summaryrefslogtreecommitdiff
path: root/net/mctp/af_mctp.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/mctp/af_mctp.c')
-rw-r--r--net/mctp/af_mctp.c86
1 files changed, 76 insertions, 10 deletions
diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c
index 66a411d60b6c..d344b02a1cde 100644
--- a/net/mctp/af_mctp.c
+++ b/net/mctp/af_mctp.c
@@ -77,6 +77,7 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
int rc, addrlen = msg->msg_namelen;
struct sock *sk = sock->sk;
+ struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
struct mctp_skb_cb *cb;
struct mctp_route *rt;
struct sk_buff *skb;
@@ -100,11 +101,6 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
if (addr->smctp_network == MCTP_NET_ANY)
addr->smctp_network = mctp_default_net(sock_net(sk));
- rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
- addr->smctp_addr.s_addr);
- if (!rt)
- return -EHOSTUNREACH;
-
skb = sock_alloc_send_skb(sk, hlen + 1 + len,
msg->msg_flags & MSG_DONTWAIT, &rc);
if (!skb)
@@ -116,19 +112,45 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
*(u8 *)skb_put(skb, 1) = addr->smctp_type;
rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len);
- if (rc < 0) {
- kfree_skb(skb);
- return rc;
- }
+ if (rc < 0)
+ goto err_free;
/* set up cb */
cb = __mctp_cb(skb);
cb->net = addr->smctp_network;
+ /* direct addressing */
+ if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
+ DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
+ extaddr, msg->msg_name);
+
+ if (extaddr->smctp_halen > sizeof(cb->haddr)) {
+ rc = -EINVAL;
+ goto err_free;
+ }
+
+ cb->ifindex = extaddr->smctp_ifindex;
+ cb->halen = extaddr->smctp_halen;
+ memcpy(cb->haddr, extaddr->smctp_haddr, cb->halen);
+
+ rt = NULL;
+ } else {
+ rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
+ addr->smctp_addr.s_addr);
+ if (!rt) {
+ rc = -EHOSTUNREACH;
+ goto err_free;
+ }
+ }
+
rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
addr->smctp_tag);
return rc ? : len;
+
+err_free:
+ kfree_skb(skb);
+ return rc;
}
static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
@@ -136,6 +158,7 @@ static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
{
DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
struct sock *sk = sock->sk;
+ struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
struct sk_buff *skb;
size_t msglen;
u8 type;
@@ -181,6 +204,16 @@ static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
addr->smctp_tag = hdr->flags_seq_tag &
(MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
msg->msg_namelen = sizeof(*addr);
+
+ if (msk->addr_ext) {
+ DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, ae,
+ msg->msg_name);
+ msg->msg_namelen = sizeof(*ae);
+ ae->smctp_ifindex = cb->ifindex;
+ ae->smctp_halen = cb->halen;
+ memset(ae->smctp_haddr, 0x0, sizeof(ae->smctp_haddr));
+ memcpy(ae->smctp_haddr, cb->haddr, cb->halen);
+ }
}
rc = len;
@@ -196,12 +229,45 @@ out_free:
static int mctp_setsockopt(struct socket *sock, int level, int optname,
sockptr_t optval, unsigned int optlen)
{
- return -EINVAL;
+ struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
+ int val;
+
+ if (level != SOL_MCTP)
+ return -EINVAL;
+
+ if (optname == MCTP_OPT_ADDR_EXT) {
+ if (optlen != sizeof(int))
+ return -EINVAL;
+ if (copy_from_sockptr(&val, optval, sizeof(int)))
+ return -EFAULT;
+ msk->addr_ext = val;
+ return 0;
+ }
+
+ return -ENOPROTOOPT;
}
static int mctp_getsockopt(struct socket *sock, int level, int optname,
char __user *optval, int __user *optlen)
{
+ struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
+ int len, val;
+
+ if (level != SOL_MCTP)
+ return -EINVAL;
+
+ if (get_user(len, optlen))
+ return -EFAULT;
+
+ if (optname == MCTP_OPT_ADDR_EXT) {
+ if (len != sizeof(int))
+ return -EINVAL;
+ val = !!msk->addr_ext;
+ if (copy_to_user(optval, &val, len))
+ return -EFAULT;
+ return 0;
+ }
+
return -EINVAL;
}