diff options
Diffstat (limited to 'net/netlink/af_netlink.c')
-rw-r--r-- | net/netlink/af_netlink.c | 63 |
1 files changed, 39 insertions, 24 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 7a186e74b1b3..ef5f77b44ec7 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -96,6 +96,14 @@ static DECLARE_WAIT_QUEUE_HEAD(nl_table_wait); static int netlink_dump(struct sock *sk); static void netlink_skb_destructor(struct sk_buff *skb); +/* nl_table locking explained: + * Lookup and traversal are protected with nl_sk_hash_lock or nl_table_lock + * combined with an RCU read-side lock. Insertion and removal are protected + * with nl_sk_hash_lock while using RCU list modification primitives and may + * run in parallel to nl_table_lock protected lookups. Destruction of the + * Netlink socket may only occur *after* nl_table_lock has been acquired + * either during or after the socket has been removed from the list. + */ DEFINE_RWLOCK(nl_table_lock); EXPORT_SYMBOL_GPL(nl_table_lock); static atomic_t nl_table_users = ATOMIC_INIT(0); @@ -106,14 +114,14 @@ static atomic_t nl_table_users = ATOMIC_INIT(0); DEFINE_MUTEX(nl_sk_hash_lock); EXPORT_SYMBOL_GPL(nl_sk_hash_lock); -static int lockdep_nl_sk_hash_is_held(void) +#ifdef CONFIG_PROVE_LOCKING +static int lockdep_nl_sk_hash_is_held(void *parent) { -#ifdef CONFIG_LOCKDEP - return (debug_locks) ? lockdep_is_held(&nl_sk_hash_lock) : 1; -#else + if (debug_locks) + return lockdep_is_held(&nl_sk_hash_lock) || lockdep_is_held(&nl_table_lock); return 1; -#endif } +#endif static ATOMIC_NOTIFIER_HEAD(netlink_chain); @@ -134,8 +142,7 @@ int netlink_add_tap(struct netlink_tap *nt) list_add_rcu(&nt->list, &netlink_tap_all); spin_unlock(&netlink_tap_lock); - if (nt->module) - __module_get(nt->module); + __module_get(nt->module); return 0; } @@ -1028,11 +1035,13 @@ static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) struct netlink_table *table = &nl_table[protocol]; struct sock *sk; + read_lock(&nl_table_lock); rcu_read_lock(); sk = __netlink_lookup(table, portid, net); if (sk) sock_hold(sk); rcu_read_unlock(); + read_unlock(&nl_table_lock); return sk; } @@ -1082,7 +1091,7 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid) nlk_sk(sk)->portid = portid; sock_hold(sk); - rhashtable_insert(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL); + rhashtable_insert(&table->hash, &nlk_sk(sk)->node); err = 0; err: mutex_unlock(&nl_sk_hash_lock); @@ -1095,7 +1104,7 @@ static void netlink_remove(struct sock *sk) mutex_lock(&nl_sk_hash_lock); table = &nl_table[sk->sk_protocol]; - if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL)) { + if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) { WARN_ON(atomic_read(&sk->sk_refcnt) == 1); __sock_put(sk); } @@ -1257,9 +1266,6 @@ static int netlink_release(struct socket *sock) } netlink_table_ungrab(); - /* Wait for readers to complete */ - synchronize_net(); - kfree(nlk->groups); nlk->groups = NULL; @@ -1281,6 +1287,7 @@ static int netlink_autobind(struct socket *sock) retry: cond_resched(); + netlink_table_grab(); rcu_read_lock(); if (__netlink_lookup(table, portid, net)) { /* Bind collision, search negative portid values. */ @@ -1288,9 +1295,11 @@ retry: if (rover > -4097) rover = -4097; rcu_read_unlock(); + netlink_table_ungrab(); goto retry; } rcu_read_unlock(); + netlink_table_ungrab(); err = netlink_insert(sk, net, portid); if (err == -EADDRINUSE) @@ -1430,7 +1439,7 @@ static void netlink_unbind(int group, long unsigned int groups, return; for (undo = 0; undo < group; undo++) - if (test_bit(group, &groups)) + if (test_bit(undo, &groups)) nlk->netlink_unbind(undo); } @@ -1482,7 +1491,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, netlink_insert(sk, net, nladdr->nl_pid) : netlink_autobind(sock); if (err) { - netlink_unbind(nlk->ngroups - 1, groups, nlk); + netlink_unbind(nlk->ngroups, groups, nlk); return err; } } @@ -2296,7 +2305,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock, } if (netlink_tx_is_mmaped(sk) && - msg->msg_iov->iov_base == NULL) { + msg->msg_iter.iov->iov_base == NULL) { err = netlink_mmap_sendmsg(sk, msg, dst_portid, dst_group, siocb); goto out; @@ -2316,7 +2325,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock, NETLINK_CB(skb).flags = netlink_skb_flags; err = -EFAULT; - if (memcpy_fromiovec(skb_put(skb, len), msg->msg_iov, len)) { + if (memcpy_from_msg(skb_put(skb, len), msg, len)) { kfree_skb(skb); goto out; } @@ -2391,7 +2400,7 @@ static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock, } skb_reset_transport_header(data_skb); - err = skb_copy_datagram_iovec(data_skb, 0, msg->msg_iov, copied); + err = skb_copy_datagram_msg(data_skb, 0, msg, copied); if (msg->msg_name) { DECLARE_SOCKADDR(struct sockaddr_nl *, addr, msg->msg_name); @@ -2499,6 +2508,7 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module, nl_table[unit].module = module; if (cfg) { nl_table[unit].bind = cfg->bind; + nl_table[unit].unbind = cfg->unbind; nl_table[unit].flags = cfg->flags; if (cfg->compare) nl_table[unit].compare = cfg->compare; @@ -2921,14 +2931,16 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos) } static void *netlink_seq_start(struct seq_file *seq, loff_t *pos) - __acquires(RCU) + __acquires(nl_table_lock) __acquires(RCU) { + read_lock(&nl_table_lock); rcu_read_lock(); return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN; } static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) { + struct rhashtable *ht; struct netlink_sock *nlk; struct nl_seq_iter *iter; struct net *net; @@ -2943,19 +2955,19 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) iter = seq->private; nlk = v; - rht_for_each_entry_rcu(nlk, nlk->node.next, node) + i = iter->link; + ht = &nl_table[i].hash; + rht_for_each_entry(nlk, nlk->node.next, ht, node) if (net_eq(sock_net((struct sock *)nlk), net)) return nlk; - i = iter->link; j = iter->hash_idx + 1; do { - struct rhashtable *ht = &nl_table[i].hash; const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); for (; j < tbl->size; j++) { - rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) { + rht_for_each_entry(nlk, tbl->buckets[j], ht, node) { if (net_eq(sock_net((struct sock *)nlk), net)) { iter->link = i; iter->hash_idx = j; @@ -2971,9 +2983,10 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) } static void netlink_seq_stop(struct seq_file *seq, void *v) - __releases(RCU) + __releases(RCU) __releases(nl_table_lock) { rcu_read_unlock(); + read_unlock(&nl_table_lock); } @@ -3116,11 +3129,13 @@ static int __init netlink_proto_init(void) .head_offset = offsetof(struct netlink_sock, node), .key_offset = offsetof(struct netlink_sock, portid), .key_len = sizeof(u32), /* portid */ - .hashfn = arch_fast_hash, + .hashfn = jhash, .max_shift = 16, /* 64K */ .grow_decision = rht_grow_above_75, .shrink_decision = rht_shrink_below_30, +#ifdef CONFIG_PROVE_LOCKING .mutex_is_held = lockdep_nl_sk_hash_is_held, +#endif }; if (err != 0) |