diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/bpf/test_run.c | 23 | ||||
-rw-r--r-- | net/core/filter.c | 34 | ||||
-rw-r--r-- | net/ipv4/tcp_ipv4.c | 410 | ||||
-rw-r--r-- | net/unix/unix_bpf.c | 16 |
4 files changed, 377 insertions, 106 deletions
diff --git a/net/bpf/test_run.c b/net/bpf/test_run.c index b488e2779718..695449088e42 100644 --- a/net/bpf/test_run.c +++ b/net/bpf/test_run.c @@ -88,17 +88,19 @@ reset: static int bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat, u32 *retval, u32 *time, bool xdp) { - struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE] = { NULL }; + struct bpf_prog_array_item item = {.prog = prog}; + struct bpf_run_ctx *old_ctx; + struct bpf_cg_run_ctx run_ctx; struct bpf_test_timer t = { NO_MIGRATE }; enum bpf_cgroup_storage_type stype; int ret; for_each_cgroup_storage_type(stype) { - storage[stype] = bpf_cgroup_storage_alloc(prog, stype); - if (IS_ERR(storage[stype])) { - storage[stype] = NULL; + item.cgroup_storage[stype] = bpf_cgroup_storage_alloc(prog, stype); + if (IS_ERR(item.cgroup_storage[stype])) { + item.cgroup_storage[stype] = NULL; for_each_cgroup_storage_type(stype) - bpf_cgroup_storage_free(storage[stype]); + bpf_cgroup_storage_free(item.cgroup_storage[stype]); return -ENOMEM; } } @@ -107,22 +109,19 @@ static int bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat, repeat = 1; bpf_test_timer_enter(&t); + old_ctx = bpf_set_run_ctx(&run_ctx.run_ctx); do { - ret = bpf_cgroup_storage_set(storage); - if (ret) - break; - + run_ctx.prog_item = &item; if (xdp) *retval = bpf_prog_run_xdp(prog, ctx); else *retval = BPF_PROG_RUN(prog, ctx); - - bpf_cgroup_storage_unset(); } while (bpf_test_timer_continue(&t, repeat, &ret, time)); + bpf_reset_run_ctx(old_ctx); bpf_test_timer_leave(&t); for_each_cgroup_storage_type(stype) - bpf_cgroup_storage_free(storage[stype]); + bpf_cgroup_storage_free(item.cgroup_storage[stype]); return ret; } diff --git a/net/core/filter.c b/net/core/filter.c index 3b4986e96e9c..faf29fd82276 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -5016,6 +5016,40 @@ err_clear: return -EINVAL; } +BPF_CALL_5(bpf_sk_setsockopt, struct sock *, sk, int, level, + int, optname, char *, optval, int, optlen) +{ + return _bpf_setsockopt(sk, level, optname, optval, optlen); +} + +const struct bpf_func_proto bpf_sk_setsockopt_proto = { + .func = bpf_sk_setsockopt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_MEM, + .arg5_type = ARG_CONST_SIZE, +}; + +BPF_CALL_5(bpf_sk_getsockopt, struct sock *, sk, int, level, + int, optname, char *, optval, int, optlen) +{ + return _bpf_getsockopt(sk, level, optname, optval, optlen); +} + +const struct bpf_func_proto bpf_sk_getsockopt_proto = { + .func = bpf_sk_getsockopt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_UNINIT_MEM, + .arg5_type = ARG_CONST_SIZE, +}; + BPF_CALL_5(bpf_sock_addr_setsockopt, struct bpf_sock_addr_kern *, ctx, int, level, int, optname, char *, optval, int, optlen) { diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 84db1c9ee92a..2e62e0d6373a 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -2277,51 +2277,72 @@ EXPORT_SYMBOL(tcp_v4_destroy_sock); #ifdef CONFIG_PROC_FS /* Proc filesystem TCP sock list dumping. */ -/* - * Get next listener socket follow cur. If cur is NULL, get first socket - * starting from bucket given in st->bucket; when st->bucket is zero the - * very first socket in the hash table is returned. +static unsigned short seq_file_family(const struct seq_file *seq); + +static bool seq_sk_match(struct seq_file *seq, const struct sock *sk) +{ + unsigned short family = seq_file_family(seq); + + /* AF_UNSPEC is used as a match all */ + return ((family == AF_UNSPEC || family == sk->sk_family) && + net_eq(sock_net(sk), seq_file_net(seq))); +} + +/* Find a non empty bucket (starting from st->bucket) + * and return the first sk from it. */ -static void *listening_get_next(struct seq_file *seq, void *cur) +static void *listening_get_first(struct seq_file *seq) { - struct tcp_seq_afinfo *afinfo; struct tcp_iter_state *st = seq->private; - struct net *net = seq_file_net(seq); - struct inet_listen_hashbucket *ilb; - struct hlist_nulls_node *node; - struct sock *sk = cur; - if (st->bpf_seq_afinfo) - afinfo = st->bpf_seq_afinfo; - else - afinfo = PDE_DATA(file_inode(seq->file)); + st->offset = 0; + for (; st->bucket <= tcp_hashinfo.lhash2_mask; st->bucket++) { + struct inet_listen_hashbucket *ilb2; + struct inet_connection_sock *icsk; + struct sock *sk; - if (!sk) { -get_head: - ilb = &tcp_hashinfo.listening_hash[st->bucket]; - spin_lock(&ilb->lock); - sk = sk_nulls_head(&ilb->nulls_head); - st->offset = 0; - goto get_sk; + ilb2 = &tcp_hashinfo.lhash2[st->bucket]; + if (hlist_empty(&ilb2->head)) + continue; + + spin_lock(&ilb2->lock); + inet_lhash2_for_each_icsk(icsk, &ilb2->head) { + sk = (struct sock *)icsk; + if (seq_sk_match(seq, sk)) + return sk; + } + spin_unlock(&ilb2->lock); } - ilb = &tcp_hashinfo.listening_hash[st->bucket]; + + return NULL; +} + +/* Find the next sk of "cur" within the same bucket (i.e. st->bucket). + * If "cur" is the last one in the st->bucket, + * call listening_get_first() to return the first sk of the next + * non empty bucket. + */ +static void *listening_get_next(struct seq_file *seq, void *cur) +{ + struct tcp_iter_state *st = seq->private; + struct inet_listen_hashbucket *ilb2; + struct inet_connection_sock *icsk; + struct sock *sk = cur; + ++st->num; ++st->offset; - sk = sk_nulls_next(sk); -get_sk: - sk_nulls_for_each_from(sk, node) { - if (!net_eq(sock_net(sk), net)) - continue; - if (afinfo->family == AF_UNSPEC || - sk->sk_family == afinfo->family) + icsk = inet_csk(sk); + inet_lhash2_for_each_icsk_continue(icsk) { + sk = (struct sock *)icsk; + if (seq_sk_match(seq, sk)) return sk; } - spin_unlock(&ilb->lock); - st->offset = 0; - if (++st->bucket < INET_LHTABLE_SIZE) - goto get_head; - return NULL; + + ilb2 = &tcp_hashinfo.lhash2[st->bucket]; + spin_unlock(&ilb2->lock); + ++st->bucket; + return listening_get_first(seq); } static void *listening_get_idx(struct seq_file *seq, loff_t *pos) @@ -2331,7 +2352,7 @@ static void *listening_get_idx(struct seq_file *seq, loff_t *pos) st->bucket = 0; st->offset = 0; - rc = listening_get_next(seq, NULL); + rc = listening_get_first(seq); while (rc && *pos) { rc = listening_get_next(seq, rc); @@ -2351,15 +2372,7 @@ static inline bool empty_bucket(const struct tcp_iter_state *st) */ static void *established_get_first(struct seq_file *seq) { - struct tcp_seq_afinfo *afinfo; struct tcp_iter_state *st = seq->private; - struct net *net = seq_file_net(seq); - void *rc = NULL; - - if (st->bpf_seq_afinfo) - afinfo = st->bpf_seq_afinfo; - else - afinfo = PDE_DATA(file_inode(seq->file)); st->offset = 0; for (; st->bucket <= tcp_hashinfo.ehash_mask; ++st->bucket) { @@ -2373,32 +2386,20 @@ static void *established_get_first(struct seq_file *seq) spin_lock_bh(lock); sk_nulls_for_each(sk, node, &tcp_hashinfo.ehash[st->bucket].chain) { - if ((afinfo->family != AF_UNSPEC && - sk->sk_family != afinfo->family) || - !net_eq(sock_net(sk), net)) { - continue; - } - rc = sk; - goto out; + if (seq_sk_match(seq, sk)) + return sk; } spin_unlock_bh(lock); } -out: - return rc; + + return NULL; } static void *established_get_next(struct seq_file *seq, void *cur) { - struct tcp_seq_afinfo *afinfo; struct sock *sk = cur; struct hlist_nulls_node *node; struct tcp_iter_state *st = seq->private; - struct net *net = seq_file_net(seq); - - if (st->bpf_seq_afinfo) - afinfo = st->bpf_seq_afinfo; - else - afinfo = PDE_DATA(file_inode(seq->file)); ++st->num; ++st->offset; @@ -2406,9 +2407,7 @@ static void *established_get_next(struct seq_file *seq, void *cur) sk = sk_nulls_next(sk); sk_nulls_for_each_from(sk, node) { - if ((afinfo->family == AF_UNSPEC || - sk->sk_family == afinfo->family) && - net_eq(sock_net(sk), net)) + if (seq_sk_match(seq, sk)) return sk; } @@ -2451,17 +2450,18 @@ static void *tcp_get_idx(struct seq_file *seq, loff_t pos) static void *tcp_seek_last_pos(struct seq_file *seq) { struct tcp_iter_state *st = seq->private; + int bucket = st->bucket; int offset = st->offset; int orig_num = st->num; void *rc = NULL; switch (st->state) { case TCP_SEQ_STATE_LISTENING: - if (st->bucket >= INET_LHTABLE_SIZE) + if (st->bucket > tcp_hashinfo.lhash2_mask) break; st->state = TCP_SEQ_STATE_LISTENING; - rc = listening_get_next(seq, NULL); - while (offset-- && rc) + rc = listening_get_first(seq); + while (offset-- && rc && bucket == st->bucket) rc = listening_get_next(seq, rc); if (rc) break; @@ -2472,7 +2472,7 @@ static void *tcp_seek_last_pos(struct seq_file *seq) if (st->bucket > tcp_hashinfo.ehash_mask) break; rc = established_get_first(seq); - while (offset-- && rc) + while (offset-- && rc && bucket == st->bucket) rc = established_get_next(seq, rc); } @@ -2542,7 +2542,7 @@ void tcp_seq_stop(struct seq_file *seq, void *v) switch (st->state) { case TCP_SEQ_STATE_LISTENING: if (v != SEQ_START_TOKEN) - spin_unlock(&tcp_hashinfo.listening_hash[st->bucket].lock); + spin_unlock(&tcp_hashinfo.lhash2[st->bucket].lock); break; case TCP_SEQ_STATE_ESTABLISHED: if (v) @@ -2687,6 +2687,15 @@ out: } #ifdef CONFIG_BPF_SYSCALL +struct bpf_tcp_iter_state { + struct tcp_iter_state state; + unsigned int cur_sk; + unsigned int end_sk; + unsigned int max_sk; + struct sock **batch; + bool st_bucket_done; +}; + struct bpf_iter__tcp { __bpf_md_ptr(struct bpf_iter_meta *, meta); __bpf_md_ptr(struct sock_common *, sk_common); @@ -2705,16 +2714,204 @@ static int tcp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta, return bpf_iter_run_prog(prog, &ctx); } +static void bpf_iter_tcp_put_batch(struct bpf_tcp_iter_state *iter) +{ + while (iter->cur_sk < iter->end_sk) + sock_put(iter->batch[iter->cur_sk++]); +} + +static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter, + unsigned int new_batch_sz) +{ + struct sock **new_batch; + + new_batch = kvmalloc(sizeof(*new_batch) * new_batch_sz, + GFP_USER | __GFP_NOWARN); + if (!new_batch) + return -ENOMEM; + + bpf_iter_tcp_put_batch(iter); + kvfree(iter->batch); + iter->batch = new_batch; + iter->max_sk = new_batch_sz; + + return 0; +} + +static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq, + struct sock *start_sk) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + struct inet_connection_sock *icsk; + unsigned int expected = 1; + struct sock *sk; + + sock_hold(start_sk); + iter->batch[iter->end_sk++] = start_sk; + + icsk = inet_csk(start_sk); + inet_lhash2_for_each_icsk_continue(icsk) { + sk = (struct sock *)icsk; + if (seq_sk_match(seq, sk)) { + if (iter->end_sk < iter->max_sk) { + sock_hold(sk); + iter->batch[iter->end_sk++] = sk; + } + expected++; + } + } + spin_unlock(&tcp_hashinfo.lhash2[st->bucket].lock); + + return expected; +} + +static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq, + struct sock *start_sk) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + struct hlist_nulls_node *node; + unsigned int expected = 1; + struct sock *sk; + + sock_hold(start_sk); + iter->batch[iter->end_sk++] = start_sk; + + sk = sk_nulls_next(start_sk); + sk_nulls_for_each_from(sk, node) { + if (seq_sk_match(seq, sk)) { + if (iter->end_sk < iter->max_sk) { + sock_hold(sk); + iter->batch[iter->end_sk++] = sk; + } + expected++; + } + } + spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket)); + + return expected; +} + +static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + unsigned int expected; + bool resized = false; + struct sock *sk; + + /* The st->bucket is done. Directly advance to the next + * bucket instead of having the tcp_seek_last_pos() to skip + * one by one in the current bucket and eventually find out + * it has to advance to the next bucket. + */ + if (iter->st_bucket_done) { + st->offset = 0; + st->bucket++; + if (st->state == TCP_SEQ_STATE_LISTENING && + st->bucket > tcp_hashinfo.lhash2_mask) { + st->state = TCP_SEQ_STATE_ESTABLISHED; + st->bucket = 0; + } + } + +again: + /* Get a new batch */ + iter->cur_sk = 0; + iter->end_sk = 0; + iter->st_bucket_done = false; + + sk = tcp_seek_last_pos(seq); + if (!sk) + return NULL; /* Done */ + + if (st->state == TCP_SEQ_STATE_LISTENING) + expected = bpf_iter_tcp_listening_batch(seq, sk); + else + expected = bpf_iter_tcp_established_batch(seq, sk); + + if (iter->end_sk == expected) { + iter->st_bucket_done = true; + return sk; + } + + if (!resized && !bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2)) { + resized = true; + goto again; + } + + return sk; +} + +static void *bpf_iter_tcp_seq_start(struct seq_file *seq, loff_t *pos) +{ + /* bpf iter does not support lseek, so it always + * continue from where it was stop()-ped. + */ + if (*pos) + return bpf_iter_tcp_batch(seq); + + return SEQ_START_TOKEN; +} + +static void *bpf_iter_tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + struct sock *sk; + + /* Whenever seq_next() is called, the iter->cur_sk is + * done with seq_show(), so advance to the next sk in + * the batch. + */ + if (iter->cur_sk < iter->end_sk) { + /* Keeping st->num consistent in tcp_iter_state. + * bpf_iter_tcp does not use st->num. + * meta.seq_num is used instead. + */ + st->num++; + /* Move st->offset to the next sk in the bucket such that + * the future start() will resume at st->offset in + * st->bucket. See tcp_seek_last_pos(). + */ + st->offset++; + sock_put(iter->batch[iter->cur_sk++]); + } + + if (iter->cur_sk < iter->end_sk) + sk = iter->batch[iter->cur_sk]; + else + sk = bpf_iter_tcp_batch(seq); + + ++*pos; + /* Keeping st->last_pos consistent in tcp_iter_state. + * bpf iter does not do lseek, so st->last_pos always equals to *pos. + */ + st->last_pos = *pos; + return sk; +} + static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v) { struct bpf_iter_meta meta; struct bpf_prog *prog; struct sock *sk = v; + bool slow; uid_t uid; + int ret; if (v == SEQ_START_TOKEN) return 0; + if (sk_fullsock(sk)) + slow = lock_sock_fast(sk); + + if (unlikely(sk_unhashed(sk))) { + ret = SEQ_SKIP; + goto unlock; + } + if (sk->sk_state == TCP_TIME_WAIT) { uid = 0; } else if (sk->sk_state == TCP_NEW_SYN_RECV) { @@ -2728,11 +2925,18 @@ static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v) meta.seq = seq; prog = bpf_iter_get_info(&meta, false); - return tcp_prog_seq_show(prog, &meta, v, uid); + ret = tcp_prog_seq_show(prog, &meta, v, uid); + +unlock: + if (sk_fullsock(sk)) + unlock_sock_fast(sk, slow); + return ret; + } static void bpf_iter_tcp_seq_stop(struct seq_file *seq, void *v) { + struct bpf_tcp_iter_state *iter = seq->private; struct bpf_iter_meta meta; struct bpf_prog *prog; @@ -2743,16 +2947,33 @@ static void bpf_iter_tcp_seq_stop(struct seq_file *seq, void *v) (void)tcp_prog_seq_show(prog, &meta, v, 0); } - tcp_seq_stop(seq, v); + if (iter->cur_sk < iter->end_sk) { + bpf_iter_tcp_put_batch(iter); + iter->st_bucket_done = false; + } } static const struct seq_operations bpf_iter_tcp_seq_ops = { .show = bpf_iter_tcp_seq_show, - .start = tcp_seq_start, - .next = tcp_seq_next, + .start = bpf_iter_tcp_seq_start, + .next = bpf_iter_tcp_seq_next, .stop = bpf_iter_tcp_seq_stop, }; #endif +static unsigned short seq_file_family(const struct seq_file *seq) +{ + const struct tcp_seq_afinfo *afinfo; + +#ifdef CONFIG_BPF_SYSCALL + /* Iterated from bpf_iter. Let the bpf prog to filter instead. */ + if (seq->op == &bpf_iter_tcp_seq_ops) + return AF_UNSPEC; +#endif + + /* Iterated from proc fs */ + afinfo = PDE_DATA(file_inode(seq->file)); + return afinfo->family; +} static const struct seq_operations tcp4_seq_ops = { .show = tcp4_seq_show, @@ -3002,39 +3223,55 @@ static struct pernet_operations __net_initdata tcp_sk_ops = { DEFINE_BPF_ITER_FUNC(tcp, struct bpf_iter_meta *meta, struct sock_common *sk_common, uid_t uid) +#define INIT_BATCH_SZ 16 + static int bpf_iter_init_tcp(void *priv_data, struct bpf_iter_aux_info *aux) { - struct tcp_iter_state *st = priv_data; - struct tcp_seq_afinfo *afinfo; - int ret; + struct bpf_tcp_iter_state *iter = priv_data; + int err; - afinfo = kmalloc(sizeof(*afinfo), GFP_USER | __GFP_NOWARN); - if (!afinfo) - return -ENOMEM; + err = bpf_iter_init_seq_net(priv_data, aux); + if (err) + return err; - afinfo->family = AF_UNSPEC; - st->bpf_seq_afinfo = afinfo; - ret = bpf_iter_init_seq_net(priv_data, aux); - if (ret) - kfree(afinfo); - return ret; + err = bpf_iter_tcp_realloc_batch(iter, INIT_BATCH_SZ); + if (err) { + bpf_iter_fini_seq_net(priv_data); + return err; + } + + return 0; } static void bpf_iter_fini_tcp(void *priv_data) { - struct tcp_iter_state *st = priv_data; + struct bpf_tcp_iter_state *iter = priv_data; - kfree(st->bpf_seq_afinfo); bpf_iter_fini_seq_net(priv_data); + kvfree(iter->batch); } static const struct bpf_iter_seq_info tcp_seq_info = { .seq_ops = &bpf_iter_tcp_seq_ops, .init_seq_private = bpf_iter_init_tcp, .fini_seq_private = bpf_iter_fini_tcp, - .seq_priv_size = sizeof(struct tcp_iter_state), + .seq_priv_size = sizeof(struct bpf_tcp_iter_state), }; +static const struct bpf_func_proto * +bpf_iter_tcp_get_func_proto(enum bpf_func_id func_id, + const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_setsockopt: + return &bpf_sk_setsockopt_proto; + case BPF_FUNC_getsockopt: + return &bpf_sk_getsockopt_proto; + default: + return NULL; + } +} + static struct bpf_iter_reg tcp_reg_info = { .target = "tcp", .ctx_arg_info_size = 1, @@ -3042,6 +3279,7 @@ static struct bpf_iter_reg tcp_reg_info = { { offsetof(struct bpf_iter__tcp, sk_common), PTR_TO_BTF_ID_OR_NULL }, }, + .get_func_proto = bpf_iter_tcp_get_func_proto, .seq_info = &tcp_seq_info, }; diff --git a/net/unix/unix_bpf.c b/net/unix/unix_bpf.c index db0cda29fb2f..177e883f451e 100644 --- a/net/unix/unix_bpf.c +++ b/net/unix/unix_bpf.c @@ -44,7 +44,7 @@ static int unix_dgram_bpf_recvmsg(struct sock *sk, struct msghdr *msg, { struct unix_sock *u = unix_sk(sk); struct sk_psock *psock; - int copied, ret; + int copied; psock = sk_psock_get(sk); if (unlikely(!psock)) @@ -53,8 +53,9 @@ static int unix_dgram_bpf_recvmsg(struct sock *sk, struct msghdr *msg, mutex_lock(&u->iolock); if (!skb_queue_empty(&sk->sk_receive_queue) && sk_psock_queue_empty(psock)) { - ret = __unix_dgram_recvmsg(sk, msg, len, flags); - goto out; + mutex_unlock(&u->iolock); + sk_psock_put(sk, psock); + return __unix_dgram_recvmsg(sk, msg, len, flags); } msg_bytes_ready: @@ -68,16 +69,15 @@ msg_bytes_ready: if (data) { if (!sk_psock_queue_empty(psock)) goto msg_bytes_ready; - ret = __unix_dgram_recvmsg(sk, msg, len, flags); - goto out; + mutex_unlock(&u->iolock); + sk_psock_put(sk, psock); + return __unix_dgram_recvmsg(sk, msg, len, flags); } copied = -EAGAIN; } - ret = copied; -out: mutex_unlock(&u->iolock); sk_psock_put(sk, psock); - return ret; + return copied; } static struct proto *unix_prot_saved __read_mostly; |