diff options
Diffstat (limited to 'net/core')
-rw-r--r-- | net/core/filter.c | 59 | ||||
-rw-r--r-- | net/core/rtnetlink.c | 4 |
2 files changed, 36 insertions, 27 deletions
diff --git a/net/core/filter.c b/net/core/filter.c index c25eb36f1320..aecdeba052d3 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -2282,14 +2282,21 @@ static const struct bpf_func_proto bpf_msg_cork_bytes_proto = { .arg2_type = ARG_ANYTHING, }; +#define sk_msg_iter_var(var) \ + do { \ + var++; \ + if (var == MAX_SKB_FRAGS) \ + var = 0; \ + } while (0) + BPF_CALL_4(bpf_msg_pull_data, struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags) { - unsigned int len = 0, offset = 0, copy = 0; + unsigned int len = 0, offset = 0, copy = 0, poffset = 0; + int bytes = end - start, bytes_sg_total; struct scatterlist *sg = msg->sg_data; int first_sg, last_sg, i, shift; unsigned char *p, *to, *from; - int bytes = end - start; struct page *page; if (unlikely(flags || end <= start)) @@ -2299,21 +2306,22 @@ BPF_CALL_4(bpf_msg_pull_data, i = msg->sg_start; do { len = sg[i].length; - offset += len; if (start < offset + len) break; - i++; - if (i == MAX_SKB_FRAGS) - i = 0; + offset += len; + sk_msg_iter_var(i); } while (i != msg->sg_end); if (unlikely(start >= offset + len)) return -EINVAL; - if (!msg->sg_copy[i] && bytes <= len) - goto out; - first_sg = i; + /* The start may point into the sg element so we need to also + * account for the headroom. + */ + bytes_sg_total = start - offset + bytes; + if (!msg->sg_copy[i] && bytes_sg_total <= len) + goto out; /* At this point we need to linearize multiple scatterlist * elements or a single shared page. Either way we need to @@ -2327,37 +2335,32 @@ BPF_CALL_4(bpf_msg_pull_data, */ do { copy += sg[i].length; - i++; - if (i == MAX_SKB_FRAGS) - i = 0; - if (bytes < copy) + sk_msg_iter_var(i); + if (bytes_sg_total <= copy) break; } while (i != msg->sg_end); last_sg = i; - if (unlikely(copy < end - start)) + if (unlikely(bytes_sg_total > copy)) return -EINVAL; page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC, get_order(copy)); if (unlikely(!page)) return -ENOMEM; p = page_address(page); - offset = 0; i = first_sg; do { from = sg_virt(&sg[i]); len = sg[i].length; - to = p + offset; + to = p + poffset; memcpy(to, from, len); - offset += len; + poffset += len; sg[i].length = 0; put_page(sg_page(&sg[i])); - i++; - if (i == MAX_SKB_FRAGS) - i = 0; + sk_msg_iter_var(i); } while (i != last_sg); sg[first_sg].length = copy; @@ -2367,11 +2370,15 @@ BPF_CALL_4(bpf_msg_pull_data, * had a single entry though we can just replace it and * be done. Otherwise walk the ring and shift the entries. */ - shift = last_sg - first_sg - 1; + WARN_ON_ONCE(last_sg == first_sg); + shift = last_sg > first_sg ? + last_sg - first_sg - 1 : + MAX_SKB_FRAGS - first_sg + last_sg - 1; if (!shift) goto out; - i = first_sg + 1; + i = first_sg; + sk_msg_iter_var(i); do { int move_from; @@ -2388,15 +2395,13 @@ BPF_CALL_4(bpf_msg_pull_data, sg[move_from].page_link = 0; sg[move_from].offset = 0; - i++; - if (i == MAX_SKB_FRAGS) - i = 0; + sk_msg_iter_var(i); } while (1); msg->sg_end -= shift; if (msg->sg_end < 0) msg->sg_end += MAX_SKB_FRAGS; out: - msg->data = sg_virt(&sg[i]) + start - offset; + msg->data = sg_virt(&sg[first_sg]) + start - offset; msg->data_end = msg->data + bytes; return 0; @@ -7281,7 +7286,7 @@ static u32 sk_reuseport_convert_ctx_access(enum bpf_access_type type, break; case offsetof(struct sk_reuseport_md, ip_protocol): - BUILD_BUG_ON(hweight_long(SK_FL_PROTO_MASK) != BITS_PER_BYTE); + BUILD_BUG_ON(HWEIGHT32(SK_FL_PROTO_MASK) != BITS_PER_BYTE); SK_REUSEPORT_LOAD_SK_FIELD_SIZE_OFF(__sk_flags_offset, BPF_W, 0); *insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, SK_FL_PROTO_MASK); diff --git a/net/core/rtnetlink.c b/net/core/rtnetlink.c index 24431e578310..60c928894a78 100644 --- a/net/core/rtnetlink.c +++ b/net/core/rtnetlink.c @@ -324,6 +324,10 @@ void rtnl_unregister_all(int protocol) rtnl_lock(); tab = rtnl_msg_handlers[protocol]; + if (!tab) { + rtnl_unlock(); + return; + } RCU_INIT_POINTER(rtnl_msg_handlers[protocol], NULL); for (msgindex = 0; msgindex < RTM_NR_MSGTYPES; msgindex++) { link = tab[msgindex]; |