diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/netfilter/nft_hash.c | 70 | ||||
-rw-r--r-- | net/netlink/af_netlink.c | 88 | ||||
-rw-r--r-- | net/tipc/socket.c | 32 |
3 files changed, 103 insertions, 87 deletions
diff --git a/net/netfilter/nft_hash.c b/net/netfilter/nft_hash.c index c82df0a48fcd..4585c5724391 100644 --- a/net/netfilter/nft_hash.c +++ b/net/netfilter/nft_hash.c @@ -29,6 +29,8 @@ struct nft_hash_elem { struct nft_data data[]; }; +static const struct rhashtable_params nft_hash_params; + static bool nft_hash_lookup(const struct nft_set *set, const struct nft_data *key, struct nft_data *data) @@ -36,7 +38,7 @@ static bool nft_hash_lookup(const struct nft_set *set, struct rhashtable *priv = nft_set_priv(set); const struct nft_hash_elem *he; - he = rhashtable_lookup(priv, key); + he = rhashtable_lookup_fast(priv, key, nft_hash_params); if (he && set->flags & NFT_SET_MAP) nft_data_copy(data, he->data); @@ -49,6 +51,7 @@ static int nft_hash_insert(const struct nft_set *set, struct rhashtable *priv = nft_set_priv(set); struct nft_hash_elem *he; unsigned int size; + int err; if (elem->flags != 0) return -EINVAL; @@ -65,9 +68,11 @@ static int nft_hash_insert(const struct nft_set *set, if (set->flags & NFT_SET_MAP) nft_data_copy(he->data, &elem->data); - rhashtable_insert(priv, &he->node); + err = rhashtable_insert_fast(priv, &he->node, nft_hash_params); + if (err) + kfree(he); - return 0; + return err; } static void nft_hash_elem_destroy(const struct nft_set *set, @@ -84,46 +89,26 @@ static void nft_hash_remove(const struct nft_set *set, { struct rhashtable *priv = nft_set_priv(set); - rhashtable_remove(priv, elem->cookie); + rhashtable_remove_fast(priv, elem->cookie, nft_hash_params); synchronize_rcu(); kfree(elem->cookie); } -struct nft_compare_arg { - const struct nft_set *set; - struct nft_set_elem *elem; -}; - -static bool nft_hash_compare(void *ptr, void *arg) -{ - struct nft_hash_elem *he = ptr; - struct nft_compare_arg *x = arg; - - if (!nft_data_cmp(&he->key, &x->elem->key, x->set->klen)) { - x->elem->cookie = he; - x->elem->flags = 0; - if (x->set->flags & NFT_SET_MAP) - nft_data_copy(&x->elem->data, he->data); - - return true; - } - - return false; -} - static int nft_hash_get(const struct nft_set *set, struct nft_set_elem *elem) { struct rhashtable *priv = nft_set_priv(set); - struct nft_compare_arg arg = { - .set = set, - .elem = elem, - }; + struct nft_hash_elem *he; + + he = rhashtable_lookup_fast(priv, &elem->key, nft_hash_params); + if (!he) + return -ENOENT; - if (rhashtable_lookup_compare(priv, &elem->key, - &nft_hash_compare, &arg)) - return 0; + elem->cookie = he; + elem->flags = 0; + if (set->flags & NFT_SET_MAP) + nft_data_copy(&elem->data, he->data); - return -ENOENT; + return 0; } static void nft_hash_walk(const struct nft_ctx *ctx, const struct nft_set *set, @@ -181,18 +166,21 @@ static unsigned int nft_hash_privsize(const struct nlattr * const nla[]) return sizeof(struct rhashtable); } +static const struct rhashtable_params nft_hash_params = { + .head_offset = offsetof(struct nft_hash_elem, node), + .key_offset = offsetof(struct nft_hash_elem, key), + .hashfn = jhash, +}; + static int nft_hash_init(const struct nft_set *set, const struct nft_set_desc *desc, const struct nlattr * const tb[]) { struct rhashtable *priv = nft_set_priv(set); - struct rhashtable_params params = { - .nelem_hint = desc->size ? : NFT_HASH_ELEMENT_HINT, - .head_offset = offsetof(struct nft_hash_elem, node), - .key_offset = offsetof(struct nft_hash_elem, key), - .key_len = set->klen, - .hashfn = jhash, - }; + struct rhashtable_params params = nft_hash_params; + + params.nelem_hint = desc->size ?: NFT_HASH_ELEMENT_HINT; + params.key_len = set->klen; return rhashtable_init(priv, ¶ms); } diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index d97aed628bda..72c6b55af741 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -116,6 +116,8 @@ static ATOMIC_NOTIFIER_HEAD(netlink_chain); static DEFINE_SPINLOCK(netlink_tap_lock); static struct list_head netlink_tap_all __read_mostly; +static const struct rhashtable_params netlink_rhashtable_params; + static inline u32 netlink_group_mask(u32 group) { return group ? 1 << (group - 1) : 0; @@ -970,41 +972,49 @@ netlink_unlock_table(void) struct netlink_compare_arg { - struct net *net; + possible_net_t pnet; u32 portid; + char trailer[]; }; -static bool netlink_compare(void *ptr, void *arg) +#define netlink_compare_arg_len offsetof(struct netlink_compare_arg, trailer) + +static inline int netlink_compare(struct rhashtable_compare_arg *arg, + const void *ptr) { - struct netlink_compare_arg *x = arg; - struct sock *sk = ptr; + const struct netlink_compare_arg *x = arg->key; + const struct netlink_sock *nlk = ptr; - return nlk_sk(sk)->portid == x->portid && - net_eq(sock_net(sk), x->net); + return nlk->portid != x->portid || + !net_eq(sock_net(&nlk->sk), read_pnet(&x->pnet)); +} + +static void netlink_compare_arg_init(struct netlink_compare_arg *arg, + struct net *net, u32 portid) +{ + memset(arg, 0, sizeof(*arg)); + write_pnet(&arg->pnet, net); + arg->portid = portid; } static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid, struct net *net) { - struct netlink_compare_arg arg = { - .net = net, - .portid = portid, - }; + struct netlink_compare_arg arg; - return rhashtable_lookup_compare(&table->hash, &portid, - &netlink_compare, &arg); + netlink_compare_arg_init(&arg, net, portid); + return rhashtable_lookup_fast(&table->hash, &arg, + netlink_rhashtable_params); } -static bool __netlink_insert(struct netlink_table *table, struct sock *sk) +static int __netlink_insert(struct netlink_table *table, struct sock *sk) { - struct netlink_compare_arg arg = { - .net = sock_net(sk), - .portid = nlk_sk(sk)->portid, - }; + struct netlink_compare_arg arg; - return rhashtable_lookup_compare_insert(&table->hash, - &nlk_sk(sk)->node, - &netlink_compare, &arg); + netlink_compare_arg_init(&arg, sock_net(sk), nlk_sk(sk)->portid); + return rhashtable_lookup_insert_key(&table->hash, &arg, + &nlk_sk(sk)->node, + netlink_rhashtable_params); } static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) @@ -1066,9 +1076,10 @@ static int netlink_insert(struct sock *sk, u32 portid) nlk_sk(sk)->portid = portid; sock_hold(sk); - err = 0; - if (!__netlink_insert(table, sk)) { - err = -EADDRINUSE; + err = __netlink_insert(table, sk); + if (err) { + if (err == -EEXIST) + err = -EADDRINUSE; sock_put(sk); } @@ -1082,7 +1093,8 @@ static void netlink_remove(struct sock *sk) struct netlink_table *table; table = &nl_table[sk->sk_protocol]; - if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) { + if (!rhashtable_remove_fast(&table->hash, &nlk_sk(sk)->node, + netlink_rhashtable_params)) { WARN_ON(atomic_read(&sk->sk_refcnt) == 1); __sock_put(sk); } @@ -3114,17 +3126,28 @@ static struct pernet_operations __net_initdata netlink_net_ops = { .exit = netlink_net_exit, }; +static inline u32 netlink_hash(const void *data, u32 seed) +{ + const struct netlink_sock *nlk = data; + struct netlink_compare_arg arg; + + netlink_compare_arg_init(&arg, sock_net(&nlk->sk), nlk->portid); + return jhash(&arg, netlink_compare_arg_len, seed); +} + +static const struct rhashtable_params netlink_rhashtable_params = { + .head_offset = offsetof(struct netlink_sock, node), + .key_len = netlink_compare_arg_len, + .hashfn = jhash, + .obj_hashfn = netlink_hash, + .obj_cmpfn = netlink_compare, + .max_size = 65536, +}; + static int __init netlink_proto_init(void) { int i; int err = proto_register(&netlink_proto, 0); - struct rhashtable_params ht_params = { - .head_offset = offsetof(struct netlink_sock, node), - .key_offset = offsetof(struct netlink_sock, portid), - .key_len = sizeof(u32), /* portid */ - .hashfn = jhash, - .max_size = 65536, - }; if (err != 0) goto out; @@ -3136,7 +3159,8 @@ static int __init netlink_proto_init(void) goto panic; for (i = 0; i < MAX_LINKS; i++) { - if (rhashtable_init(&nl_table[i].hash, &ht_params) < 0) { + if (rhashtable_init(&nl_table[i].hash, + &netlink_rhashtable_params) < 0) { while (--i > 0) rhashtable_destroy(&nl_table[i].hash); kfree(nl_table); diff --git a/net/tipc/socket.c b/net/tipc/socket.c index c03a3d33806f..73c2f518a7c0 100644 --- a/net/tipc/socket.c +++ b/net/tipc/socket.c @@ -133,6 +133,8 @@ static const struct nla_policy tipc_nl_sock_policy[TIPC_NLA_SOCK_MAX + 1] = { [TIPC_NLA_SOCK_HAS_PUBL] = { .type = NLA_FLAG } }; +static const struct rhashtable_params tsk_rht_params; + /* * Revised TIPC socket locking policy: * @@ -2245,7 +2247,7 @@ static struct tipc_sock *tipc_sk_lookup(struct net *net, u32 portid) struct tipc_sock *tsk; rcu_read_lock(); - tsk = rhashtable_lookup(&tn->sk_rht, &portid); + tsk = rhashtable_lookup_fast(&tn->sk_rht, &portid, tsk_rht_params); if (tsk) sock_hold(&tsk->sk); rcu_read_unlock(); @@ -2267,7 +2269,8 @@ static int tipc_sk_insert(struct tipc_sock *tsk) portid = TIPC_MIN_PORT; tsk->portid = portid; sock_hold(&tsk->sk); - if (rhashtable_lookup_insert(&tn->sk_rht, &tsk->node)) + if (!rhashtable_lookup_insert_fast(&tn->sk_rht, &tsk->node, + tsk_rht_params)) return 0; sock_put(&tsk->sk); } @@ -2280,26 +2283,27 @@ static void tipc_sk_remove(struct tipc_sock *tsk) struct sock *sk = &tsk->sk; struct tipc_net *tn = net_generic(sock_net(sk), tipc_net_id); - if (rhashtable_remove(&tn->sk_rht, &tsk->node)) { + if (!rhashtable_remove_fast(&tn->sk_rht, &tsk->node, tsk_rht_params)) { WARN_ON(atomic_read(&sk->sk_refcnt) == 1); __sock_put(sk); } } +static const struct rhashtable_params tsk_rht_params = { + .nelem_hint = 192, + .head_offset = offsetof(struct tipc_sock, node), + .key_offset = offsetof(struct tipc_sock, portid), + .key_len = sizeof(u32), /* portid */ + .hashfn = jhash, + .max_size = 1048576, + .min_size = 256, +}; + int tipc_sk_rht_init(struct net *net) { struct tipc_net *tn = net_generic(net, tipc_net_id); - struct rhashtable_params rht_params = { - .nelem_hint = 192, - .head_offset = offsetof(struct tipc_sock, node), - .key_offset = offsetof(struct tipc_sock, portid), - .key_len = sizeof(u32), /* portid */ - .hashfn = jhash, - .max_size = 1048576, - .min_size = 256, - }; - - return rhashtable_init(&tn->sk_rht, &rht_params); + + return rhashtable_init(&tn->sk_rht, &tsk_rht_params); } void tipc_sk_rht_destroy(struct net *net) |