diff options
Diffstat (limited to 'net/netlink/af_netlink.c')
-rw-r--r-- | net/netlink/af_netlink.c | 95 |
1 files changed, 59 insertions, 36 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 05919bf3f670..19909d0786a2 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -116,6 +116,8 @@ static ATOMIC_NOTIFIER_HEAD(netlink_chain); static DEFINE_SPINLOCK(netlink_tap_lock); static struct list_head netlink_tap_all __read_mostly; +static const struct rhashtable_params netlink_rhashtable_params; + static inline u32 netlink_group_mask(u32 group) { return group ? 1 << (group - 1) : 0; @@ -970,41 +972,50 @@ netlink_unlock_table(void) struct netlink_compare_arg { - struct net *net; + possible_net_t pnet; u32 portid; }; -static bool netlink_compare(void *ptr, void *arg) +/* Doing sizeof directly may yield 4 extra bytes on 64-bit. */ +#define netlink_compare_arg_len \ + (offsetof(struct netlink_compare_arg, portid) + sizeof(u32)) + +static inline int netlink_compare(struct rhashtable_compare_arg *arg, + const void *ptr) { - struct netlink_compare_arg *x = arg; - struct sock *sk = ptr; + const struct netlink_compare_arg *x = arg->key; + const struct netlink_sock *nlk = ptr; - return nlk_sk(sk)->portid == x->portid && - net_eq(sock_net(sk), x->net); + return nlk->portid != x->portid || + !net_eq(sock_net(&nlk->sk), read_pnet(&x->pnet)); +} + +static void netlink_compare_arg_init(struct netlink_compare_arg *arg, + struct net *net, u32 portid) +{ + memset(arg, 0, sizeof(*arg)); + write_pnet(&arg->pnet, net); + arg->portid = portid; } static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid, struct net *net) { - struct netlink_compare_arg arg = { - .net = net, - .portid = portid, - }; + struct netlink_compare_arg arg; - return rhashtable_lookup_compare(&table->hash, &portid, - &netlink_compare, &arg); + netlink_compare_arg_init(&arg, net, portid); + return rhashtable_lookup_fast(&table->hash, &arg, + netlink_rhashtable_params); } -static bool __netlink_insert(struct netlink_table *table, struct sock *sk) +static int __netlink_insert(struct netlink_table *table, struct sock *sk) { - struct netlink_compare_arg arg = { - .net = sock_net(sk), - .portid = nlk_sk(sk)->portid, - }; + struct netlink_compare_arg arg; - return rhashtable_lookup_compare_insert(&table->hash, - &nlk_sk(sk)->node, - &netlink_compare, &arg); + netlink_compare_arg_init(&arg, sock_net(sk), nlk_sk(sk)->portid); + return rhashtable_lookup_insert_key(&table->hash, &arg, + &nlk_sk(sk)->node, + netlink_rhashtable_params); } static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) @@ -1066,9 +1077,10 @@ static int netlink_insert(struct sock *sk, u32 portid) nlk_sk(sk)->portid = portid; sock_hold(sk); - err = 0; - if (!__netlink_insert(table, sk)) { - err = -EADDRINUSE; + err = __netlink_insert(table, sk); + if (err) { + if (err == -EEXIST) + err = -EADDRINUSE; sock_put(sk); } @@ -1082,7 +1094,8 @@ static void netlink_remove(struct sock *sk) struct netlink_table *table; table = &nl_table[sk->sk_protocol]; - if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) { + if (!rhashtable_remove_fast(&table->hash, &nlk_sk(sk)->node, + netlink_rhashtable_params)) { WARN_ON(atomic_read(&sk->sk_refcnt) == 1); __sock_put(sk); } @@ -2256,8 +2269,7 @@ static void netlink_cmsg_recv_pktinfo(struct msghdr *msg, struct sk_buff *skb) put_cmsg(msg, SOL_NETLINK, NETLINK_PKTINFO, sizeof(info), &info); } -static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock, - struct msghdr *msg, size_t len) +static int netlink_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) { struct sock *sk = sock->sk; struct netlink_sock *nlk = nlk_sk(sk); @@ -2346,8 +2358,7 @@ out: return err; } -static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock, - struct msghdr *msg, size_t len, +static int netlink_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, int flags) { struct scm_cookie scm; @@ -3116,17 +3127,28 @@ static struct pernet_operations __net_initdata netlink_net_ops = { .exit = netlink_net_exit, }; +static inline u32 netlink_hash(const void *data, u32 len, u32 seed) +{ + const struct netlink_sock *nlk = data; + struct netlink_compare_arg arg; + + netlink_compare_arg_init(&arg, sock_net(&nlk->sk), nlk->portid); + return jhash2((u32 *)&arg, netlink_compare_arg_len / sizeof(u32), seed); +} + +static const struct rhashtable_params netlink_rhashtable_params = { + .head_offset = offsetof(struct netlink_sock, node), + .key_len = netlink_compare_arg_len, + .obj_hashfn = netlink_hash, + .obj_cmpfn = netlink_compare, + .max_size = 65536, + .automatic_shrinking = true, +}; + static int __init netlink_proto_init(void) { int i; int err = proto_register(&netlink_proto, 0); - struct rhashtable_params ht_params = { - .head_offset = offsetof(struct netlink_sock, node), - .key_offset = offsetof(struct netlink_sock, portid), - .key_len = sizeof(u32), /* portid */ - .hashfn = jhash, - .max_shift = 16, /* 64K */ - }; if (err != 0) goto out; @@ -3138,7 +3160,8 @@ static int __init netlink_proto_init(void) goto panic; for (i = 0; i < MAX_LINKS; i++) { - if (rhashtable_init(&nl_table[i].hash, &ht_params) < 0) { + if (rhashtable_init(&nl_table[i].hash, + &netlink_rhashtable_params) < 0) { while (--i > 0) rhashtable_destroy(&nl_table[i].hash); kfree(nl_table); |