diff options
Diffstat (limited to 'net/netlink/af_netlink.c')
-rw-r--r-- | net/netlink/af_netlink.c | 311 |
1 files changed, 243 insertions, 68 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index d0b3dd60d386..8df7f64c6db3 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -57,6 +57,7 @@ #include <linux/audit.h> #include <linux/mutex.h> #include <linux/vmalloc.h> +#include <linux/if_arp.h> #include <asm/cacheflush.h> #include <net/net_namespace.h> @@ -101,6 +102,9 @@ static atomic_t nl_table_users = ATOMIC_INIT(0); static ATOMIC_NOTIFIER_HEAD(netlink_chain); +static DEFINE_SPINLOCK(netlink_tap_lock); +static struct list_head netlink_tap_all __read_mostly; + static inline u32 netlink_group_mask(u32 group) { return group ? 1 << (group - 1) : 0; @@ -111,6 +115,130 @@ static inline struct hlist_head *nl_portid_hashfn(struct nl_portid_hash *hash, u return &hash->table[jhash_1word(portid, hash->rnd) & hash->mask]; } +int netlink_add_tap(struct netlink_tap *nt) +{ + if (unlikely(nt->dev->type != ARPHRD_NETLINK)) + return -EINVAL; + + spin_lock(&netlink_tap_lock); + list_add_rcu(&nt->list, &netlink_tap_all); + spin_unlock(&netlink_tap_lock); + + if (nt->module) + __module_get(nt->module); + + return 0; +} +EXPORT_SYMBOL_GPL(netlink_add_tap); + +int __netlink_remove_tap(struct netlink_tap *nt) +{ + bool found = false; + struct netlink_tap *tmp; + + spin_lock(&netlink_tap_lock); + + list_for_each_entry(tmp, &netlink_tap_all, list) { + if (nt == tmp) { + list_del_rcu(&nt->list); + found = true; + goto out; + } + } + + pr_warn("__netlink_remove_tap: %p not found\n", nt); +out: + spin_unlock(&netlink_tap_lock); + + if (found && nt->module) + module_put(nt->module); + + return found ? 0 : -ENODEV; +} +EXPORT_SYMBOL_GPL(__netlink_remove_tap); + +int netlink_remove_tap(struct netlink_tap *nt) +{ + int ret; + + ret = __netlink_remove_tap(nt); + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL_GPL(netlink_remove_tap); + +static bool netlink_filter_tap(const struct sk_buff *skb) +{ + struct sock *sk = skb->sk; + bool pass = false; + + /* We take the more conservative approach and + * whitelist socket protocols that may pass. + */ + switch (sk->sk_protocol) { + case NETLINK_ROUTE: + case NETLINK_USERSOCK: + case NETLINK_SOCK_DIAG: + case NETLINK_NFLOG: + case NETLINK_XFRM: + case NETLINK_FIB_LOOKUP: + case NETLINK_NETFILTER: + case NETLINK_GENERIC: + pass = true; + break; + } + + return pass; +} + +static int __netlink_deliver_tap_skb(struct sk_buff *skb, + struct net_device *dev) +{ + struct sk_buff *nskb; + struct sock *sk = skb->sk; + int ret = -ENOMEM; + + dev_hold(dev); + nskb = skb_clone(skb, GFP_ATOMIC); + if (nskb) { + nskb->dev = dev; + nskb->protocol = htons((u16) sk->sk_protocol); + + ret = dev_queue_xmit(nskb); + if (unlikely(ret > 0)) + ret = net_xmit_errno(ret); + } + + dev_put(dev); + return ret; +} + +static void __netlink_deliver_tap(struct sk_buff *skb) +{ + int ret; + struct netlink_tap *tmp; + + if (!netlink_filter_tap(skb)) + return; + + list_for_each_entry_rcu(tmp, &netlink_tap_all, list) { + ret = __netlink_deliver_tap_skb(skb, tmp->dev); + if (unlikely(ret)) + break; + } +} + +static void netlink_deliver_tap(struct sk_buff *skb) +{ + rcu_read_lock(); + + if (unlikely(!list_empty(&netlink_tap_all))) + __netlink_deliver_tap(skb); + + rcu_read_unlock(); +} + static void netlink_overrun(struct sock *sk) { struct netlink_sock *nlk = nlk_sk(sk); @@ -196,14 +324,14 @@ static void **alloc_pg_vec(struct netlink_sock *nlk, { unsigned int block_nr = req->nm_block_nr; unsigned int i; - void **pg_vec, *ptr; + void **pg_vec; pg_vec = kcalloc(block_nr, sizeof(void *), GFP_KERNEL); if (pg_vec == NULL) return NULL; for (i = 0; i < block_nr; i++) { - pg_vec[i] = ptr = alloc_one_pg_vec_page(order); + pg_vec[i] = alloc_one_pg_vec_page(order); if (pg_vec[i] == NULL) goto err1; } @@ -371,7 +499,7 @@ static int netlink_mmap(struct file *file, struct socket *sock, err = 0; out: mutex_unlock(&nlk->pg_vec_lock); - return 0; + return err; } static void netlink_frame_flush_dcache(const struct nl_mmap_hdr *hdr) @@ -497,7 +625,7 @@ static unsigned int netlink_poll(struct file *file, struct socket *sock, * for dumps is performed here. A dump is allowed to continue * if at least half the ring is unused. */ - while (nlk->cb != NULL && netlink_dump_space(nlk)) { + while (nlk->cb_running && netlink_dump_space(nlk)) { err = netlink_dump(sk); if (err < 0) { sk->sk_err = err; @@ -704,18 +832,6 @@ static void netlink_ring_set_copied(struct sock *sk, struct sk_buff *skb) #define netlink_mmap_sendmsg(sk, msg, dst_portid, dst_group, siocb) 0 #endif /* CONFIG_NETLINK_MMAP */ -static void netlink_destroy_callback(struct netlink_callback *cb) -{ - kfree_skb(cb->skb); - kfree(cb); -} - -static void netlink_consume_callback(struct netlink_callback *cb) -{ - consume_skb(cb->skb); - kfree(cb); -} - static void netlink_skb_destructor(struct sk_buff *skb) { #ifdef CONFIG_NETLINK_MMAP @@ -750,6 +866,13 @@ static void netlink_skb_destructor(struct sk_buff *skb) skb->head = NULL; } #endif + if (is_vmalloc_addr(skb->head)) { + if (!skb->cloned || + !atomic_dec_return(&(skb_shinfo(skb)->dataref))) + vfree(skb->head); + + skb->head = NULL; + } if (skb->sk != NULL) sock_rfree(skb); } @@ -767,12 +890,12 @@ static void netlink_sock_destruct(struct sock *sk) { struct netlink_sock *nlk = nlk_sk(sk); - if (nlk->cb) { - if (nlk->cb->done) - nlk->cb->done(nlk->cb); + if (nlk->cb_running) { + if (nlk->cb.done) + nlk->cb.done(&nlk->cb); - module_put(nlk->cb->module); - netlink_destroy_callback(nlk->cb); + module_put(nlk->cb.module); + kfree_skb(nlk->cb.skb); } skb_queue_purge(&sk->sk_receive_queue); @@ -854,16 +977,23 @@ netlink_unlock_table(void) wake_up(&nl_table_wait); } +static bool netlink_compare(struct net *net, struct sock *sk) +{ + return net_eq(sock_net(sk), net); +} + static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) { - struct nl_portid_hash *hash = &nl_table[protocol].hash; + struct netlink_table *table = &nl_table[protocol]; + struct nl_portid_hash *hash = &table->hash; struct hlist_head *head; struct sock *sk; read_lock(&nl_table_lock); head = nl_portid_hashfn(hash, portid); sk_for_each(sk, head) { - if (net_eq(sock_net(sk), net) && (nlk_sk(sk)->portid == portid)) { + if (table->compare(net, sk) && + (nlk_sk(sk)->portid == portid)) { sock_hold(sk); goto found; } @@ -976,7 +1106,8 @@ netlink_update_listeners(struct sock *sk) static int netlink_insert(struct sock *sk, struct net *net, u32 portid) { - struct nl_portid_hash *hash = &nl_table[sk->sk_protocol].hash; + struct netlink_table *table = &nl_table[sk->sk_protocol]; + struct nl_portid_hash *hash = &table->hash; struct hlist_head *head; int err = -EADDRINUSE; struct sock *osk; @@ -986,7 +1117,8 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid) head = nl_portid_hashfn(hash, portid); len = 0; sk_for_each(osk, head) { - if (net_eq(sock_net(osk), net) && (nlk_sk(osk)->portid == portid)) + if (table->compare(net, osk) && + (nlk_sk(osk)->portid == portid)) break; len++; } @@ -1183,7 +1315,8 @@ static int netlink_autobind(struct socket *sock) { struct sock *sk = sock->sk; struct net *net = sock_net(sk); - struct nl_portid_hash *hash = &nl_table[sk->sk_protocol].hash; + struct netlink_table *table = &nl_table[sk->sk_protocol]; + struct nl_portid_hash *hash = &table->hash; struct hlist_head *head; struct sock *osk; s32 portid = task_tgid_vnr(current); @@ -1195,7 +1328,7 @@ retry: netlink_table_grab(); head = nl_portid_hashfn(hash, portid); sk_for_each(osk, head) { - if (!net_eq(sock_net(osk), net)) + if (!table->compare(net, osk)) continue; if (nlk_sk(osk)->portid == portid) { /* Bind collision, search negative portid values. */ @@ -1420,6 +1553,33 @@ struct sock *netlink_getsockbyfilp(struct file *filp) return sock; } +static struct sk_buff *netlink_alloc_large_skb(unsigned int size, + int broadcast) +{ + struct sk_buff *skb; + void *data; + + if (size <= NLMSG_GOODSIZE || broadcast) + return alloc_skb(size, GFP_KERNEL); + + size = SKB_DATA_ALIGN(size) + + SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); + + data = vmalloc(size); + if (data == NULL) + return NULL; + + skb = build_skb(data, size); + if (skb == NULL) + vfree(data); + else { + skb->head_frag = 0; + skb->destructor = netlink_skb_destructor; + } + + return skb; +} + /* * Attach a skb to a netlink socket. * The caller must hold a reference to the destination socket. On error, the @@ -1475,6 +1635,8 @@ static int __netlink_sendskb(struct sock *sk, struct sk_buff *skb) { int len = skb->len; + netlink_deliver_tap(skb); + #ifdef CONFIG_NETLINK_MMAP if (netlink_skb_is_mmaped(skb)) netlink_queue_mmaped_skb(sk, skb); @@ -1510,7 +1672,7 @@ static struct sk_buff *netlink_trim(struct sk_buff *skb, gfp_t allocation) return skb; delta = skb->end - skb->tail; - if (delta * 2 < skb->truesize) + if (is_vmalloc_addr(skb->head) || delta * 2 < skb->truesize) return skb; if (skb_shared(skb)) { @@ -1535,6 +1697,11 @@ static int netlink_unicast_kernel(struct sock *sk, struct sk_buff *skb, ret = -ECONNREFUSED; if (nlk->netlink_rcv != NULL) { + /* We could do a netlink_deliver_tap(skb) here as well + * but since this is intended for the kernel only, we + * should rather let it stay under the hood. + */ + ret = skb->len; netlink_skb_set_owner_r(skb, sk); NETLINK_CB(skb).sk = ssk; @@ -2096,7 +2263,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock, if (len > sk->sk_sndbuf - 32) goto out; err = -ENOBUFS; - skb = alloc_skb(len, GFP_KERNEL); + skb = netlink_alloc_large_skb(len, dst_group); if (skb == NULL) goto out; @@ -2201,7 +2368,8 @@ static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock, skb_free_datagram(sk, skb); - if (nlk->cb && atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf / 2) { + if (nlk->cb_running && + atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf / 2) { ret = netlink_dump(sk); if (ret) { sk->sk_err = ret; @@ -2285,6 +2453,8 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module, if (cfg) { nl_table[unit].bind = cfg->bind; nl_table[unit].flags = cfg->flags; + if (cfg->compare) + nl_table[unit].compare = cfg->compare; } nl_table[unit].registered = 1; } else { @@ -2415,13 +2585,12 @@ static int netlink_dump(struct sock *sk) int alloc_size; mutex_lock(nlk->cb_mutex); - - cb = nlk->cb; - if (cb == NULL) { + if (!nlk->cb_running) { err = -EINVAL; goto errout_skb; } + cb = &nlk->cb; alloc_size = max_t(int, cb->min_dump_alloc, NLMSG_GOODSIZE); if (!netlink_rx_is_mmaped(sk) && @@ -2459,11 +2628,11 @@ static int netlink_dump(struct sock *sk) if (cb->done) cb->done(cb); - nlk->cb = NULL; - mutex_unlock(nlk->cb_mutex); + nlk->cb_running = false; + mutex_unlock(nlk->cb_mutex); module_put(cb->module); - netlink_consume_callback(cb); + consume_skb(cb->skb); return 0; errout_skb: @@ -2481,59 +2650,51 @@ int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb, struct netlink_sock *nlk; int ret; - cb = kzalloc(sizeof(*cb), GFP_KERNEL); - if (cb == NULL) - return -ENOBUFS; - /* Memory mapped dump requests need to be copied to avoid looping * on the pending state in netlink_mmap_sendmsg() while the CB hold * a reference to the skb. */ if (netlink_skb_is_mmaped(skb)) { skb = skb_copy(skb, GFP_KERNEL); - if (skb == NULL) { - kfree(cb); + if (skb == NULL) return -ENOBUFS; - } } else atomic_inc(&skb->users); - cb->dump = control->dump; - cb->done = control->done; - cb->nlh = nlh; - cb->data = control->data; - cb->module = control->module; - cb->min_dump_alloc = control->min_dump_alloc; - cb->skb = skb; - sk = netlink_lookup(sock_net(ssk), ssk->sk_protocol, NETLINK_CB(skb).portid); if (sk == NULL) { - netlink_destroy_callback(cb); - return -ECONNREFUSED; + ret = -ECONNREFUSED; + goto error_free; } - nlk = nlk_sk(sk); + nlk = nlk_sk(sk); mutex_lock(nlk->cb_mutex); /* A dump is in progress... */ - if (nlk->cb) { - mutex_unlock(nlk->cb_mutex); - netlink_destroy_callback(cb); + if (nlk->cb_running) { ret = -EBUSY; - goto out; + goto error_unlock; } /* add reference of module which cb->dump belongs to */ - if (!try_module_get(cb->module)) { - mutex_unlock(nlk->cb_mutex); - netlink_destroy_callback(cb); + if (!try_module_get(control->module)) { ret = -EPROTONOSUPPORT; - goto out; + goto error_unlock; } - nlk->cb = cb; + cb = &nlk->cb; + memset(cb, 0, sizeof(*cb)); + cb->dump = control->dump; + cb->done = control->done; + cb->nlh = nlh; + cb->data = control->data; + cb->module = control->module; + cb->min_dump_alloc = control->min_dump_alloc; + cb->skb = skb; + + nlk->cb_running = true; + mutex_unlock(nlk->cb_mutex); ret = netlink_dump(sk); -out: sock_put(sk); if (ret) @@ -2543,6 +2704,13 @@ out: * signal not to send ACK even if it was requested. */ return -EINTR; + +error_unlock: + sock_put(sk); + mutex_unlock(nlk->cb_mutex); +error_free: + kfree_skb(skb); + return ret; } EXPORT_SYMBOL(__netlink_dump_start); @@ -2707,6 +2875,7 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) { struct sock *s; struct nl_seq_iter *iter; + struct net *net; int i, j; ++*pos; @@ -2714,11 +2883,12 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) if (v == SEQ_START_TOKEN) return netlink_seq_socket_idx(seq, 0); + net = seq_file_net(seq); iter = seq->private; s = v; do { s = sk_next(s); - } while (s && sock_net(s) != seq_file_net(seq)); + } while (s && !nl_table[s->sk_protocol].compare(net, s)); if (s) return s; @@ -2730,7 +2900,8 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) for (; j <= hash->mask; j++) { s = sk_head(&hash->table[j]); - while (s && sock_net(s) != seq_file_net(seq)) + + while (s && !nl_table[s->sk_protocol].compare(net, s)) s = sk_next(s); if (s) { iter->link = i; @@ -2762,14 +2933,14 @@ static int netlink_seq_show(struct seq_file *seq, void *v) struct sock *s = v; struct netlink_sock *nlk = nlk_sk(s); - seq_printf(seq, "%pK %-3d %-6u %08x %-8d %-8d %pK %-8d %-8d %-8lu\n", + seq_printf(seq, "%pK %-3d %-6u %08x %-8d %-8d %d %-8d %-8d %-8lu\n", s, s->sk_protocol, nlk->portid, nlk->groups ? (u32)nlk->groups[0] : 0, sk_rmem_alloc_get(s), sk_wmem_alloc_get(s), - nlk->cb, + nlk->cb_running, atomic_read(&s->sk_refcnt), atomic_read(&s->sk_drops), sock_i_ino(s) @@ -2923,8 +3094,12 @@ static int __init netlink_proto_init(void) hash->shift = 0; hash->mask = 0; hash->rehash_time = jiffies; + + nl_table[i].compare = netlink_compare; } + INIT_LIST_HEAD(&netlink_tap_all); + netlink_add_usersock_entry(); sock_register(&netlink_family_ops); |