diff options
Diffstat (limited to 'net/core/sock_map.c')
-rw-r--r-- | net/core/sock_map.c | 194 |
1 files changed, 94 insertions, 100 deletions
diff --git a/net/core/sock_map.c b/net/core/sock_map.c index d758fb83c884..6f1b82b8ad49 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -24,6 +24,10 @@ struct bpf_stab { #define SOCK_CREATE_FLAG_MASK \ (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) +static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, + struct bpf_prog *old, u32 which); +static struct sk_psock_progs *sock_map_progs(struct bpf_map *map); + static struct bpf_map *sock_map_alloc(union bpf_attr *attr) { struct bpf_stab *stab; @@ -148,9 +152,11 @@ static void sock_map_del_link(struct sock *sk, struct bpf_map *map = link->map; struct bpf_stab *stab = container_of(map, struct bpf_stab, map); - if (psock->parser.enabled && stab->progs.skb_parser) + if (psock->saved_data_ready && stab->progs.stream_parser) strp_stop = true; - if (psock->parser.enabled && stab->progs.skb_verdict) + if (psock->saved_data_ready && stab->progs.stream_verdict) + verdict_stop = true; + if (psock->saved_data_ready && stab->progs.skb_verdict) verdict_stop = true; list_del(&link->list); sk_psock_free_link(link); @@ -179,26 +185,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw) static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) { - struct proto *prot; - - switch (sk->sk_type) { - case SOCK_STREAM: - prot = tcp_bpf_get_proto(sk, psock); - break; - - case SOCK_DGRAM: - prot = udp_bpf_get_proto(sk, psock); - break; - - default: + if (!sk->sk_prot->psock_update_sk_prot) return -EINVAL; - } - - if (IS_ERR(prot)) - return PTR_ERR(prot); - - sk_psock_update_proto(sk, psock, prot); - return 0; + psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot; + return sk->sk_prot->psock_update_sk_prot(sk, psock, false); } static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) @@ -221,26 +211,38 @@ out: return psock; } -static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, - struct sock *sk) +static bool sock_map_redirect_allowed(const struct sock *sk); + +static int sock_map_link(struct bpf_map *map, struct sock *sk) { - struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; + struct sk_psock_progs *progs = sock_map_progs(map); + struct bpf_prog *stream_verdict = NULL; + struct bpf_prog *stream_parser = NULL; + struct bpf_prog *skb_verdict = NULL; + struct bpf_prog *msg_parser = NULL; struct sk_psock *psock; int ret; - skb_verdict = READ_ONCE(progs->skb_verdict); - if (skb_verdict) { - skb_verdict = bpf_prog_inc_not_zero(skb_verdict); - if (IS_ERR(skb_verdict)) - return PTR_ERR(skb_verdict); + /* Only sockets we can redirect into/from in BPF need to hold + * refs to parser/verdict progs and have their sk_data_ready + * and sk_write_space callbacks overridden. + */ + if (!sock_map_redirect_allowed(sk)) + goto no_progs; + + stream_verdict = READ_ONCE(progs->stream_verdict); + if (stream_verdict) { + stream_verdict = bpf_prog_inc_not_zero(stream_verdict); + if (IS_ERR(stream_verdict)) + return PTR_ERR(stream_verdict); } - skb_parser = READ_ONCE(progs->skb_parser); - if (skb_parser) { - skb_parser = bpf_prog_inc_not_zero(skb_parser); - if (IS_ERR(skb_parser)) { - ret = PTR_ERR(skb_parser); - goto out_put_skb_verdict; + stream_parser = READ_ONCE(progs->stream_parser); + if (stream_parser) { + stream_parser = bpf_prog_inc_not_zero(stream_parser); + if (IS_ERR(stream_parser)) { + ret = PTR_ERR(stream_parser); + goto out_put_stream_verdict; } } @@ -249,10 +251,20 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, msg_parser = bpf_prog_inc_not_zero(msg_parser); if (IS_ERR(msg_parser)) { ret = PTR_ERR(msg_parser); - goto out_put_skb_parser; + goto out_put_stream_parser; } } + skb_verdict = READ_ONCE(progs->skb_verdict); + if (skb_verdict) { + skb_verdict = bpf_prog_inc_not_zero(skb_verdict); + if (IS_ERR(skb_verdict)) { + ret = PTR_ERR(skb_verdict); + goto out_put_msg_parser; + } + } + +no_progs: psock = sock_map_psock_get_checked(sk); if (IS_ERR(psock)) { ret = PTR_ERR(psock); @@ -261,8 +273,11 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, if (psock) { if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || - (skb_parser && READ_ONCE(psock->progs.skb_parser)) || - (skb_verdict && READ_ONCE(psock->progs.skb_verdict))) { + (stream_parser && READ_ONCE(psock->progs.stream_parser)) || + (skb_verdict && READ_ONCE(psock->progs.skb_verdict)) || + (skb_verdict && READ_ONCE(psock->progs.stream_verdict)) || + (stream_verdict && READ_ONCE(psock->progs.skb_verdict)) || + (stream_verdict && READ_ONCE(psock->progs.stream_verdict))) { sk_psock_put(sk, psock); ret = -EBUSY; goto out_progs; @@ -283,16 +298,19 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, goto out_drop; write_lock_bh(&sk->sk_callback_lock); - if (skb_parser && skb_verdict && !psock->parser.enabled) { + if (stream_parser && stream_verdict && !psock->saved_data_ready) { ret = sk_psock_init_strp(sk, psock); if (ret) goto out_unlock_drop; - psock_set_prog(&psock->progs.skb_verdict, skb_verdict); - psock_set_prog(&psock->progs.skb_parser, skb_parser); + psock_set_prog(&psock->progs.stream_verdict, stream_verdict); + psock_set_prog(&psock->progs.stream_parser, stream_parser); sk_psock_start_strp(sk, psock); - } else if (!skb_parser && skb_verdict && !psock->parser.enabled) { - psock_set_prog(&psock->progs.skb_verdict, skb_verdict); + } else if (!stream_parser && stream_verdict && !psock->saved_data_ready) { + psock_set_prog(&psock->progs.stream_verdict, stream_verdict); sk_psock_start_verdict(sk,psock); + } else if (!stream_verdict && skb_verdict && !psock->saved_data_ready) { + psock_set_prog(&psock->progs.skb_verdict, skb_verdict); + sk_psock_start_verdict(sk, psock); } write_unlock_bh(&sk->sk_callback_lock); return 0; @@ -301,35 +319,17 @@ out_unlock_drop: out_drop: sk_psock_put(sk, psock); out_progs: - if (msg_parser) - bpf_prog_put(msg_parser); -out_put_skb_parser: - if (skb_parser) - bpf_prog_put(skb_parser); -out_put_skb_verdict: if (skb_verdict) bpf_prog_put(skb_verdict); - return ret; -} - -static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk) -{ - struct sk_psock *psock; - int ret; - - psock = sock_map_psock_get_checked(sk); - if (IS_ERR(psock)) - return PTR_ERR(psock); - - if (!psock) { - psock = sk_psock_init(sk, map->numa_node); - if (IS_ERR(psock)) - return PTR_ERR(psock); - } - - ret = sock_map_init_proto(sk, psock); - if (ret < 0) - sk_psock_put(sk, psock); +out_put_msg_parser: + if (msg_parser) + bpf_prog_put(msg_parser); +out_put_stream_parser: + if (stream_parser) + bpf_prog_put(stream_parser); +out_put_stream_verdict: + if (stream_verdict) + bpf_prog_put(stream_verdict); return ret; } @@ -463,8 +463,6 @@ static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) return 0; } -static bool sock_map_redirect_allowed(const struct sock *sk); - static int sock_map_update_common(struct bpf_map *map, u32 idx, struct sock *sk, u64 flags) { @@ -484,14 +482,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx, if (!link) return -ENOMEM; - /* Only sockets we can redirect into/from in BPF need to hold - * refs to parser/verdict progs and have their sk_data_ready - * and sk_write_space callbacks overridden. - */ - if (sock_map_redirect_allowed(sk)) - ret = sock_map_link(map, &stab->progs, sk); - else - ret = sock_map_link_no_progs(map, sk); + ret = sock_map_link(map, sk); if (ret < 0) goto out_free; @@ -544,12 +535,15 @@ static bool sk_is_udp(const struct sock *sk) static bool sock_map_redirect_allowed(const struct sock *sk) { - return sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN; + if (sk_is_tcp(sk)) + return sk->sk_state != TCP_LISTEN; + else + return sk->sk_state == TCP_ESTABLISHED; } static bool sock_map_sk_is_suitable(const struct sock *sk) { - return sk_is_tcp(sk) || sk_is_udp(sk); + return !!sk->sk_prot->psock_update_sk_prot; } static bool sock_map_sk_state_allowed(const struct sock *sk) @@ -657,7 +651,6 @@ const struct bpf_func_proto bpf_sock_map_update_proto = { BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, struct bpf_map *, map, u32, key, u64, flags) { - struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); struct sock *sk; if (unlikely(flags & ~(BPF_F_INGRESS))) @@ -667,8 +660,7 @@ BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, if (unlikely(!sk || !sock_map_redirect_allowed(sk))) return SK_DROP; - tcb->bpf.flags = flags; - tcb->bpf.sk_redir = sk; + skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS); return SK_PASS; } @@ -998,14 +990,7 @@ static int sock_hash_update_common(struct bpf_map *map, void *key, if (!link) return -ENOMEM; - /* Only sockets we can redirect into/from in BPF need to hold - * refs to parser/verdict progs and have their sk_data_ready - * and sk_write_space callbacks overridden. - */ - if (sock_map_redirect_allowed(sk)) - ret = sock_map_link(map, &htab->progs, sk); - else - ret = sock_map_link_no_progs(map, sk); + ret = sock_map_link(map, sk); if (ret < 0) goto out_free; @@ -1250,7 +1235,6 @@ const struct bpf_func_proto bpf_sock_hash_update_proto = { BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, struct bpf_map *, map, void *, key, u64, flags) { - struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); struct sock *sk; if (unlikely(flags & ~(BPF_F_INGRESS))) @@ -1260,8 +1244,7 @@ BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, if (unlikely(!sk || !sock_map_redirect_allowed(sk))) return SK_DROP; - tcb->bpf.flags = flags; - tcb->bpf.sk_redir = sk; + skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS); return SK_PASS; } @@ -1448,8 +1431,8 @@ static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) return NULL; } -int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, - struct bpf_prog *old, u32 which) +static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, + struct bpf_prog *old, u32 which) { struct sk_psock_progs *progs = sock_map_progs(map); struct bpf_prog **pprog; @@ -1461,10 +1444,19 @@ int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, case BPF_SK_MSG_VERDICT: pprog = &progs->msg_parser; break; +#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER) case BPF_SK_SKB_STREAM_PARSER: - pprog = &progs->skb_parser; + pprog = &progs->stream_parser; break; +#endif case BPF_SK_SKB_STREAM_VERDICT: + if (progs->skb_verdict) + return -EBUSY; + pprog = &progs->stream_verdict; + break; + case BPF_SK_SKB_VERDICT: + if (progs->stream_verdict) + return -EBUSY; pprog = &progs->skb_verdict; break; default: @@ -1529,7 +1521,7 @@ void sock_map_close(struct sock *sk, long timeout) lock_sock(sk); rcu_read_lock(); - psock = sk_psock(sk); + psock = sk_psock_get(sk); if (unlikely(!psock)) { rcu_read_unlock(); release_sock(sk); @@ -1539,6 +1531,8 @@ void sock_map_close(struct sock *sk, long timeout) saved_close = psock->saved_close; sock_map_remove_links(sk, psock); rcu_read_unlock(); + sk_psock_stop(psock, true); + sk_psock_put(sk, psock); release_sock(sk); saved_close(sk, timeout); } |