diff options
Diffstat (limited to 'net/xfrm/xfrm_state.c')
-rw-r--r-- | net/xfrm/xfrm_state.c | 116 |
1 files changed, 98 insertions, 18 deletions
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index 711e816fc404..69af5964c886 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -424,18 +424,18 @@ void xfrm_unregister_type_offload(const struct xfrm_type_offload *type, } EXPORT_SYMBOL(xfrm_unregister_type_offload); -static const struct xfrm_type_offload * -xfrm_get_type_offload(u8 proto, unsigned short family, bool try_load) +void xfrm_set_type_offload(struct xfrm_state *x) { const struct xfrm_type_offload *type = NULL; struct xfrm_state_afinfo *afinfo; + bool try_load = true; retry: - afinfo = xfrm_state_get_afinfo(family); + afinfo = xfrm_state_get_afinfo(x->props.family); if (unlikely(afinfo == NULL)) - return NULL; + goto out; - switch (proto) { + switch (x->id.proto) { case IPPROTO_ESP: type = afinfo->type_offload_esp; break; @@ -449,18 +449,16 @@ retry: rcu_read_unlock(); if (!type && try_load) { - request_module("xfrm-offload-%d-%d", family, proto); + request_module("xfrm-offload-%d-%d", x->props.family, + x->id.proto); try_load = false; goto retry; } - return type; -} - -static void xfrm_put_type_offload(const struct xfrm_type_offload *type) -{ - module_put(type->owner); +out: + x->type_offload = type; } +EXPORT_SYMBOL(xfrm_set_type_offload); static const struct xfrm_mode xfrm4_mode_map[XFRM_MODE_MAX] = { [XFRM_MODE_BEET] = { @@ -477,6 +475,11 @@ static const struct xfrm_mode xfrm4_mode_map[XFRM_MODE_MAX] = { .flags = XFRM_MODE_FLAG_TUNNEL, .family = AF_INET, }, + [XFRM_MODE_IPTFS] = { + .encap = XFRM_MODE_IPTFS, + .flags = XFRM_MODE_FLAG_TUNNEL, + .family = AF_INET, + }, }; static const struct xfrm_mode xfrm6_mode_map[XFRM_MODE_MAX] = { @@ -498,6 +501,11 @@ static const struct xfrm_mode xfrm6_mode_map[XFRM_MODE_MAX] = { .flags = XFRM_MODE_FLAG_TUNNEL, .family = AF_INET6, }, + [XFRM_MODE_IPTFS] = { + .encap = XFRM_MODE_IPTFS, + .flags = XFRM_MODE_FLAG_TUNNEL, + .family = AF_INET6, + }, }; static const struct xfrm_mode *xfrm_get_mode(unsigned int encap, int family) @@ -525,6 +533,60 @@ static const struct xfrm_mode *xfrm_get_mode(unsigned int encap, int family) return NULL; } +static const struct xfrm_mode_cbs __rcu *xfrm_mode_cbs_map[XFRM_MODE_MAX]; +static DEFINE_SPINLOCK(xfrm_mode_cbs_map_lock); + +int xfrm_register_mode_cbs(u8 mode, const struct xfrm_mode_cbs *mode_cbs) +{ + if (mode >= XFRM_MODE_MAX) + return -EINVAL; + + spin_lock_bh(&xfrm_mode_cbs_map_lock); + rcu_assign_pointer(xfrm_mode_cbs_map[mode], mode_cbs); + spin_unlock_bh(&xfrm_mode_cbs_map_lock); + + return 0; +} +EXPORT_SYMBOL(xfrm_register_mode_cbs); + +void xfrm_unregister_mode_cbs(u8 mode) +{ + if (mode >= XFRM_MODE_MAX) + return; + + spin_lock_bh(&xfrm_mode_cbs_map_lock); + RCU_INIT_POINTER(xfrm_mode_cbs_map[mode], NULL); + spin_unlock_bh(&xfrm_mode_cbs_map_lock); + synchronize_rcu(); +} +EXPORT_SYMBOL(xfrm_unregister_mode_cbs); + +static const struct xfrm_mode_cbs *xfrm_get_mode_cbs(u8 mode) +{ + const struct xfrm_mode_cbs *cbs; + bool try_load = true; + + if (mode >= XFRM_MODE_MAX) + return NULL; + +retry: + rcu_read_lock(); + + cbs = rcu_dereference(xfrm_mode_cbs_map[mode]); + if (cbs && !try_module_get(cbs->owner)) + cbs = NULL; + + rcu_read_unlock(); + + if (mode == XFRM_MODE_IPTFS && !cbs && try_load) { + request_module("xfrm-iptfs"); + try_load = false; + goto retry; + } + + return cbs; +} + void xfrm_state_free(struct xfrm_state *x) { kmem_cache_free(xfrm_state_cache, x); @@ -533,6 +595,8 @@ EXPORT_SYMBOL(xfrm_state_free); static void ___xfrm_state_destroy(struct xfrm_state *x) { + if (x->mode_cbs && x->mode_cbs->destroy_state) + x->mode_cbs->destroy_state(x); hrtimer_cancel(&x->mtimer); del_timer_sync(&x->rtimer); kfree(x->aead); @@ -543,8 +607,6 @@ static void ___xfrm_state_destroy(struct xfrm_state *x) kfree(x->coaddr); kfree(x->replay_esn); kfree(x->preplay_esn); - if (x->type_offload) - xfrm_put_type_offload(x->type_offload); if (x->type) { x->type->destructor(x); xfrm_put_type(x->type); @@ -692,6 +754,7 @@ struct xfrm_state *xfrm_state_alloc(struct net *net) x->replay_maxdiff = 0; x->pcpu_num = UINT_MAX; spin_lock_init(&x->lock); + x->mode_data = NULL; } return x; } @@ -717,6 +780,8 @@ void xfrm_dev_state_free(struct xfrm_state *x) struct xfrm_dev_offload *xso = &x->xso; struct net_device *dev = READ_ONCE(xso->dev); + xfrm_unset_type_offload(x); + if (dev && dev->xfrmdev_ops) { spin_lock_bh(&xfrm_state_dev_gc_lock); if (!hlist_unhashed(&x->dev_gclist)) @@ -1987,6 +2052,12 @@ static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig, x->new_mapping_sport = 0; x->dir = orig->dir; + x->mode_cbs = orig->mode_cbs; + if (x->mode_cbs && x->mode_cbs->clone_state) { + if (x->mode_cbs->clone_state(x, orig)) + goto error; + } + return x; error: @@ -2320,6 +2391,7 @@ static int __xfrm6_state_sort_cmp(const void *p) #endif case XFRM_MODE_TUNNEL: case XFRM_MODE_BEET: + case XFRM_MODE_IPTFS: return 4; } return 5; @@ -2346,6 +2418,7 @@ static int __xfrm6_tmpl_sort_cmp(const void *p) #endif case XFRM_MODE_TUNNEL: case XFRM_MODE_BEET: + case XFRM_MODE_IPTFS: return 3; } return 4; @@ -3035,6 +3108,9 @@ u32 xfrm_state_mtu(struct xfrm_state *x, int mtu) case XFRM_MODE_TUNNEL: break; default: + if (x->mode_cbs && x->mode_cbs->get_inner_mtu) + return x->mode_cbs->get_inner_mtu(x, mtu); + WARN_ON_ONCE(1); break; } @@ -3044,7 +3120,7 @@ u32 xfrm_state_mtu(struct xfrm_state *x, int mtu) } EXPORT_SYMBOL_GPL(xfrm_state_mtu); -int __xfrm_init_state(struct xfrm_state *x, bool init_replay, bool offload, +int __xfrm_init_state(struct xfrm_state *x, bool init_replay, struct netlink_ext_ack *extack) { const struct xfrm_mode *inner_mode; @@ -3100,8 +3176,6 @@ int __xfrm_init_state(struct xfrm_state *x, bool init_replay, bool offload, goto error; } - x->type_offload = xfrm_get_type_offload(x->id.proto, family, offload); - err = x->type->init_state(x, extack); if (err) goto error; @@ -3135,6 +3209,12 @@ int __xfrm_init_state(struct xfrm_state *x, bool init_replay, bool offload, } } + x->mode_cbs = xfrm_get_mode_cbs(x->props.mode); + if (x->mode_cbs) { + if (x->mode_cbs->init_state) + err = x->mode_cbs->init_state(x); + module_put(x->mode_cbs->owner); + } error: return err; } @@ -3145,7 +3225,7 @@ int xfrm_init_state(struct xfrm_state *x) { int err; - err = __xfrm_init_state(x, true, false, NULL); + err = __xfrm_init_state(x, true, NULL); if (!err) x->km.state = XFRM_STATE_VALID; |