summaryrefslogtreecommitdiff
path: root/net/netlink/af_netlink.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/netlink/af_netlink.c')
-rw-r--r--net/netlink/af_netlink.c80
1 files changed, 57 insertions, 23 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 7a401d94463a..c64277659753 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -580,7 +580,9 @@ static int netlink_insert(struct sock *sk, u32 portid)
if (nlk_sk(sk)->bound)
goto err;
- nlk_sk(sk)->portid = portid;
+ /* portid can be read locklessly from netlink_getname(). */
+ WRITE_ONCE(nlk_sk(sk)->portid, portid);
+
sock_hold(sk);
err = __netlink_insert(table, sk);
@@ -812,6 +814,17 @@ static int netlink_release(struct socket *sock)
}
sock_prot_inuse_add(sock_net(sk), &netlink_proto, -1);
+
+ /* Because struct net might disappear soon, do not keep a pointer. */
+ if (!sk->sk_net_refcnt && sock_net(sk) != &init_net) {
+ __netns_tracker_free(sock_net(sk), &sk->ns_tracker, false);
+ /* Because of deferred_put_nlk_sk and use of work queue,
+ * it is possible netns will be freed before this socket.
+ */
+ sock_net_set(sk, &init_net);
+ __netns_tracker_alloc(&init_net, &sk->ns_tracker,
+ false, GFP_KERNEL);
+ }
call_rcu(&nlk->rcu, deferred_put_nlk_sk);
return 0;
}
@@ -1085,9 +1098,11 @@ static int netlink_connect(struct socket *sock, struct sockaddr *addr,
return -EINVAL;
if (addr->sa_family == AF_UNSPEC) {
- sk->sk_state = NETLINK_UNCONNECTED;
- nlk->dst_portid = 0;
- nlk->dst_group = 0;
+ /* paired with READ_ONCE() in netlink_getsockbyportid() */
+ WRITE_ONCE(sk->sk_state, NETLINK_UNCONNECTED);
+ /* dst_portid and dst_group can be read locklessly */
+ WRITE_ONCE(nlk->dst_portid, 0);
+ WRITE_ONCE(nlk->dst_group, 0);
return 0;
}
if (addr->sa_family != AF_NETLINK)
@@ -1108,9 +1123,11 @@ static int netlink_connect(struct socket *sock, struct sockaddr *addr,
err = netlink_autobind(sock);
if (err == 0) {
- sk->sk_state = NETLINK_CONNECTED;
- nlk->dst_portid = nladdr->nl_pid;
- nlk->dst_group = ffs(nladdr->nl_groups);
+ /* paired with READ_ONCE() in netlink_getsockbyportid() */
+ WRITE_ONCE(sk->sk_state, NETLINK_CONNECTED);
+ /* dst_portid and dst_group can be read locklessly */
+ WRITE_ONCE(nlk->dst_portid, nladdr->nl_pid);
+ WRITE_ONCE(nlk->dst_group, ffs(nladdr->nl_groups));
}
return err;
@@ -1127,10 +1144,12 @@ static int netlink_getname(struct socket *sock, struct sockaddr *addr,
nladdr->nl_pad = 0;
if (peer) {
- nladdr->nl_pid = nlk->dst_portid;
- nladdr->nl_groups = netlink_group_mask(nlk->dst_group);
+ /* Paired with WRITE_ONCE() in netlink_connect() */
+ nladdr->nl_pid = READ_ONCE(nlk->dst_portid);
+ nladdr->nl_groups = netlink_group_mask(READ_ONCE(nlk->dst_group));
} else {
- nladdr->nl_pid = nlk->portid;
+ /* Paired with WRITE_ONCE() in netlink_insert() */
+ nladdr->nl_pid = READ_ONCE(nlk->portid);
netlink_lock_table();
nladdr->nl_groups = nlk->groups ? nlk->groups[0] : 0;
netlink_unlock_table();
@@ -1157,8 +1176,9 @@ static struct sock *netlink_getsockbyportid(struct sock *ssk, u32 portid)
/* Don't bother queuing skb if kernel socket has no input function */
nlk = nlk_sk(sock);
- if (sock->sk_state == NETLINK_CONNECTED &&
- nlk->dst_portid != nlk_sk(ssk)->portid) {
+ /* dst_portid and sk_state can be changed in netlink_connect() */
+ if (READ_ONCE(sock->sk_state) == NETLINK_CONNECTED &&
+ READ_ONCE(nlk->dst_portid) != nlk_sk(ssk)->portid) {
sock_put(sock);
return ERR_PTR(-ECONNREFUSED);
}
@@ -1875,8 +1895,9 @@ static int netlink_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
goto out;
netlink_skb_flags |= NETLINK_SKB_DST;
} else {
- dst_portid = nlk->dst_portid;
- dst_group = nlk->dst_group;
+ /* Paired with WRITE_ONCE() in netlink_connect() */
+ dst_portid = READ_ONCE(nlk->dst_portid);
+ dst_group = READ_ONCE(nlk->dst_group);
}
/* Paired with WRITE_ONCE() in netlink_insert() */
@@ -2488,19 +2509,24 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err,
flags |= NLM_F_ACK_TLVS;
skb = nlmsg_new(payload + tlvlen, GFP_KERNEL);
- if (!skb) {
- NETLINK_CB(in_skb).sk->sk_err = ENOBUFS;
- sk_error_report(NETLINK_CB(in_skb).sk);
- return;
- }
+ if (!skb)
+ goto err_skb;
rep = nlmsg_put(skb, NETLINK_CB(in_skb).portid, nlh->nlmsg_seq,
- NLMSG_ERROR, payload, flags);
+ NLMSG_ERROR, sizeof(*errmsg), flags);
+ if (!rep)
+ goto err_bad_put;
errmsg = nlmsg_data(rep);
errmsg->error = err;
- unsafe_memcpy(&errmsg->msg, nlh, payload > sizeof(*errmsg)
- ? nlh->nlmsg_len : sizeof(*nlh),
- /* Bounds checked by the skb layer. */);
+ errmsg->msg = *nlh;
+
+ if (!(flags & NLM_F_CAPPED)) {
+ if (!nlmsg_append(skb, nlmsg_len(nlh)))
+ goto err_bad_put;
+
+ memcpy(nlmsg_data(&errmsg->msg), nlmsg_data(nlh),
+ nlmsg_len(nlh));
+ }
if (tlvlen)
netlink_ack_tlv_fill(in_skb, skb, nlh, err, extack);
@@ -2508,6 +2534,14 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err,
nlmsg_end(skb, rep);
nlmsg_unicast(in_skb->sk, skb, NETLINK_CB(in_skb).portid);
+
+ return;
+
+err_bad_put:
+ nlmsg_free(skb);
+err_skb:
+ NETLINK_CB(in_skb).sk->sk_err = ENOBUFS;
+ sk_error_report(NETLINK_CB(in_skb).sk);
}
EXPORT_SYMBOL(netlink_ack);