diff options
Diffstat (limited to 'tools')
46 files changed, 5080 insertions, 238 deletions
diff --git a/tools/bpf/bpftool/Makefile b/tools/bpf/bpftool/Makefile index f610e184ce02..ab20ecc5acce 100644 --- a/tools/bpf/bpftool/Makefile +++ b/tools/bpf/bpftool/Makefile @@ -293,3 +293,6 @@ FORCE: .PHONY: all FORCE bootstrap clean install-bin install uninstall .PHONY: doc doc-clean doc-install doc-uninstall .DEFAULT_GOAL := all + +# Delete partially updated (corrupted) files on error +.DELETE_ON_ERROR: diff --git a/tools/bpf/resolve_btfids/Makefile b/tools/bpf/resolve_btfids/Makefile index 19a3112e271a..f7375a119f54 100644 --- a/tools/bpf/resolve_btfids/Makefile +++ b/tools/bpf/resolve_btfids/Makefile @@ -56,13 +56,17 @@ $(BPFOBJ): $(wildcard $(LIBBPF_SRC)/*.[ch] $(LIBBPF_SRC)/Makefile) | $(LIBBPF_OU DESTDIR=$(LIBBPF_DESTDIR) prefix= EXTRA_CFLAGS="$(CFLAGS)" \ $(abspath $@) install_headers +LIBELF_FLAGS := $(shell $(HOSTPKG_CONFIG) libelf --cflags 2>/dev/null) +LIBELF_LIBS := $(shell $(HOSTPKG_CONFIG) libelf --libs 2>/dev/null || echo -lelf) + CFLAGS += -g \ -I$(srctree)/tools/include \ -I$(srctree)/tools/include/uapi \ -I$(LIBBPF_INCLUDE) \ - -I$(SUBCMD_SRC) + -I$(SUBCMD_SRC) \ + $(LIBELF_FLAGS) -LIBS = -lelf -lz +LIBS = $(LIBELF_LIBS) -lz export srctree OUTPUT CFLAGS Q include $(srctree)/tools/build/Makefile.include diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h index 464ca3f01fe7..bc1a3d232ae4 100644 --- a/tools/include/uapi/linux/bpf.h +++ b/tools/include/uapi/linux/bpf.h @@ -2001,6 +2001,9 @@ union bpf_attr { * sending the packet. This flag was added for GRE * encapsulation, but might be used with other protocols * as well in the future. + * **BPF_F_NO_TUNNEL_KEY** + * Add a flag to tunnel metadata indicating that no tunnel + * key should be set in the resulting tunnel header. * * Here is a typical usage on the transmit path: * @@ -5764,6 +5767,7 @@ enum { BPF_F_ZERO_CSUM_TX = (1ULL << 1), BPF_F_DONT_FRAGMENT = (1ULL << 2), BPF_F_SEQ_NUMBER = (1ULL << 3), + BPF_F_NO_TUNNEL_KEY = (1ULL << 4), }; /* BPF_FUNC_skb_get_tunnel_key flags. */ diff --git a/tools/lib/bpf/bpf_tracing.h b/tools/lib/bpf/bpf_tracing.h index 2972dc25ff72..bdb0f6b5be84 100644 --- a/tools/lib/bpf/bpf_tracing.h +++ b/tools/lib/bpf/bpf_tracing.h @@ -32,6 +32,9 @@ #elif defined(__TARGET_ARCH_arc) #define bpf_target_arc #define bpf_target_defined +#elif defined(__TARGET_ARCH_loongarch) + #define bpf_target_loongarch + #define bpf_target_defined #else /* Fall back to what the compiler says */ @@ -62,6 +65,9 @@ #elif defined(__arc__) #define bpf_target_arc #define bpf_target_defined +#elif defined(__loongarch__) + #define bpf_target_loongarch + #define bpf_target_defined #endif /* no compiler target */ #endif @@ -137,7 +143,7 @@ struct pt_regs___s390 { #define __PT_PARM3_REG gprs[4] #define __PT_PARM4_REG gprs[5] #define __PT_PARM5_REG gprs[6] -#define __PT_RET_REG grps[14] +#define __PT_RET_REG gprs[14] #define __PT_FP_REG gprs[11] /* Works only with CONFIG_FRAME_POINTER */ #define __PT_RC_REG gprs[2] #define __PT_SP_REG gprs[15] @@ -258,6 +264,23 @@ struct pt_regs___arm64 { /* arc does not select ARCH_HAS_SYSCALL_WRAPPER. */ #define PT_REGS_SYSCALL_REGS(ctx) ctx +#elif defined(bpf_target_loongarch) + +/* https://loongson.github.io/LoongArch-Documentation/LoongArch-ELF-ABI-EN.html */ + +#define __PT_PARM1_REG regs[4] +#define __PT_PARM2_REG regs[5] +#define __PT_PARM3_REG regs[6] +#define __PT_PARM4_REG regs[7] +#define __PT_PARM5_REG regs[8] +#define __PT_RET_REG regs[1] +#define __PT_FP_REG regs[22] +#define __PT_RC_REG regs[4] +#define __PT_SP_REG regs[3] +#define __PT_IP_REG csr_era +/* loongarch does not select ARCH_HAS_SYSCALL_WRAPPER. */ +#define PT_REGS_SYSCALL_REGS(ctx) ctx + #endif #if defined(bpf_target_defined) diff --git a/tools/lib/bpf/btf.c b/tools/lib/bpf/btf.c index 71e165b09ed5..64841117fbb2 100644 --- a/tools/lib/bpf/btf.c +++ b/tools/lib/bpf/btf.c @@ -688,8 +688,21 @@ int btf__align_of(const struct btf *btf, __u32 id) if (align <= 0) return libbpf_err(align); max_align = max(max_align, align); + + /* if field offset isn't aligned according to field + * type's alignment, then struct must be packed + */ + if (btf_member_bitfield_size(t, i) == 0 && + (m->offset % (8 * align)) != 0) + return 1; } + /* if struct/union size isn't a multiple of its alignment, + * then struct must be packed + */ + if ((t->size % max_align) != 0) + return 1; + return max_align; } default: @@ -990,7 +1003,8 @@ static struct btf *btf_parse_elf(const char *path, struct btf *base_btf, err = 0; if (!btf_data) { - err = -ENOENT; + pr_warn("failed to find '%s' ELF section in %s\n", BTF_ELF_SEC, path); + err = -ENODATA; goto done; } btf = btf_new(btf_data->d_buf, btf_data->d_size, base_btf); diff --git a/tools/lib/bpf/btf_dump.c b/tools/lib/bpf/btf_dump.c index deb2bc9a0a7b..580985ee5545 100644 --- a/tools/lib/bpf/btf_dump.c +++ b/tools/lib/bpf/btf_dump.c @@ -13,6 +13,7 @@ #include <ctype.h> #include <endian.h> #include <errno.h> +#include <limits.h> #include <linux/err.h> #include <linux/btf.h> #include <linux/kernel.h> @@ -833,14 +834,9 @@ static bool btf_is_struct_packed(const struct btf *btf, __u32 id, const struct btf_type *t) { const struct btf_member *m; - int align, i, bit_sz; + int max_align = 1, align, i, bit_sz; __u16 vlen; - align = btf__align_of(btf, id); - /* size of a non-packed struct has to be a multiple of its alignment*/ - if (align && t->size % align) - return true; - m = btf_members(t); vlen = btf_vlen(t); /* all non-bitfield fields have to be naturally aligned */ @@ -849,8 +845,11 @@ static bool btf_is_struct_packed(const struct btf *btf, __u32 id, bit_sz = btf_member_bitfield_size(t, i); if (align && bit_sz == 0 && m->offset % (8 * align) != 0) return true; + max_align = max(align, max_align); } - + /* size of a non-packed struct has to be a multiple of its alignment */ + if (t->size % max_align != 0) + return true; /* * if original struct was marked as packed, but its layout is * naturally aligned, we'll detect that it's not packed @@ -858,44 +857,97 @@ static bool btf_is_struct_packed(const struct btf *btf, __u32 id, return false; } -static int chip_away_bits(int total, int at_most) -{ - return total % at_most ? : at_most; -} - static void btf_dump_emit_bit_padding(const struct btf_dump *d, - int cur_off, int m_off, int m_bit_sz, - int align, int lvl) + int cur_off, int next_off, int next_align, + bool in_bitfield, int lvl) { - int off_diff = m_off - cur_off; - int ptr_bits = d->ptr_sz * 8; + const struct { + const char *name; + int bits; + } pads[] = { + {"long", d->ptr_sz * 8}, {"int", 32}, {"short", 16}, {"char", 8} + }; + int new_off, pad_bits, bits, i; + const char *pad_type; + + if (cur_off >= next_off) + return; /* no gap */ + + /* For filling out padding we want to take advantage of + * natural alignment rules to minimize unnecessary explicit + * padding. First, we find the largest type (among long, int, + * short, or char) that can be used to force naturally aligned + * boundary. Once determined, we'll use such type to fill in + * the remaining padding gap. In some cases we can rely on + * compiler filling some gaps, but sometimes we need to force + * alignment to close natural alignment with markers like + * `long: 0` (this is always the case for bitfields). Note + * that even if struct itself has, let's say 4-byte alignment + * (i.e., it only uses up to int-aligned types), using `long: + * X;` explicit padding doesn't actually change struct's + * overall alignment requirements, but compiler does take into + * account that type's (long, in this example) natural + * alignment requirements when adding implicit padding. We use + * this fact heavily and don't worry about ruining correct + * struct alignment requirement. + */ + for (i = 0; i < ARRAY_SIZE(pads); i++) { + pad_bits = pads[i].bits; + pad_type = pads[i].name; - if (off_diff <= 0) - /* no gap */ - return; - if (m_bit_sz == 0 && off_diff < align * 8) - /* natural padding will take care of a gap */ - return; + new_off = roundup(cur_off, pad_bits); + if (new_off <= next_off) + break; + } - while (off_diff > 0) { - const char *pad_type; - int pad_bits; - - if (ptr_bits > 32 && off_diff > 32) { - pad_type = "long"; - pad_bits = chip_away_bits(off_diff, ptr_bits); - } else if (off_diff > 16) { - pad_type = "int"; - pad_bits = chip_away_bits(off_diff, 32); - } else if (off_diff > 8) { - pad_type = "short"; - pad_bits = chip_away_bits(off_diff, 16); - } else { - pad_type = "char"; - pad_bits = chip_away_bits(off_diff, 8); + if (new_off > cur_off && new_off <= next_off) { + /* We need explicit `<type>: 0` aligning mark if next + * field is right on alignment offset and its + * alignment requirement is less strict than <type>'s + * alignment (so compiler won't naturally align to the + * offset we expect), or if subsequent `<type>: X`, + * will actually completely fit in the remaining hole, + * making compiler basically ignore `<type>: X` + * completely. + */ + if (in_bitfield || + (new_off == next_off && roundup(cur_off, next_align * 8) != new_off) || + (new_off != next_off && next_off - new_off <= new_off - cur_off)) + /* but for bitfields we'll emit explicit bit count */ + btf_dump_printf(d, "\n%s%s: %d;", pfx(lvl), pad_type, + in_bitfield ? new_off - cur_off : 0); + cur_off = new_off; + } + + /* Now we know we start at naturally aligned offset for a chosen + * padding type (long, int, short, or char), and so the rest is just + * a straightforward filling of remaining padding gap with full + * `<type>: sizeof(<type>);` markers, except for the last one, which + * might need smaller than sizeof(<type>) padding. + */ + while (cur_off != next_off) { + bits = min(next_off - cur_off, pad_bits); + if (bits == pad_bits) { + btf_dump_printf(d, "\n%s%s: %d;", pfx(lvl), pad_type, pad_bits); + cur_off += bits; + continue; + } + /* For the remainder padding that doesn't cover entire + * pad_type bit length, we pick the smallest necessary type. + * This is pure aesthetics, we could have just used `long`, + * but having smallest necessary one communicates better the + * scale of the padding gap. + */ + for (i = ARRAY_SIZE(pads) - 1; i >= 0; i--) { + pad_type = pads[i].name; + pad_bits = pads[i].bits; + if (pad_bits < bits) + continue; + + btf_dump_printf(d, "\n%s%s: %d;", pfx(lvl), pad_type, bits); + cur_off += bits; + break; } - btf_dump_printf(d, "\n%s%s: %d;", pfx(lvl), pad_type, pad_bits); - off_diff -= pad_bits; } } @@ -915,9 +967,11 @@ static void btf_dump_emit_struct_def(struct btf_dump *d, { const struct btf_member *m = btf_members(t); bool is_struct = btf_is_struct(t); - int align, i, packed, off = 0; + bool packed, prev_bitfield = false; + int align, i, off = 0; __u16 vlen = btf_vlen(t); + align = btf__align_of(d->btf, id); packed = is_struct ? btf_is_struct_packed(d->btf, id, t) : 0; btf_dump_printf(d, "%s%s%s {", @@ -927,41 +981,47 @@ static void btf_dump_emit_struct_def(struct btf_dump *d, for (i = 0; i < vlen; i++, m++) { const char *fname; - int m_off, m_sz; + int m_off, m_sz, m_align; + bool in_bitfield; fname = btf_name_of(d, m->name_off); m_sz = btf_member_bitfield_size(t, i); m_off = btf_member_bit_offset(t, i); - align = packed ? 1 : btf__align_of(d->btf, m->type); + m_align = packed ? 1 : btf__align_of(d->btf, m->type); + + in_bitfield = prev_bitfield && m_sz != 0; - btf_dump_emit_bit_padding(d, off, m_off, m_sz, align, lvl + 1); + btf_dump_emit_bit_padding(d, off, m_off, m_align, in_bitfield, lvl + 1); btf_dump_printf(d, "\n%s", pfx(lvl + 1)); btf_dump_emit_type_decl(d, m->type, fname, lvl + 1); if (m_sz) { btf_dump_printf(d, ": %d", m_sz); off = m_off + m_sz; + prev_bitfield = true; } else { m_sz = max((__s64)0, btf__resolve_size(d->btf, m->type)); off = m_off + m_sz * 8; + prev_bitfield = false; } + btf_dump_printf(d, ";"); } /* pad at the end, if necessary */ - if (is_struct) { - align = packed ? 1 : btf__align_of(d->btf, id); - btf_dump_emit_bit_padding(d, off, t->size * 8, 0, align, - lvl + 1); - } + if (is_struct) + btf_dump_emit_bit_padding(d, off, t->size * 8, align, false, lvl + 1); /* * Keep `struct empty {}` on a single line, * only print newline when there are regular or padding fields. */ - if (vlen || t->size) + if (vlen || t->size) { btf_dump_printf(d, "\n"); - btf_dump_printf(d, "%s}", pfx(lvl)); + btf_dump_printf(d, "%s}", pfx(lvl)); + } else { + btf_dump_printf(d, "}"); + } if (packed) btf_dump_printf(d, " __attribute__((packed))"); } @@ -1073,6 +1133,43 @@ static void btf_dump_emit_enum_def(struct btf_dump *d, __u32 id, else btf_dump_emit_enum64_val(d, t, lvl, vlen); btf_dump_printf(d, "\n%s}", pfx(lvl)); + + /* special case enums with special sizes */ + if (t->size == 1) { + /* one-byte enums can be forced with mode(byte) attribute */ + btf_dump_printf(d, " __attribute__((mode(byte)))"); + } else if (t->size == 8 && d->ptr_sz == 8) { + /* enum can be 8-byte sized if one of the enumerator values + * doesn't fit in 32-bit integer, or by adding mode(word) + * attribute (but probably only on 64-bit architectures); do + * our best here to try to satisfy the contract without adding + * unnecessary attributes + */ + bool needs_word_mode; + + if (btf_is_enum(t)) { + /* enum can't represent 64-bit values, so we need word mode */ + needs_word_mode = true; + } else { + /* enum64 needs mode(word) if none of its values has + * non-zero upper 32-bits (which means that all values + * fit in 32-bit integers and won't cause compiler to + * bump enum to be 64-bit naturally + */ + int i; + + needs_word_mode = true; + for (i = 0; i < vlen; i++) { + if (btf_enum64(t)[i].val_hi32 != 0) { + needs_word_mode = false; + break; + } + } + } + if (needs_word_mode) + btf_dump_printf(d, " __attribute__((mode(word)))"); + } + } static void btf_dump_emit_fwd_def(struct btf_dump *d, __u32 id, diff --git a/tools/lib/bpf/libbpf.c b/tools/lib/bpf/libbpf.c index 2a82f49ce16f..a5c67a3c93c5 100644 --- a/tools/lib/bpf/libbpf.c +++ b/tools/lib/bpf/libbpf.c @@ -9903,7 +9903,7 @@ static int perf_event_open_probe(bool uprobe, bool retprobe, const char *name, char errmsg[STRERR_BUFSIZE]; int type, pfd; - if (ref_ctr_off >= (1ULL << PERF_UPROBE_REF_CTR_OFFSET_BITS)) + if ((__u64)ref_ctr_off >= (1ULL << PERF_UPROBE_REF_CTR_OFFSET_BITS)) return -EINVAL; memset(&attr, 0, attr_sz); diff --git a/tools/lib/bpf/libbpf.h b/tools/lib/bpf/libbpf.h index eee883f007f9..898db26e42e9 100644 --- a/tools/lib/bpf/libbpf.h +++ b/tools/lib/bpf/libbpf.h @@ -96,6 +96,12 @@ enum libbpf_print_level { typedef int (*libbpf_print_fn_t)(enum libbpf_print_level level, const char *, va_list ap); +/** + * @brief **libbpf_set_print()** sets user-provided log callback function to + * be used for libbpf warnings and informational messages. + * @param fn The log print function. If NULL, libbpf won't print anything. + * @return Pointer to old print function. + */ LIBBPF_API libbpf_print_fn_t libbpf_set_print(libbpf_print_fn_t fn); /* Hide internal to user */ @@ -174,6 +180,14 @@ struct bpf_object_open_opts { }; #define bpf_object_open_opts__last_field kernel_log_level +/** + * @brief **bpf_object__open()** creates a bpf_object by opening + * the BPF ELF object file pointed to by the passed path and loading it + * into memory. + * @param path BPF object file path. + * @return pointer to the new bpf_object; or NULL is returned on error, + * error code is stored in errno + */ LIBBPF_API struct bpf_object *bpf_object__open(const char *path); /** @@ -203,10 +217,21 @@ LIBBPF_API struct bpf_object * bpf_object__open_mem(const void *obj_buf, size_t obj_buf_sz, const struct bpf_object_open_opts *opts); -/* Load/unload object into/from kernel */ +/** + * @brief **bpf_object__load()** loads BPF object into kernel. + * @param obj Pointer to a valid BPF object instance returned by + * **bpf_object__open*()** APIs + * @return 0, on success; negative error code, otherwise, error code is + * stored in errno + */ LIBBPF_API int bpf_object__load(struct bpf_object *obj); -LIBBPF_API void bpf_object__close(struct bpf_object *object); +/** + * @brief **bpf_object__close()** closes a BPF object and releases all + * resources. + * @param obj Pointer to a valid BPF object + */ +LIBBPF_API void bpf_object__close(struct bpf_object *obj); /* pin_maps and unpin_maps can both be called with a NULL path, in which case * they will use the pin_path attribute of each map (and ignore all maps that diff --git a/tools/lib/bpf/libbpf.map b/tools/lib/bpf/libbpf.map index 71bf5691a689..11c36a3c1a9f 100644 --- a/tools/lib/bpf/libbpf.map +++ b/tools/lib/bpf/libbpf.map @@ -382,3 +382,6 @@ LIBBPF_1.1.0 { user_ring_buffer__reserve_blocking; user_ring_buffer__submit; } LIBBPF_1.0.0; + +LIBBPF_1.2.0 { +} LIBBPF_1.1.0; diff --git a/tools/lib/bpf/libbpf_errno.c b/tools/lib/bpf/libbpf_errno.c index 96f67a772a1b..6b180172ec6b 100644 --- a/tools/lib/bpf/libbpf_errno.c +++ b/tools/lib/bpf/libbpf_errno.c @@ -39,14 +39,14 @@ static const char *libbpf_strerror_table[NR_ERRNO] = { int libbpf_strerror(int err, char *buf, size_t size) { + int ret; + if (!buf || !size) return libbpf_err(-EINVAL); err = err > 0 ? err : -err; if (err < __LIBBPF_ERRNO__START) { - int ret; - ret = strerror_r(err, buf, size); buf[size - 1] = '\0'; return libbpf_err_errno(ret); @@ -56,12 +56,20 @@ int libbpf_strerror(int err, char *buf, size_t size) const char *msg; msg = libbpf_strerror_table[ERRNO_OFFSET(err)]; - snprintf(buf, size, "%s", msg); + ret = snprintf(buf, size, "%s", msg); buf[size - 1] = '\0'; + /* The length of the buf and msg is positive. + * A negative number may be returned only when the + * size exceeds INT_MAX. Not likely to appear. + */ + if (ret >= size) + return libbpf_err(-ERANGE); return 0; } - snprintf(buf, size, "Unknown libbpf error %d", err); + ret = snprintf(buf, size, "Unknown libbpf error %d", err); buf[size - 1] = '\0'; + if (ret >= size) + return libbpf_err(-ERANGE); return libbpf_err(-ENOENT); } diff --git a/tools/lib/bpf/libbpf_internal.h b/tools/lib/bpf/libbpf_internal.h index 377642ff51fc..e4d05662a96c 100644 --- a/tools/lib/bpf/libbpf_internal.h +++ b/tools/lib/bpf/libbpf_internal.h @@ -543,6 +543,7 @@ static inline int ensure_good_fd(int fd) fd = fcntl(fd, F_DUPFD_CLOEXEC, 3); saved_errno = errno; close(old_fd); + errno = saved_errno; if (fd < 0) { pr_warn("failed to dup FD %d to FD > 2: %d\n", old_fd, -saved_errno); errno = saved_errno; diff --git a/tools/lib/bpf/libbpf_version.h b/tools/lib/bpf/libbpf_version.h index e944f5bce728..1fd2eeac5cfc 100644 --- a/tools/lib/bpf/libbpf_version.h +++ b/tools/lib/bpf/libbpf_version.h @@ -4,6 +4,6 @@ #define __LIBBPF_VERSION_H #define LIBBPF_MAJOR_VERSION 1 -#define LIBBPF_MINOR_VERSION 1 +#define LIBBPF_MINOR_VERSION 2 #endif /* __LIBBPF_VERSION_H */ diff --git a/tools/net/ynl/samples/cli.py b/tools/net/ynl/samples/cli.py new file mode 100755 index 000000000000..b27159c70710 --- /dev/null +++ b/tools/net/ynl/samples/cli.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: BSD-3-Clause + +import argparse +import json +import pprint +import time + +from ynl import YnlFamily + + +def main(): + parser = argparse.ArgumentParser(description='YNL CLI sample') + parser.add_argument('--spec', dest='spec', type=str, required=True) + parser.add_argument('--schema', dest='schema', type=str) + parser.add_argument('--json', dest='json_text', type=str) + parser.add_argument('--do', dest='do', type=str) + parser.add_argument('--dump', dest='dump', type=str) + parser.add_argument('--sleep', dest='sleep', type=int) + parser.add_argument('--subscribe', dest='ntf', type=str) + args = parser.parse_args() + + attrs = {} + if args.json_text: + attrs = json.loads(args.json_text) + + ynl = YnlFamily(args.spec, args.schema) + + if args.ntf: + ynl.ntf_subscribe(args.ntf) + + if args.sleep: + time.sleep(args.sleep) + + if args.do or args.dump: + method = getattr(ynl, args.do if args.do else args.dump) + + reply = method(attrs, dump=bool(args.dump)) + pprint.PrettyPrinter().pprint(reply) + + if args.ntf: + ynl.check_ntf() + pprint.PrettyPrinter().pprint(ynl.async_msg_queue) + + +if __name__ == "__main__": + main() diff --git a/tools/net/ynl/samples/ynl.py b/tools/net/ynl/samples/ynl.py new file mode 100644 index 000000000000..b71523d71d46 --- /dev/null +++ b/tools/net/ynl/samples/ynl.py @@ -0,0 +1,534 @@ +# SPDX-License-Identifier: BSD-3-Clause + +import functools +import jsonschema +import os +import random +import socket +import struct +import yaml + +# +# Generic Netlink code which should really be in some library, but I can't quickly find one. +# + + +class Netlink: + # Netlink socket + SOL_NETLINK = 270 + + NETLINK_ADD_MEMBERSHIP = 1 + NETLINK_CAP_ACK = 10 + NETLINK_EXT_ACK = 11 + + # Netlink message + NLMSG_ERROR = 2 + NLMSG_DONE = 3 + + NLM_F_REQUEST = 1 + NLM_F_ACK = 4 + NLM_F_ROOT = 0x100 + NLM_F_MATCH = 0x200 + NLM_F_APPEND = 0x800 + + NLM_F_CAPPED = 0x100 + NLM_F_ACK_TLVS = 0x200 + + NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH + + NLA_F_NESTED = 0x8000 + NLA_F_NET_BYTEORDER = 0x4000 + + NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER + + # Genetlink defines + NETLINK_GENERIC = 16 + + GENL_ID_CTRL = 0x10 + + # nlctrl + CTRL_CMD_GETFAMILY = 3 + + CTRL_ATTR_FAMILY_ID = 1 + CTRL_ATTR_FAMILY_NAME = 2 + CTRL_ATTR_MAXATTR = 5 + CTRL_ATTR_MCAST_GROUPS = 7 + + CTRL_ATTR_MCAST_GRP_NAME = 1 + CTRL_ATTR_MCAST_GRP_ID = 2 + + # Extack types + NLMSGERR_ATTR_MSG = 1 + NLMSGERR_ATTR_OFFS = 2 + NLMSGERR_ATTR_COOKIE = 3 + NLMSGERR_ATTR_POLICY = 4 + NLMSGERR_ATTR_MISS_TYPE = 5 + NLMSGERR_ATTR_MISS_NEST = 6 + + +class NlAttr: + def __init__(self, raw, offset): + self._len, self._type = struct.unpack("HH", raw[offset:offset + 4]) + self.type = self._type & ~Netlink.NLA_TYPE_MASK + self.payload_len = self._len + self.full_len = (self.payload_len + 3) & ~3 + self.raw = raw[offset + 4:offset + self.payload_len] + + def as_u16(self): + return struct.unpack("H", self.raw)[0] + + def as_u32(self): + return struct.unpack("I", self.raw)[0] + + def as_u64(self): + return struct.unpack("Q", self.raw)[0] + + def as_strz(self): + return self.raw.decode('ascii')[:-1] + + def as_bin(self): + return self.raw + + def __repr__(self): + return f"[type:{self.type} len:{self._len}] {self.raw}" + + +class NlAttrs: + def __init__(self, msg): + self.attrs = [] + + offset = 0 + while offset < len(msg): + attr = NlAttr(msg, offset) + offset += attr.full_len + self.attrs.append(attr) + + def __iter__(self): + yield from self.attrs + + def __repr__(self): + msg = '' + for a in self.attrs: + if msg: + msg += '\n' + msg += repr(a) + return msg + + +class NlMsg: + def __init__(self, msg, offset, attr_space=None): + self.hdr = msg[offset:offset + 16] + + self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ + struct.unpack("IHHII", self.hdr) + + self.raw = msg[offset + 16:offset + self.nl_len] + + self.error = 0 + self.done = 0 + + extack_off = None + if self.nl_type == Netlink.NLMSG_ERROR: + self.error = struct.unpack("i", self.raw[0:4])[0] + self.done = 1 + extack_off = 20 + elif self.nl_type == Netlink.NLMSG_DONE: + self.done = 1 + extack_off = 4 + + self.extack = None + if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: + self.extack = dict() + extack_attrs = NlAttrs(self.raw[extack_off:]) + for extack in extack_attrs: + if extack.type == Netlink.NLMSGERR_ATTR_MSG: + self.extack['msg'] = extack.as_strz() + elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: + self.extack['miss-type'] = extack.as_u32() + elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: + self.extack['miss-nest'] = extack.as_u32() + elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: + self.extack['bad-attr-offs'] = extack.as_u32() + else: + if 'unknown' not in self.extack: + self.extack['unknown'] = [] + self.extack['unknown'].append(extack) + + if attr_space: + # We don't have the ability to parse nests yet, so only do global + if 'miss-type' in self.extack and 'miss-nest' not in self.extack: + miss_type = self.extack['miss-type'] + if len(attr_space.attr_list) > miss_type: + spec = attr_space.attr_list[miss_type] + desc = spec['name'] + if 'doc' in spec: + desc += f" ({spec['doc']})" + self.extack['miss-type'] = desc + + def __repr__(self): + msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" + if self.error: + msg += '\terror: ' + str(self.error) + if self.extack: + msg += '\textack: ' + repr(self.extack) + return msg + + +class NlMsgs: + def __init__(self, data, attr_space=None): + self.msgs = [] + + offset = 0 + while offset < len(data): + msg = NlMsg(data, offset, attr_space=attr_space) + offset += msg.nl_len + self.msgs.append(msg) + + def __iter__(self): + yield from self.msgs + + +genl_family_name_to_id = None + + +def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): + # we prepend length in _genl_msg_finalize() + if seq is None: + seq = random.randint(1, 1024) + nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) + genlmsg = struct.pack("bbH", genl_cmd, genl_version, 0) + return nlmsg + genlmsg + + +def _genl_msg_finalize(msg): + return struct.pack("I", len(msg) + 4) + msg + + +def _genl_load_families(): + with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: + sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) + + msg = _genl_msg(Netlink.GENL_ID_CTRL, + Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, + Netlink.CTRL_CMD_GETFAMILY, 1) + msg = _genl_msg_finalize(msg) + + sock.send(msg, 0) + + global genl_family_name_to_id + genl_family_name_to_id = dict() + + while True: + reply = sock.recv(128 * 1024) + nms = NlMsgs(reply) + for nl_msg in nms: + if nl_msg.error: + print("Netlink error:", nl_msg.error) + return + if nl_msg.done: + return + + gm = GenlMsg(nl_msg) + fam = dict() + for attr in gm.raw_attrs: + if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: + fam['id'] = attr.as_u16() + elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: + fam['name'] = attr.as_strz() + elif attr.type == Netlink.CTRL_ATTR_MAXATTR: + fam['maxattr'] = attr.as_u32() + elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: + fam['mcast'] = dict() + for entry in NlAttrs(attr.raw): + mcast_name = None + mcast_id = None + for entry_attr in NlAttrs(entry.raw): + if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: + mcast_name = entry_attr.as_strz() + elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: + mcast_id = entry_attr.as_u32() + if mcast_name and mcast_id is not None: + fam['mcast'][mcast_name] = mcast_id + if 'name' in fam and 'id' in fam: + genl_family_name_to_id[fam['name']] = fam + + +class GenlMsg: + def __init__(self, nl_msg): + self.nl = nl_msg + + self.hdr = nl_msg.raw[0:4] + self.raw = nl_msg.raw[4:] + + self.genl_cmd, self.genl_version, _ = struct.unpack("bbH", self.hdr) + + self.raw_attrs = NlAttrs(self.raw) + + def __repr__(self): + msg = repr(self.nl) + msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" + for a in self.raw_attrs: + msg += '\t\t' + repr(a) + '\n' + return msg + + +class GenlFamily: + def __init__(self, family_name): + self.family_name = family_name + + global genl_family_name_to_id + if genl_family_name_to_id is None: + _genl_load_families() + + self.genl_family = genl_family_name_to_id[family_name] + self.family_id = genl_family_name_to_id[family_name]['id'] + + +# +# YNL implementation details. +# + + +class YnlAttrSpace: + def __init__(self, family, yaml): + self.yaml = yaml + + self.attrs = dict() + self.name = self.yaml['name'] + self.subspace_of = self.yaml['subset-of'] if 'subspace-of' in self.yaml else None + + val = 0 + max_val = 0 + for elem in self.yaml['attributes']: + if 'value' in elem: + val = elem['value'] + else: + elem['value'] = val + if val > max_val: + max_val = val + val += 1 + + self.attrs[elem['name']] = elem + + self.attr_list = [None] * (max_val + 1) + for elem in self.yaml['attributes']: + self.attr_list[elem['value']] = elem + + def __getitem__(self, key): + return self.attrs[key] + + def __contains__(self, key): + return key in self.yaml + + def __iter__(self): + yield from self.attrs + + def items(self): + return self.attrs.items() + + +class YnlFamily: + def __init__(self, def_path, schema=None): + self.include_raw = False + + with open(def_path, "r") as stream: + self.yaml = yaml.safe_load(stream) + + if schema: + with open(schema, "r") as stream: + schema = yaml.safe_load(stream) + + jsonschema.validate(self.yaml, schema) + + self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) + self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) + self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) + + self._ops = dict() + self._spaces = dict() + self._types = dict() + + for elem in self.yaml['attribute-sets']: + self._spaces[elem['name']] = YnlAttrSpace(self, elem) + + for elem in self.yaml['definitions']: + self._types[elem['name']] = elem + + async_separation = 'async-prefix' in self.yaml['operations'] + self.async_msg_ids = set() + self.async_msg_queue = [] + val = 0 + max_val = 0 + for elem in self.yaml['operations']['list']: + if not (async_separation and ('notify' in elem or 'event' in elem)): + if 'value' in elem: + val = elem['value'] + else: + elem['value'] = val + val += 1 + max_val = max(val, max_val) + + if 'notify' in elem or 'event' in elem: + self.async_msg_ids.add(elem['value']) + + self._ops[elem['name']] = elem + + op_name = elem['name'].replace('-', '_') + + bound_f = functools.partial(self._op, elem['name']) + setattr(self, op_name, bound_f) + + self._op_array = [None] * max_val + for _, op in self._ops.items(): + self._op_array[op['value']] = op + if 'notify' in op: + op['attribute-set'] = self._ops[op['notify']]['attribute-set'] + + self.family = GenlFamily(self.yaml['name']) + + def ntf_subscribe(self, mcast_name): + if mcast_name not in self.family.genl_family['mcast']: + raise Exception(f'Multicast group "{mcast_name}" not present in the family') + + self.sock.bind((0, 0)) + self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, + self.family.genl_family['mcast'][mcast_name]) + + def _add_attr(self, space, name, value): + attr = self._spaces[space][name] + nl_type = attr['value'] + if attr["type"] == 'nest': + nl_type |= Netlink.NLA_F_NESTED + attr_payload = b'' + for subname, subvalue in value.items(): + attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) + elif attr["type"] == 'u32': + attr_payload = struct.pack("I", int(value)) + elif attr["type"] == 'string': + attr_payload = str(value).encode('ascii') + b'\x00' + elif attr["type"] == 'binary': + attr_payload = value + else: + raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') + + pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) + return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad + + def _decode_enum(self, rsp, attr_spec): + raw = rsp[attr_spec['name']] + enum = self._types[attr_spec['enum']] + i = attr_spec.get('value-start', 0) + if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']: + value = set() + while raw: + if raw & 1: + value.add(enum['entries'][i]) + raw >>= 1 + i += 1 + else: + value = enum['entries'][raw - i] + rsp[attr_spec['name']] = value + + def _decode(self, attrs, space): + attr_space = self._spaces[space] + rsp = dict() + for attr in attrs: + attr_spec = attr_space.attr_list[attr.type] + if attr_spec["type"] == 'nest': + subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) + rsp[attr_spec['name']] = subdict + elif attr_spec['type'] == 'u32': + rsp[attr_spec['name']] = attr.as_u32() + elif attr_spec['type'] == 'u64': + rsp[attr_spec['name']] = attr.as_u64() + elif attr_spec["type"] == 'string': + rsp[attr_spec['name']] = attr.as_strz() + elif attr_spec["type"] == 'binary': + rsp[attr_spec['name']] = attr.as_bin() + else: + raise Exception(f'Unknown {attr.type} {attr_spec["name"]} {attr_spec["type"]}') + + if 'enum' in attr_spec: + self._decode_enum(rsp, attr_spec) + return rsp + + def handle_ntf(self, nl_msg, genl_msg): + msg = dict() + if self.include_raw: + msg['nlmsg'] = nl_msg + msg['genlmsg'] = genl_msg + op = self._op_array[genl_msg.genl_cmd] + msg['name'] = op['name'] + msg['msg'] = self._decode(genl_msg.raw_attrs, op['attribute-set']) + self.async_msg_queue.append(msg) + + def check_ntf(self): + while True: + try: + reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) + except BlockingIOError: + return + + nms = NlMsgs(reply) + for nl_msg in nms: + if nl_msg.error: + print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) + print(nl_msg) + continue + if nl_msg.done: + print("Netlink done while checking for ntf!?") + continue + + gm = GenlMsg(nl_msg) + if gm.genl_cmd not in self.async_msg_ids: + print("Unexpected msg id done while checking for ntf", gm) + continue + + self.handle_ntf(nl_msg, gm) + + def _op(self, method, vals, dump=False): + op = self._ops[method] + + nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK + if dump: + nl_flags |= Netlink.NLM_F_DUMP + + req_seq = random.randint(1024, 65535) + msg = _genl_msg(self.family.family_id, nl_flags, op['value'], 1, req_seq) + for name, value in vals.items(): + msg += self._add_attr(op['attribute-set'], name, value) + msg = _genl_msg_finalize(msg) + + self.sock.send(msg, 0) + + done = False + rsp = [] + while not done: + reply = self.sock.recv(128 * 1024) + nms = NlMsgs(reply, attr_space=self._spaces[op['attribute-set']]) + for nl_msg in nms: + if nl_msg.error: + print("Netlink error:", os.strerror(-nl_msg.error)) + print(nl_msg) + return + if nl_msg.done: + done = True + break + + gm = GenlMsg(nl_msg) + # Check if this is a reply to our request + if nl_msg.nl_seq != req_seq or gm.genl_cmd != op['value']: + if gm.genl_cmd in self.async_msg_ids: + self.handle_ntf(nl_msg, gm) + continue + else: + print('Unexpected message: ' + repr(gm)) + continue + + rsp.append(self._decode(gm.raw_attrs, op['attribute-set'])) + + if not rsp: + return None + if not dump and len(rsp) == 1: + return rsp[0] + return rsp diff --git a/tools/net/ynl/ynl-gen-c.py b/tools/net/ynl/ynl-gen-c.py new file mode 100755 index 000000000000..1aa872e582ab --- /dev/null +++ b/tools/net/ynl/ynl-gen-c.py @@ -0,0 +1,2379 @@ +#!/usr/bin/env python + +import argparse +import collections +import jsonschema +import os +import yaml + + +def c_upper(name): + return name.upper().replace('-', '_') + + +def c_lower(name): + return name.lower().replace('-', '_') + + +class BaseNlLib: + def get_family_id(self): + return 'ys->family_id' + + def parse_cb_run(self, cb, data, is_dump=False, indent=1): + ind = '\n\t\t' + '\t' * indent + ' ' + if is_dump: + return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)" + else: + return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \ + "ynl_cb_array, NLMSG_MIN_TYPE)" + + +class Type: + def __init__(self, family, attr_set, attr): + self.family = family + self.attr = attr + self.value = attr['value'] + self.name = c_lower(attr['name']) + self.type = attr['type'] + self.checks = attr.get('checks', {}) + + if 'len' in attr: + self.len = attr['len'] + if 'nested-attributes' in attr: + self.nested_attrs = attr['nested-attributes'] + if self.nested_attrs == family.name: + self.nested_render_name = f"{family.name}" + else: + self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}" + + self.enum_name = f"{attr_set.name_prefix}{self.name}" + self.enum_name = c_upper(self.enum_name) + self.c_name = c_lower(self.name) + if self.c_name in _C_KW: + self.c_name += '_' + + def __getitem__(self, key): + return self.attr[key] + + def __contains__(self, key): + return key in self.attr + + def is_multi_val(self): + return None + + def is_scalar(self): + return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'} + + def presence_type(self): + return 'bit' + + def presence_member(self, space, type_filter): + if self.presence_type() != type_filter: + return + + if self.presence_type() == 'bit': + pfx = '__' if space == 'user' else '' + return f"{pfx}u32 {self.c_name}:1;" + + if self.presence_type() == 'len': + pfx = '__' if space == 'user' else '' + return f"{pfx}u32 {self.c_name}_len;" + + def _complex_member_type(self, ri): + return None + + def free_needs_iter(self): + return False + + def free(self, ri, var, ref): + if self.is_multi_val() or self.presence_type() == 'len': + ri.cw.p(f'free({var}->{ref}{self.c_name});') + + def arg_member(self, ri): + member = self._complex_member_type(ri) + if member: + return [member + ' *' + self.c_name] + raise Exception(f"Struct member not implemented for class type {self.type}") + + def struct_member(self, ri): + if self.is_multi_val(): + ri.cw.p(f"unsigned int n_{self.c_name};") + member = self._complex_member_type(ri) + if member: + ptr = '*' if self.is_multi_val() else '' + ri.cw.p(f"{member} {ptr}{self.c_name};") + return + members = self.arg_member(ri) + for one in members: + ri.cw.p(one + ';') + + def _attr_policy(self, policy): + return '{ .type = ' + policy + ', }' + + def attr_policy(self, cw): + policy = c_upper('nla-' + self.attr['type']) + + spec = self._attr_policy(policy) + cw.p(f"\t[{self.enum_name}] = {spec},") + + def _attr_typol(self): + raise Exception(f"Type policy not implemented for class type {self.type}") + + def attr_typol(self, cw): + typol = self._attr_typol() + cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},') + + def _attr_put_line(self, ri, var, line): + if self.presence_type() == 'bit': + ri.cw.p(f"if ({var}->_present.{self.c_name})") + elif self.presence_type() == 'len': + ri.cw.p(f"if ({var}->_present.{self.c_name}_len)") + ri.cw.p(f"{line};") + + def _attr_put_simple(self, ri, var, put_type): + line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})" + self._attr_put_line(ri, var, line) + + def attr_put(self, ri, var): + raise Exception(f"Put not implemented for class type {self.type}") + + def _attr_get(self, ri, var): + raise Exception(f"Attr get not implemented for class type {self.type}") + + def attr_get(self, ri, var, first): + lines, init_lines, local_vars = self._attr_get(ri, var) + if type(lines) is str: + lines = [lines] + if type(init_lines) is str: + init_lines = [init_lines] + + kw = 'if' if first else 'else if' + ri.cw.block_start(line=f"{kw} (mnl_attr_get_type(attr) == {self.enum_name})") + if local_vars: + for local in local_vars: + ri.cw.p(local) + ri.cw.nl() + + if not self.is_multi_val(): + ri.cw.p("if (ynl_attr_validate(yarg, attr))") + ri.cw.p("return MNL_CB_ERROR;") + if self.presence_type() == 'bit': + ri.cw.p(f"{var}->_present.{self.c_name} = 1;") + + if init_lines: + ri.cw.nl() + for line in init_lines: + ri.cw.p(line) + + for line in lines: + ri.cw.p(line) + ri.cw.block_end() + + def _setter_lines(self, ri, member, presence): + raise Exception(f"Setter not implemented for class type {self.type}") + + def setter(self, ri, space, direction, deref=False, ref=None): + ref = (ref if ref else []) + [self.c_name] + var = "req" + member = f"{var}->{'.'.join(ref)}" + + code = [] + presence = '' + for i in range(0, len(ref)): + presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}" + if self.presence_type() == 'bit': + code.append(presence + ' = 1;') + code += self._setter_lines(ri, member, presence) + + ri.cw.write_func('static inline void', + f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}", + body=code, + args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri)) + + +class TypeUnused(Type): + def presence_type(self): + return '' + + def _attr_typol(self): + return '.type = YNL_PT_REJECT, ' + + def attr_policy(self, cw): + pass + + +class TypePad(Type): + def presence_type(self): + return '' + + def _attr_typol(self): + return '.type = YNL_PT_REJECT, ' + + def attr_policy(self, cw): + pass + + +class TypeScalar(Type): + def __init__(self, family, attr_set, attr): + super().__init__(family, attr_set, attr) + + self.is_bitfield = False + if 'enum' in self.attr: + self.is_bitfield = family.consts[self.attr['enum']]['type'] == 'flags' + if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']: + self.is_bitfield = True + + if 'enum' in self.attr and not self.is_bitfield: + self.type_name = f"enum {family.name}_{c_lower(self.attr['enum'])}" + else: + self.type_name = '__' + self.type + + self.byte_order_comment = '' + if 'byte-order' in attr: + self.byte_order_comment = f" /* {attr['byte-order']} */" + + def _mnl_type(self): + t = self.type + # mnl does not have a helper for signed types + if t[0] == 's': + t = 'u' + t[1:] + return t + + def _attr_policy(self, policy): + if 'flags-mask' in self.checks or self.is_bitfield: + if self.is_bitfield: + mask = self.family.consts[self.attr['enum']].get_mask() + else: + flags = self.family.consts[self.checks['flags-mask']] + flag_cnt = len(flags['entries']) + mask = (1 << flag_cnt) - 1 + return f"NLA_POLICY_MASK({policy}, 0x{mask:x})" + elif 'min' in self.checks: + return f"NLA_POLICY_MIN({policy}, {self.checks['min']})" + elif 'enum' in self.attr: + enum = self.family.consts[self.attr['enum']] + cnt = len(enum['entries']) + return f"NLA_POLICY_MAX({policy}, {cnt - 1})" + return super()._attr_policy(policy) + + def _attr_typol(self): + return f'.type = YNL_PT_U{self.type[1:]}, ' + + def arg_member(self, ri): + return [f'{self.type_name} {self.c_name}{self.byte_order_comment}'] + + def attr_put(self, ri, var): + self._attr_put_simple(ri, var, self._mnl_type()) + + def _attr_get(self, ri, var): + return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None + + def _setter_lines(self, ri, member, presence): + return [f"{member} = {self.c_name};"] + + +class TypeFlag(Type): + def arg_member(self, ri): + return [] + + def _attr_typol(self): + return '.type = YNL_PT_FLAG, ' + + def attr_put(self, ri, var): + self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)") + + def _attr_get(self, ri, var): + return [], None, None + + def _setter_lines(self, ri, member, presence): + return [] + + +class TypeString(Type): + def arg_member(self, ri): + return [f"const char *{self.c_name}"] + + def presence_type(self): + return 'len' + + def struct_member(self, ri): + ri.cw.p(f"char *{self.c_name};") + + def _attr_typol(self): + return f'.type = YNL_PT_NUL_STR, ' + + def _attr_policy(self, policy): + mem = '{ .type = ' + policy + if 'max-len' in self.checks: + mem += ', .len = ' + str(self.checks['max-len']) + mem += ', }' + return mem + + def attr_policy(self, cw): + if self.checks.get('unterminated-ok', False): + policy = 'NLA_STRING' + else: + policy = 'NLA_NUL_STRING' + + spec = self._attr_policy(policy) + cw.p(f"\t[{self.enum_name}] = {spec},") + + def attr_put(self, ri, var): + self._attr_put_simple(ri, var, 'strz') + + def _attr_get(self, ri, var): + len_mem = var + '->_present.' + self.c_name + '_len' + return [f"{len_mem} = len;", + f"{var}->{self.c_name} = malloc(len + 1);", + f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);", + f"{var}->{self.c_name}[len] = 0;"], \ + ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \ + ['unsigned int len;'] + + def _setter_lines(self, ri, member, presence): + return [f"free({member});", + f"{presence}_len = strlen({self.c_name});", + f"{member} = malloc({presence}_len + 1);", + f'memcpy({member}, {self.c_name}, {presence}_len);', + f'{member}[{presence}_len] = 0;'] + + +class TypeBinary(Type): + def arg_member(self, ri): + return [f"const void *{self.c_name}", 'size_t len'] + + def presence_type(self): + return 'len' + + def struct_member(self, ri): + ri.cw.p(f"void *{self.c_name};") + + def _attr_typol(self): + return f'.type = YNL_PT_BINARY,' + + def _attr_policy(self, policy): + mem = '{ ' + if len(self.checks) == 1 and 'min-len' in self.checks: + mem += '.len = ' + str(self.checks['min-len']) + elif len(self.checks) == 0: + mem += '.type = NLA_BINARY' + else: + raise Exception('One or more of binary type checks not implemented, yet') + mem += ', }' + return mem + + def attr_put(self, ri, var): + self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " + + f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})") + + def _attr_get(self, ri, var): + len_mem = var + '->_present.' + self.c_name + '_len' + return [f"{len_mem} = len;", + f"{var}->{self.c_name} = malloc(len);", + f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \ + ['len = mnl_attr_get_payload_len(attr);'], \ + ['unsigned int len;'] + + def _setter_lines(self, ri, member, presence): + return [f"free({member});", + f"{member} = malloc({presence}_len);", + f'memcpy({member}, {self.c_name}, {presence}_len);'] + + +class TypeNest(Type): + def _complex_member_type(self, ri): + return f"struct {self.nested_render_name}" + + def free(self, ri, var, ref): + ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});') + + def _attr_typol(self): + return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, ' + + def _attr_policy(self, policy): + return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)' + + def attr_put(self, ri, var): + self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " + + f"{self.enum_name}, &{var}->{self.c_name})") + + def _attr_get(self, ri, var): + get_lines = [f"{self.nested_render_name}_parse(&parg, attr);"] + init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;", + f"parg.data = &{var}->{self.c_name};"] + return get_lines, init_lines, None + + def setter(self, ri, space, direction, deref=False, ref=None): + ref = (ref if ref else []) + [self.c_name] + + for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list(): + attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref) + + +class TypeMultiAttr(Type): + def is_multi_val(self): + return True + + def presence_type(self): + return 'count' + + def _complex_member_type(self, ri): + if 'type' not in self.attr or self.attr['type'] == 'nest': + return f"struct {self.nested_render_name}" + elif self.attr['type'] in scalars: + scalar_pfx = '__' if ri.ku_space == 'user' else '' + return scalar_pfx + self.attr['type'] + else: + raise Exception(f"Sub-type {self.attr['type']} not supported yet") + + def free_needs_iter(self): + return 'type' not in self.attr or self.attr['type'] == 'nest' + + def free(self, ri, var, ref): + if 'type' not in self.attr or self.attr['type'] == 'nest': + ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)") + ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);') + + def _attr_typol(self): + if 'type' not in self.attr or self.attr['type'] == 'nest': + return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, ' + elif self.attr['type'] in scalars: + return f".type = YNL_PT_U{self.attr['type'][1:]}, " + else: + raise Exception(f"Sub-type {self.attr['type']} not supported yet") + + def _attr_get(self, ri, var): + return f'{var}->n_{self.c_name}++;', None, None + + +class TypeArrayNest(Type): + def is_multi_val(self): + return True + + def presence_type(self): + return 'count' + + def _complex_member_type(self, ri): + if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest': + return f"struct {self.nested_render_name}" + elif self.attr['sub-type'] in scalars: + scalar_pfx = '__' if ri.ku_space == 'user' else '' + return scalar_pfx + self.attr['sub-type'] + else: + raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet") + + def _attr_typol(self): + return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, ' + + def _attr_get(self, ri, var): + local_vars = ['const struct nlattr *attr2;'] + get_lines = [f'attr_{self.c_name} = attr;', + 'mnl_attr_for_each_nested(attr2, attr)', + f'\t{var}->n_{self.c_name}++;'] + return get_lines, None, local_vars + + +class TypeNestTypeValue(Type): + def _complex_member_type(self, ri): + return f"struct {self.nested_render_name}" + + def _attr_typol(self): + return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, ' + + def _attr_get(self, ri, var): + prev = 'attr' + tv_args = '' + get_lines = [] + local_vars = [] + init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;", + f"parg.data = &{var}->{self.c_name};"] + if 'type-value' in self.attr: + tv_names = [c_lower(x) for x in self.attr["type-value"]] + local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};'] + local_vars += [f'__u32 {", ".join(tv_names)};'] + for level in self.attr["type-value"]: + level = c_lower(level) + get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});'] + get_lines += [f'{level} = mnl_attr_get_type(attr_{level});'] + prev = 'attr_' + level + + tv_args = f", {', '.join(tv_names)}" + + get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"] + return get_lines, init_lines, local_vars + + +class Struct: + def __init__(self, family, space_name, type_list=None, inherited=None): + self.family = family + self.space_name = space_name + self.attr_set = family.attr_sets[space_name] + # Use list to catch comparisons with empty sets + self._inherited = inherited if inherited is not None else [] + self.inherited = [] + + self.nested = type_list is None + if family.name == c_lower(space_name): + self.render_name = f"{family.name}" + else: + self.render_name = f"{family.name}_{c_lower(space_name)}" + self.struct_name = 'struct ' + self.render_name + self.ptr_name = self.struct_name + ' *' + + self.request = False + self.reply = False + + self.attr_list = [] + self.attrs = dict() + if type_list: + for t in type_list: + self.attr_list.append((t, self.attr_set[t]),) + else: + for t in self.attr_set: + self.attr_list.append((t, self.attr_set[t]),) + + max_val = 0 + self.attr_max_val = None + for name, attr in self.attr_list: + if attr.value > max_val: + max_val = attr.value + self.attr_max_val = attr + self.attrs[name] = attr + + def __iter__(self): + yield from self.attrs + + def __getitem__(self, key): + return self.attrs[key] + + def member_list(self): + return self.attr_list + + def set_inherited(self, new_inherited): + if self._inherited != new_inherited: + raise Exception("Inheriting different members not supported") + self.inherited = [c_lower(x) for x in sorted(self._inherited)] + + +class EnumEntry: + def __init__(self, enum_set, yaml, prev, value_start): + if isinstance(yaml, str): + self.name = yaml + yaml = {} + self.doc = '' + else: + self.name = yaml['name'] + self.doc = yaml.get('doc', '') + + self.yaml = yaml + self.enum_set = enum_set + self.c_name = c_upper(enum_set.value_pfx + self.name) + + if 'value' in yaml: + self.value = yaml['value'] + if prev: + self.value_change = (self.value != prev.value + 1) + elif prev: + self.value_change = False + self.value = prev.value + 1 + else: + self.value = value_start + self.value_change = (self.value != 0) + + self.value_change = self.value_change or self.enum_set['type'] == 'flags' + + def __getitem__(self, key): + return self.yaml[key] + + def __contains__(self, key): + return key in self.yaml + + def has_doc(self): + return bool(self.doc) + + # raw value, i.e. the id in the enum, unlike user value which is a mask for flags + def raw_value(self): + return self.value + + # user value, same as raw value for enums, for flags it's the mask + def user_value(self): + if self.enum_set['type'] == 'flags': + return 1 << self.value + else: + return self.value + + +class EnumSet: + def __init__(self, family, yaml): + self.yaml = yaml + self.family = family + + self.render_name = c_lower(family.name + '-' + yaml['name']) + self.enum_name = 'enum ' + self.render_name + + self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-") + + self.type = yaml['type'] + + prev_entry = None + value_start = self.yaml.get('value-start', 0) + self.entries = {} + self.entry_list = [] + for entry in self.yaml['entries']: + e = EnumEntry(self, entry, prev_entry, value_start) + self.entries[e.name] = e + self.entry_list.append(e) + prev_entry = e + + def __getitem__(self, key): + return self.yaml[key] + + def __contains__(self, key): + return key in self.yaml + + def has_doc(self): + if 'doc' in self.yaml: + return True + for entry in self.entry_list: + if entry.has_doc(): + return True + return False + + def get_mask(self): + mask = 0 + idx = self.yaml.get('value-start', 0) + for _ in self.entry_list: + mask |= 1 << idx + idx += 1 + return mask + + +class AttrSet: + def __init__(self, family, yaml): + self.yaml = yaml + + self.attrs = dict() + self.name = self.yaml['name'] + if 'subset-of' not in yaml: + self.subset_of = None + if 'name-prefix' in yaml: + pfx = yaml['name-prefix'] + elif self.name == family.name: + pfx = family.name + '-a-' + else: + pfx = f"{family.name}-a-{self.name}-" + self.name_prefix = c_upper(pfx) + self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max")) + else: + self.subset_of = self.yaml['subset-of'] + self.name_prefix = family.attr_sets[self.subset_of].name_prefix + self.max_name = family.attr_sets[self.subset_of].max_name + + self.c_name = c_lower(self.name) + if self.c_name in _C_KW: + self.c_name += '_' + if self.c_name == family.c_name: + self.c_name = '' + + val = 0 + for elem in self.yaml['attributes']: + if 'value' in elem: + val = elem['value'] + else: + elem['value'] = val + val += 1 + + if 'multi-attr' in elem and elem['multi-attr']: + attr = TypeMultiAttr(family, self, elem) + elif elem['type'] in scalars: + attr = TypeScalar(family, self, elem) + elif elem['type'] == 'unused': + attr = TypeUnused(family, self, elem) + elif elem['type'] == 'pad': + attr = TypePad(family, self, elem) + elif elem['type'] == 'flag': + attr = TypeFlag(family, self, elem) + elif elem['type'] == 'string': + attr = TypeString(family, self, elem) + elif elem['type'] == 'binary': + attr = TypeBinary(family, self, elem) + elif elem['type'] == 'nest': + attr = TypeNest(family, self, elem) + elif elem['type'] == 'array-nest': + attr = TypeArrayNest(family, self, elem) + elif elem['type'] == 'nest-type-value': + attr = TypeNestTypeValue(family, self, elem) + else: + raise Exception(f"No typed class for type {elem['type']}") + + self.attrs[elem['name']] = attr + + def __getitem__(self, key): + return self.attrs[key] + + def __contains__(self, key): + return key in self.yaml + + def __iter__(self): + yield from self.attrs + + def items(self): + return self.attrs.items() + + +class Operation: + def __init__(self, family, yaml, value): + self.yaml = yaml + self.value = value + + self.name = self.yaml['name'] + self.render_name = family.name + '_' + c_lower(self.name) + self.is_async = 'notify' in yaml or 'event' in yaml + if not self.is_async: + self.enum_name = family.op_prefix + c_upper(self.name) + else: + self.enum_name = family.async_op_prefix + c_upper(self.name) + + self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \ + ('dump' in yaml and 'request' in yaml['dump']) + + def __getitem__(self, key): + return self.yaml[key] + + def __contains__(self, key): + return key in self.yaml + + def add_notification(self, op): + if 'notify' not in self.yaml: + self.yaml['notify'] = dict() + self.yaml['notify']['reply'] = self.yaml['do']['reply'] + self.yaml['notify']['cmds'] = [] + self.yaml['notify']['cmds'].append(op) + + +class Family: + def __init__(self, file_name): + with open(file_name, "r") as stream: + self.yaml = yaml.safe_load(stream) + + self.proto = self.yaml.get('protocol', 'genetlink') + + with open(os.path.dirname(os.path.dirname(file_name)) + + f'/{self.proto}.yaml', "r") as stream: + schema = yaml.safe_load(stream) + + jsonschema.validate(self.yaml, schema) + + if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}: + raise Exception("Codegen only supported for genetlink") + + self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME')) + self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION')) + + if 'definitions' not in self.yaml: + self.yaml['definitions'] = [] + + self.name = self.yaml['name'] + self.c_name = c_lower(self.name) + if 'uapi-header' in self.yaml: + self.uapi_header = self.yaml['uapi-header'] + else: + self.uapi_header = f"linux/{self.name}.h" + if 'name-prefix' in self.yaml['operations']: + self.op_prefix = c_upper(self.yaml['operations']['name-prefix']) + else: + self.op_prefix = c_upper(self.yaml['name'] + '-cmd-') + if 'async-prefix' in self.yaml['operations']: + self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix']) + else: + self.async_op_prefix = self.op_prefix + + self.mcgrps = self.yaml.get('mcast-groups', {'list': []}) + + self.consts = dict() + # list of all operations + self.msg_list = [] + # dict of operations which have their own message type (have attributes) + self.ops = collections.OrderedDict() + self.attr_sets = dict() + self.attr_sets_list = [] + + self.hooks = dict() + for when in ['pre', 'post']: + self.hooks[when] = dict() + for op_mode in ['do', 'dump']: + self.hooks[when][op_mode] = dict() + self.hooks[when][op_mode]['set'] = set() + self.hooks[when][op_mode]['list'] = [] + + # dict space-name -> 'request': set(attrs), 'reply': set(attrs) + self.root_sets = dict() + # dict space-name -> set('request', 'reply') + self.pure_nested_structs = dict() + self.all_notify = dict() + + self._mock_up_events() + + self._dictify() + self._load_root_sets() + self._load_nested_sets() + self._load_all_notify() + self._load_hooks() + + self.kernel_policy = self.yaml.get('kernel-policy', 'split') + if self.kernel_policy == 'global': + self._load_global_policy() + + def __getitem__(self, key): + return self.yaml[key] + + def get(self, key, default=None): + return self.yaml.get(key, default) + + # Fake a 'do' equivalent of all events, so that we can render their response parsing + def _mock_up_events(self): + for op in self.yaml['operations']['list']: + if 'event' in op: + op['do'] = { + 'reply': { + 'attributes': op['event']['attributes'] + } + } + + def _dictify(self): + for elem in self.yaml['definitions']: + if elem['type'] == 'enum' or elem['type'] == 'flags': + self.consts[elem['name']] = EnumSet(self, elem) + else: + self.consts[elem['name']] = elem + + for elem in self.yaml['attribute-sets']: + attr_set = AttrSet(self, elem) + self.attr_sets[elem['name']] = attr_set + self.attr_sets_list.append((elem['name'], attr_set), ) + + ntf = [] + val = 0 + for elem in self.yaml['operations']['list']: + if 'value' in elem: + val = elem['value'] + + op = Operation(self, elem, val) + val += 1 + + self.msg_list.append(op) + if 'notify' in elem: + ntf.append(op) + continue + if 'attribute-set' not in elem: + continue + self.ops[elem['name']] = op + for n in ntf: + self.ops[n['notify']].add_notification(n) + + def _load_root_sets(self): + for op_name, op in self.ops.items(): + if 'attribute-set' not in op: + continue + + req_attrs = set() + rsp_attrs = set() + for op_mode in ['do', 'dump']: + if op_mode in op and 'request' in op[op_mode]: + req_attrs.update(set(op[op_mode]['request']['attributes'])) + if op_mode in op and 'reply' in op[op_mode]: + rsp_attrs.update(set(op[op_mode]['reply']['attributes'])) + + if op['attribute-set'] not in self.root_sets: + self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs} + else: + self.root_sets[op['attribute-set']]['request'].update(req_attrs) + self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs) + + def _load_nested_sets(self): + for root_set, rs_members in self.root_sets.items(): + for attr, spec in self.attr_sets[root_set].items(): + if 'nested-attributes' in spec: + inherit = set() + nested = spec['nested-attributes'] + if nested not in self.root_sets: + self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit) + if attr in rs_members['request']: + self.pure_nested_structs[nested].request = True + if attr in rs_members['reply']: + self.pure_nested_structs[nested].reply = True + + if 'type-value' in spec: + if nested in self.root_sets: + raise Exception("Inheriting members to a space used as root not supported") + inherit.update(set(spec['type-value'])) + elif spec['type'] == 'array-nest': + inherit.add('idx') + self.pure_nested_structs[nested].set_inherited(inherit) + + def _load_all_notify(self): + for op_name, op in self.ops.items(): + if not op: + continue + + if 'notify' in op: + self.all_notify[op_name] = op['notify']['cmds'] + + def _load_global_policy(self): + global_set = set() + attr_set_name = None + for op_name, op in self.ops.items(): + if not op: + continue + if 'attribute-set' not in op: + continue + + if attr_set_name is None: + attr_set_name = op['attribute-set'] + if attr_set_name != op['attribute-set']: + raise Exception('For a global policy all ops must use the same set') + + for op_mode in {'do', 'dump'}: + if op_mode in op: + global_set.update(op[op_mode].get('request', [])) + + self.global_policy = [] + self.global_policy_set = attr_set_name + for attr in self.attr_sets[attr_set_name]: + if attr in global_set: + self.global_policy.append(attr) + + def _load_hooks(self): + for op in self.ops.values(): + for op_mode in ['do', 'dump']: + if op_mode not in op: + continue + for when in ['pre', 'post']: + if when not in op[op_mode]: + continue + name = op[op_mode][when] + if name in self.hooks[when][op_mode]['set']: + continue + self.hooks[when][op_mode]['set'].add(name) + self.hooks[when][op_mode]['list'].append(name) + + +class RenderInfo: + def __init__(self, cw, family, ku_space, op, op_name, op_mode, attr_set=None): + self.family = family + self.nl = cw.nlib + self.ku_space = ku_space + self.op = op + self.op_name = op_name + self.op_mode = op_mode + + # 'do' and 'dump' response parsing is identical + if op_mode != 'do' and 'dump' in op and 'do' in op and 'reply' in op['do'] and \ + op["do"]["reply"] == op["dump"]["reply"]: + self.type_consistent = True + else: + self.type_consistent = op_mode == 'event' + + self.attr_set = attr_set + if not self.attr_set: + self.attr_set = op['attribute-set'] + + if op: + self.type_name = c_lower(op_name) + else: + self.type_name = c_lower(attr_set) + + self.cw = cw + + self.struct = dict() + for op_dir in ['request', 'reply']: + if op and op_dir in op[op_mode]: + self.struct[op_dir] = Struct(family, self.attr_set, + type_list=op[op_mode][op_dir]['attributes']) + if op_mode == 'event': + self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes']) + + +class CodeWriter: + def __init__(self, nlib, out_file): + self.nlib = nlib + + self._nl = False + self._silent_block = False + self._ind = 0 + self._out = out_file + + @classmethod + def _is_cond(cls, line): + return line.startswith('if') or line.startswith('while') or line.startswith('for') + + def p(self, line, add_ind=0): + if self._nl: + self._out.write('\n') + self._nl = False + ind = self._ind + if line[-1] == ':': + ind -= 1 + if self._silent_block: + ind += 1 + self._silent_block = line.endswith(')') and CodeWriter._is_cond(line) + if add_ind: + ind += add_ind + self._out.write('\t' * ind + line + '\n') + + def nl(self): + self._nl = True + + def block_start(self, line=''): + if line: + line = line + ' ' + self.p(line + '{') + self._ind += 1 + + def block_end(self, line=''): + if line and line[0] not in {';', ','}: + line = ' ' + line + self._ind -= 1 + self.p('}' + line) + + def write_doc_line(self, doc, indent=True): + words = doc.split() + line = ' *' + for word in words: + if len(line) + len(word) >= 79: + self.p(line) + line = ' *' + if indent: + line += ' ' + line += ' ' + word + self.p(line) + + def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''): + if not args: + args = ['void'] + + if doc: + self.p('/*') + self.p(' * ' + doc) + self.p(' */') + + oneline = qual_ret + if qual_ret[-1] != '*': + oneline += ' ' + oneline += f"{name}({', '.join(args)}){suffix}" + + if len(oneline) < 80: + self.p(oneline) + return + + v = qual_ret + if len(v) > 3: + self.p(v) + v = '' + elif qual_ret[-1] != '*': + v += ' ' + v += name + '(' + ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8) + delta_ind = len(v) - len(ind) + v += args[0] + i = 1 + while i < len(args): + next_len = len(v) + len(args[i]) + if v[0] == '\t': + next_len += delta_ind + if next_len > 76: + self.p(v + ',') + v = ind + else: + v += ', ' + v += args[i] + i += 1 + self.p(v + ')' + suffix) + + def write_func_lvar(self, local_vars): + if not local_vars: + return + + if type(local_vars) is str: + local_vars = [local_vars] + + local_vars.sort(key=len, reverse=True) + for var in local_vars: + self.p(var) + self.nl() + + def write_func(self, qual_ret, name, body, args=None, local_vars=None): + self.write_func_prot(qual_ret=qual_ret, name=name, args=args) + self.write_func_lvar(local_vars=local_vars) + + self.block_start() + for line in body: + self.p(line) + self.block_end() + + def writes_defines(self, defines): + longest = 0 + for define in defines: + if len(define[0]) > longest: + longest = len(define[0]) + longest = ((longest + 8) // 8) * 8 + for define in defines: + line = '#define ' + define[0] + line += '\t' * ((longest - len(define[0]) + 7) // 8) + if type(define[1]) is int: + line += str(define[1]) + elif type(define[1]) is str: + line += '"' + define[1] + '"' + self.p(line) + + def write_struct_init(self, members): + longest = max([len(x[0]) for x in members]) + longest += 1 # because we prepend a . + longest = ((longest + 8) // 8) * 8 + for one in members: + line = '.' + one[0] + line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8) + line += '= ' + one[1] + ',' + self.p(line) + + +scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'} + +direction_to_suffix = { + 'reply': '_rsp', + 'request': '_req', + '': '' +} + +op_mode_to_wrapper = { + 'do': '', + 'dump': '_list', + 'notify': '_ntf', + 'event': '', +} + +_C_KW = { + 'do' +} + + +def rdir(direction): + if direction == 'reply': + return 'request' + if direction == 'request': + return 'reply' + return direction + + +def op_prefix(ri, direction, deref=False): + suffix = f"_{ri.type_name}" + + if not ri.op_mode or ri.op_mode == 'do': + suffix += f"{direction_to_suffix[direction]}" + else: + if direction == 'request': + suffix += '_req_dump' + else: + if ri.type_consistent: + if deref: + suffix += f"{direction_to_suffix[direction]}" + else: + suffix += op_mode_to_wrapper[ri.op_mode] + else: + suffix += '_rsp' + suffix += '_dump' if deref else '_list' + + return f"{ri.family['name']}{suffix}" + + +def type_name(ri, direction, deref=False): + return f"struct {op_prefix(ri, direction, deref=deref)}" + + +def print_prototype(ri, direction, terminate=True, doc=None): + suffix = ';' if terminate else '' + + fname = ri.op.render_name + if ri.op_mode == 'dump': + fname += '_dump' + + args = ['struct ynl_sock *ys'] + if 'request' in ri.op[ri.op_mode]: + args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}") + + ret = 'int' + if 'reply' in ri.op[ri.op_mode]: + ret = f"{type_name(ri, rdir(direction))} *" + + ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix) + + +def print_req_prototype(ri): + print_prototype(ri, "request", doc=ri.op['doc']) + + +def print_dump_prototype(ri): + print_prototype(ri, "request") + + +def put_typol_fwd(cw, struct): + cw.p(f'extern struct ynl_policy_nest {struct.render_name}_nest;') + + +def put_typol(cw, struct): + type_max = struct.attr_set.max_name + cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =') + + for _, arg in struct.member_list(): + arg.attr_typol(cw) + + cw.block_end(line=';') + cw.nl() + + cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =') + cw.p(f'.max_attr = {type_max},') + cw.p(f'.table = {struct.render_name}_policy,') + cw.block_end(line=';') + cw.nl() + + +def put_req_nested(ri, struct): + func_args = ['struct nlmsghdr *nlh', + 'unsigned int attr_type', + f'{struct.ptr_name}obj'] + + ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args) + ri.cw.block_start() + ri.cw.write_func_lvar('struct nlattr *nest;') + + ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);") + + for _, arg in struct.member_list(): + arg.attr_put(ri, "obj") + + ri.cw.p("mnl_attr_nest_end(nlh, nest);") + + ri.cw.nl() + ri.cw.p('return 0;') + ri.cw.block_end() + ri.cw.nl() + + +def _multi_parse(ri, struct, init_lines, local_vars): + if struct.nested: + iter_line = "mnl_attr_for_each_nested(attr, nested)" + else: + iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))" + + array_nests = set() + multi_attrs = set() + needs_parg = False + for arg, aspec in struct.member_list(): + if aspec['type'] == 'array-nest': + local_vars.append(f'const struct nlattr *attr_{aspec.c_name};') + array_nests.add(arg) + if 'multi-attr' in aspec: + multi_attrs.add(arg) + needs_parg |= 'nested-attributes' in aspec + if array_nests or multi_attrs: + local_vars.append('int i;') + if needs_parg: + local_vars.append('struct ynl_parse_arg parg;') + init_lines.append('parg.ys = yarg->ys;') + + ri.cw.block_start() + ri.cw.write_func_lvar(local_vars) + + for line in init_lines: + ri.cw.p(line) + ri.cw.nl() + + for arg in struct.inherited: + ri.cw.p(f'dst->{arg} = {arg};') + + ri.cw.nl() + ri.cw.block_start(line=iter_line) + + first = True + for _, arg in struct.member_list(): + arg.attr_get(ri, 'dst', first=first) + first = False + + ri.cw.block_end() + ri.cw.nl() + + for anest in sorted(array_nests): + aspec = struct[anest] + + ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})") + ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));") + ri.cw.p('i = 0;') + ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;") + ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})") + ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];") + ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))") + ri.cw.p('return MNL_CB_ERROR;') + ri.cw.p('i++;') + ri.cw.block_end() + ri.cw.block_end() + ri.cw.nl() + + for anest in sorted(multi_attrs): + aspec = struct[anest] + ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})") + ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));") + ri.cw.p('i = 0;') + if 'nested-attributes' in aspec: + ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;") + ri.cw.block_start(line=iter_line) + ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})") + if 'nested-attributes' in aspec: + ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];") + ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))") + ri.cw.p('return MNL_CB_ERROR;') + elif aspec['type'] in scalars: + t = aspec['type'] + if t[0] == 's': + t = 'u' + t[1:] + ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);") + else: + raise Exception('Nest parsing type not supported yet') + ri.cw.p('i++;') + ri.cw.block_end() + ri.cw.block_end() + ri.cw.block_end() + ri.cw.nl() + + if struct.nested: + ri.cw.p('return 0;') + else: + ri.cw.p('return MNL_CB_OK;') + ri.cw.block_end() + ri.cw.nl() + + +def parse_rsp_nested(ri, struct): + func_args = ['struct ynl_parse_arg *yarg', + 'const struct nlattr *nested'] + for arg in struct.inherited: + func_args.append('__u32 ' + arg) + + local_vars = ['const struct nlattr *attr;', + f'{struct.ptr_name}dst = yarg->data;'] + init_lines = [] + + ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args) + + _multi_parse(ri, struct, init_lines, local_vars) + + +def parse_rsp_msg(ri, deref=False): + if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event': + return + + func_args = ['const struct nlmsghdr *nlh', + 'void *data'] + + local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;', + 'struct ynl_parse_arg *yarg = data;', + 'const struct nlattr *attr;'] + init_lines = ['dst = yarg->data;'] + + ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args) + + _multi_parse(ri, ri.struct["reply"], init_lines, local_vars) + + +def print_req(ri): + ret_ok = '0' + ret_err = '-1' + direction = "request" + local_vars = ['struct nlmsghdr *nlh;', + 'int len, err;'] + + if 'reply' in ri.op[ri.op_mode]: + ret_ok = 'rsp' + ret_err = 'NULL' + local_vars += [f'{type_name(ri, rdir(direction))} *rsp;', + 'struct ynl_parse_arg yarg = { .ys = ys, };'] + + print_prototype(ri, direction, terminate=False) + ri.cw.block_start() + ri.cw.write_func_lvar(local_vars) + + ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);") + + ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;") + if 'reply' in ri.op[ri.op_mode]: + ri.cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;") + ri.cw.nl() + for _, attr in ri.struct["request"].member_list(): + attr.attr_put(ri, "req") + ri.cw.nl() + + ri.cw.p('err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);') + ri.cw.p('if (err < 0)') + ri.cw.p(f"return {ret_err};") + ri.cw.nl() + ri.cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);') + ri.cw.p('if (len < 0)') + ri.cw.p(f"return {ret_err};") + ri.cw.nl() + + if 'reply' in ri.op[ri.op_mode]: + ri.cw.p('rsp = calloc(1, sizeof(*rsp));') + ri.cw.p('yarg.data = rsp;') + ri.cw.nl() + ri.cw.p(f"err = {ri.nl.parse_cb_run(op_prefix(ri, 'reply') + '_parse', '&yarg', False)};") + ri.cw.p('if (err < 0)') + ri.cw.p('goto err_free;') + ri.cw.nl() + + ri.cw.p('err = ynl_recv_ack(ys, err);') + ri.cw.p('if (err)') + ri.cw.p('goto err_free;') + ri.cw.nl() + ri.cw.p(f"return {ret_ok};") + ri.cw.nl() + ri.cw.p('err_free:') + + if 'reply' in ri.op[ri.op_mode]: + ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}") + ri.cw.p(f"return {ret_err};") + ri.cw.block_end() + + +def print_dump(ri): + direction = "request" + print_prototype(ri, direction, terminate=False) + ri.cw.block_start() + local_vars = ['struct ynl_dump_state yds = {};', + 'struct nlmsghdr *nlh;', + 'int len, err;'] + + for var in local_vars: + ri.cw.p(f'{var}') + ri.cw.nl() + + ri.cw.p('yds.ys = ys;') + ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});") + ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;") + ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;") + ri.cw.nl() + ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);") + + if "request" in ri.op[ri.op_mode]: + ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;") + ri.cw.nl() + for _, attr in ri.struct["request"].member_list(): + attr.attr_put(ri, "req") + ri.cw.nl() + + ri.cw.p('err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);') + ri.cw.p('if (err < 0)') + ri.cw.p('return NULL;') + ri.cw.nl() + + ri.cw.block_start(line='do') + ri.cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);') + ri.cw.p('if (len < 0)') + ri.cw.p('goto free_list;') + ri.cw.nl() + ri.cw.p(f"err = {ri.nl.parse_cb_run('ynl_dump_trampoline', '&yds', False, indent=2)};") + ri.cw.p('if (err < 0)') + ri.cw.p('goto free_list;') + ri.cw.block_end(line='while (err > 0);') + ri.cw.nl() + + ri.cw.p('return yds.first;') + ri.cw.nl() + ri.cw.p('free_list:') + ri.cw.p(call_free(ri, rdir(direction), 'yds.first')) + ri.cw.p('return NULL;') + ri.cw.block_end() + + +def call_free(ri, direction, var): + return f"{op_prefix(ri, direction)}_free({var});" + + +def free_arg_name(direction): + if direction: + return direction_to_suffix[direction][1:] + return 'obj' + + +def print_free_prototype(ri, direction, suffix=';'): + name = op_prefix(ri, direction) + arg = free_arg_name(direction) + ri.cw.write_func_prot('void', f"{name}_free", [f"struct {name} *{arg}"], suffix=suffix) + + +def _print_type(ri, direction, struct): + suffix = f'_{ri.type_name}{direction_to_suffix[direction]}' + + if ri.op_mode == 'dump': + suffix += '_dump' + + ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}") + + meta_started = False + for _, attr in struct.member_list(): + for type_filter in ['len', 'bit']: + line = attr.presence_member(ri.ku_space, type_filter) + if line: + if not meta_started: + ri.cw.block_start(line=f"struct") + meta_started = True + ri.cw.p(line) + if meta_started: + ri.cw.block_end(line='_present;') + ri.cw.nl() + + for arg in struct.inherited: + ri.cw.p(f"__u32 {arg};") + + for _, attr in struct.member_list(): + attr.struct_member(ri) + + ri.cw.block_end(line=';') + ri.cw.nl() + + +def print_type(ri, direction): + _print_type(ri, direction, ri.struct[direction]) + + +def print_type_full(ri, struct): + _print_type(ri, "", struct) + + +def print_type_helpers(ri, direction, deref=False): + print_free_prototype(ri, direction) + + if ri.ku_space == 'user' and direction == 'request': + for _, attr in ri.struct[direction].member_list(): + attr.setter(ri, ri.attr_set, direction, deref=deref) + ri.cw.nl() + + +def print_req_type_helpers(ri): + print_type_helpers(ri, "request") + + +def print_rsp_type_helpers(ri): + if 'reply' not in ri.op[ri.op_mode]: + return + print_type_helpers(ri, "reply") + + +def print_parse_prototype(ri, direction, terminate=True): + suffix = "_rsp" if direction == "reply" else "_req" + term = ';' if terminate else '' + + ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse", + ['const struct nlattr **tb', + f"struct {ri.op.render_name}{suffix} *req"], + suffix=term) + + +def print_req_type(ri): + print_type(ri, "request") + + +def print_rsp_type(ri): + if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]: + direction = 'reply' + elif ri.op_mode == 'event': + direction = 'reply' + else: + return + print_type(ri, direction) + + +def print_wrapped_type(ri): + ri.cw.block_start(line=f"{type_name(ri, 'reply')}") + if ri.op_mode == 'dump': + ri.cw.p(f"{type_name(ri, 'reply')} *next;") + elif ri.op_mode == 'notify' or ri.op_mode == 'event': + ri.cw.p('__u16 family;') + ri.cw.p('__u8 cmd;') + ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);") + ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));") + ri.cw.block_end(line=';') + ri.cw.nl() + print_free_prototype(ri, 'reply') + ri.cw.nl() + + +def _free_type_members_iter(ri, struct): + for _, attr in struct.member_list(): + if attr.free_needs_iter(): + ri.cw.p('unsigned int i;') + ri.cw.nl() + break + + +def _free_type_members(ri, var, struct, ref=''): + for _, attr in struct.member_list(): + attr.free(ri, var, ref) + + +def _free_type(ri, direction, struct): + var = free_arg_name(direction) + + print_free_prototype(ri, direction, suffix='') + ri.cw.block_start() + _free_type_members_iter(ri, struct) + _free_type_members(ri, var, struct) + if direction: + ri.cw.p(f'free({var});') + ri.cw.block_end() + ri.cw.nl() + + +def free_rsp_nested(ri, struct): + _free_type(ri, "", struct) + + +def print_rsp_free(ri): + if 'reply' not in ri.op[ri.op_mode]: + return + _free_type(ri, 'reply', ri.struct['reply']) + + +def print_dump_type_free(ri): + sub_type = type_name(ri, 'reply') + + print_free_prototype(ri, 'reply', suffix='') + ri.cw.block_start() + ri.cw.p(f"{sub_type} *next = rsp;") + ri.cw.nl() + ri.cw.block_start(line='while (next)') + _free_type_members_iter(ri, ri.struct['reply']) + ri.cw.p('rsp = next;') + ri.cw.p('next = rsp->next;') + ri.cw.nl() + + _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.') + ri.cw.p(f'free(rsp);') + ri.cw.block_end() + ri.cw.block_end() + ri.cw.nl() + + +def print_ntf_type_free(ri): + print_free_prototype(ri, 'reply', suffix='') + ri.cw.block_start() + _free_type_members_iter(ri, ri.struct['reply']) + _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.') + ri.cw.p(f'free(rsp);') + ri.cw.block_end() + ri.cw.nl() + + +def print_ntf_parse_prototype(family, cw, suffix=';'): + cw.write_func_prot('struct ynl_ntf_base_type *', f"{family['name']}_ntf_parse", + ['struct ynl_sock *ys'], suffix=suffix) + + +def print_ntf_type_parse(family, cw, ku_mode): + print_ntf_parse_prototype(family, cw, suffix='') + cw.block_start() + cw.write_func_lvar(['struct genlmsghdr *genlh;', + 'struct nlmsghdr *nlh;', + 'struct ynl_parse_arg yarg = { .ys = ys, };', + 'struct ynl_ntf_base_type *rsp;', + 'int len, err;', + 'mnl_cb_t parse;']) + cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);') + cw.p('if (len < (ssize_t)(sizeof(*nlh) + sizeof(*genlh)))') + cw.p('return NULL;') + cw.nl() + cw.p('nlh = (struct nlmsghdr *)ys->rx_buf;') + cw.p('genlh = mnl_nlmsg_get_payload(nlh);') + cw.nl() + cw.block_start(line='switch (genlh->cmd)') + for ntf_op in sorted(family.all_notify.keys()): + op = family.ops[ntf_op] + ri = RenderInfo(cw, family, ku_mode, op, ntf_op, "notify") + for ntf in op['notify']['cmds']: + cw.p(f"case {ntf.enum_name}:") + cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'notify')}));") + cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;") + cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;") + cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;") + cw.p('break;') + for op_name, op in family.ops.items(): + if 'event' not in op: + continue + ri = RenderInfo(cw, family, ku_mode, op, op_name, "event") + cw.p(f"case {op.enum_name}:") + cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'event')}));") + cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;") + cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;") + cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;") + cw.p('break;') + cw.p('default:') + cw.p('ynl_error_unknown_notification(ys, genlh->cmd);') + cw.p('return NULL;') + cw.block_end() + cw.nl() + cw.p('yarg.data = rsp->data;') + cw.nl() + cw.p(f"err = {cw.nlib.parse_cb_run('parse', '&yarg', True)};") + cw.p('if (err < 0)') + cw.p('goto err_free;') + cw.nl() + cw.p('rsp->family = nlh->nlmsg_type;') + cw.p('rsp->cmd = genlh->cmd;') + cw.p('return rsp;') + cw.nl() + cw.p('err_free:') + cw.p('free(rsp);') + cw.p('return NULL;') + cw.block_end() + cw.nl() + + +def print_req_policy_fwd(cw, struct, ri=None, terminate=True): + if terminate and ri and kernel_can_gen_family_struct(struct.family): + return + + if terminate: + prefix = 'extern ' + else: + if kernel_can_gen_family_struct(struct.family) and ri: + prefix = 'static ' + else: + prefix = '' + + suffix = ';' if terminate else ' = {' + + max_attr = struct.attr_max_val + if ri: + name = ri.op.render_name + if ri.op.dual_policy: + name += '_' + ri.op_mode + else: + name = struct.render_name + cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}") + + +def print_req_policy(cw, struct, ri=None): + print_req_policy_fwd(cw, struct, ri=ri, terminate=False) + for _, arg in struct.member_list(): + arg.attr_policy(cw) + cw.p("};") + + +def kernel_can_gen_family_struct(family): + return family.proto == 'genetlink' + + +def print_kernel_op_table_fwd(family, cw, terminate): + exported = not kernel_can_gen_family_struct(family) + + if not terminate or exported: + cw.p(f"/* Ops table for {family.name} */") + + pol_to_struct = {'global': 'genl_small_ops', + 'per-op': 'genl_ops', + 'split': 'genl_split_ops'} + struct_type = pol_to_struct[family.kernel_policy] + + if family.kernel_policy == 'split': + cnt = 0 + for op in family.ops.values(): + if 'do' in op: + cnt += 1 + if 'dump' in op: + cnt += 1 + else: + cnt = len(family.ops) + + qual = 'static const' if not exported else 'const' + line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]" + if terminate: + cw.p(f"extern {line};") + else: + cw.block_start(line=line + ' =') + + if not terminate: + return + + cw.nl() + for name in family.hooks['pre']['do']['list']: + cw.write_func_prot('int', c_lower(name), + ['const struct genl_split_ops *ops', + 'struct sk_buff *skb', 'struct genl_info *info'], suffix=';') + for name in family.hooks['post']['do']['list']: + cw.write_func_prot('void', c_lower(name), + ['const struct genl_split_ops *ops', + 'struct sk_buff *skb', 'struct genl_info *info'], suffix=';') + for name in family.hooks['pre']['dump']['list']: + cw.write_func_prot('int', c_lower(name), + ['struct netlink_callback *cb'], suffix=';') + for name in family.hooks['post']['dump']['list']: + cw.write_func_prot('int', c_lower(name), + ['struct netlink_callback *cb'], suffix=';') + + cw.nl() + + for op_name, op in family.ops.items(): + if op.is_async: + continue + + if 'do' in op: + name = c_lower(f"{family.name}-nl-{op_name}-doit") + cw.write_func_prot('int', name, + ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';') + + if 'dump' in op: + name = c_lower(f"{family.name}-nl-{op_name}-dumpit") + cw.write_func_prot('int', name, + ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';') + cw.nl() + + +def print_kernel_op_table_hdr(family, cw): + print_kernel_op_table_fwd(family, cw, terminate=True) + + +def print_kernel_op_table(family, cw): + print_kernel_op_table_fwd(family, cw, terminate=False) + if family.kernel_policy == 'global' or family.kernel_policy == 'per-op': + for op_name, op in family.ops.items(): + if op.is_async: + continue + + cw.block_start() + members = [('cmd', op.enum_name)] + if 'dont-validate' in op: + members.append(('validate', + ' | '.join([c_upper('genl-dont-validate-' + x) + for x in op['dont-validate']])), ) + for op_mode in ['do', 'dump']: + if op_mode in op: + name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it") + members.append((op_mode + 'it', name)) + if family.kernel_policy == 'per-op': + struct = Struct(family, op['attribute-set'], + type_list=op['do']['request']['attributes']) + + name = c_lower(f"{family.name}-{op_name}-nl-policy") + members.append(('policy', name)) + members.append(('maxattr', struct.attr_max_val.enum_name)) + if 'flags' in op: + members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']]))) + cw.write_struct_init(members) + cw.block_end(line=',') + elif family.kernel_policy == 'split': + cb_names = {'do': {'pre': 'pre_doit', 'post': 'post_doit'}, + 'dump': {'pre': 'start', 'post': 'done'}} + + for op_name, op in family.ops.items(): + for op_mode in ['do', 'dump']: + if op.is_async or op_mode not in op: + continue + + cw.block_start() + members = [('cmd', op.enum_name)] + if 'dont-validate' in op: + members.append(('validate', + ' | '.join([c_upper('genl-dont-validate-' + x) + for x in op['dont-validate']])), ) + name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it") + if 'pre' in op[op_mode]: + members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre']))) + members.append((op_mode + 'it', name)) + if 'post' in op[op_mode]: + members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post']))) + if 'request' in op[op_mode]: + struct = Struct(family, op['attribute-set'], + type_list=op[op_mode]['request']['attributes']) + + if op.dual_policy: + name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy") + else: + name = c_lower(f"{family.name}-{op_name}-nl-policy") + members.append(('policy', name)) + members.append(('maxattr', struct.attr_max_val.enum_name)) + flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode] + members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags]))) + cw.write_struct_init(members) + cw.block_end(line=',') + + cw.block_end(line=';') + cw.nl() + + +def print_kernel_mcgrp_hdr(family, cw): + if not family.mcgrps['list']: + return + + cw.block_start('enum') + for grp in family.mcgrps['list']: + grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},") + cw.p(grp_id) + cw.block_end(';') + cw.nl() + + +def print_kernel_mcgrp_src(family, cw): + if not family.mcgrps['list']: + return + + cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =') + for grp in family.mcgrps['list']: + name = grp['name'] + grp_id = c_upper(f"{family.name}-nlgrp-{name}") + cw.p('[' + grp_id + '] = { "' + name + '", },') + cw.block_end(';') + cw.nl() + + +def print_kernel_family_struct_hdr(family, cw): + if not kernel_can_gen_family_struct(family): + return + + cw.p(f"extern struct genl_family {family.name}_nl_family;") + cw.nl() + + +def print_kernel_family_struct_src(family, cw): + if not kernel_can_gen_family_struct(family): + return + + cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =") + cw.p('.name\t\t= ' + family.fam_key + ',') + cw.p('.version\t= ' + family.ver_key + ',') + cw.p('.netnsok\t= true,') + cw.p('.parallel_ops\t= true,') + cw.p('.module\t\t= THIS_MODULE,') + if family.kernel_policy == 'per-op': + cw.p(f'.ops\t\t= {family.name}_nl_ops,') + cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),') + elif family.kernel_policy == 'split': + cw.p(f'.split_ops\t= {family.name}_nl_ops,') + cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),') + if family.mcgrps['list']: + cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,') + cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),') + cw.block_end(';') + + +def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'): + start_line = 'enum' + if enum_name in obj: + if obj[enum_name]: + start_line = 'enum ' + c_lower(obj[enum_name]) + elif ckey and ckey in obj: + start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey]) + cw.block_start(line=start_line) + + +def render_uapi(family, cw): + hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H" + cw.p('#ifndef ' + hdr_prot) + cw.p('#define ' + hdr_prot) + cw.nl() + + defines = [(family.fam_key, family["name"]), + (family.ver_key, family.get('version', 1))] + cw.writes_defines(defines) + cw.nl() + + defines = [] + for const in family['definitions']: + if const['type'] != 'const': + cw.writes_defines(defines) + defines = [] + cw.nl() + + # Write kdoc for enum and flags (one day maybe also structs) + if const['type'] == 'enum' or const['type'] == 'flags': + enum = family.consts[const['name']] + + if enum.has_doc(): + cw.p('/**') + doc = '' + if 'doc' in enum: + doc = ' - ' + enum['doc'] + cw.write_doc_line(enum.enum_name + doc) + for entry in enum.entry_list: + if entry.has_doc(): + doc = '@' + entry.c_name + ': ' + entry['doc'] + cw.write_doc_line(doc) + cw.p(' */') + + uapi_enum_start(family, cw, const, 'name') + name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-") + for entry in enum.entry_list: + suffix = ',' + if entry.value_change: + suffix = f" = {entry.user_value()}" + suffix + cw.p(entry.c_name + suffix) + + if const.get('render-max', False): + cw.nl() + max_name = c_upper(name_pfx + 'max') + cw.p('__' + max_name + ',') + cw.p(max_name + ' = (__' + max_name + ' - 1)') + cw.block_end(line=';') + cw.nl() + elif const['type'] == 'const': + defines.append([c_upper(family.get('c-define-name', + f"{family.name}-{const['name']}")), + const['value']]) + + if defines: + cw.writes_defines(defines) + cw.nl() + + max_by_define = family.get('max-by-define', False) + + for _, attr_set in family.attr_sets_list: + if attr_set.subset_of: + continue + + cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX")) + max_value = f"({cnt_name} - 1)" + + val = 0 + uapi_enum_start(family, cw, attr_set.yaml, 'enum-name') + for _, attr in attr_set.items(): + suffix = ',' + if attr['value'] != val: + suffix = f" = {attr['value']}," + val = attr['value'] + val += 1 + cw.p(attr.enum_name + suffix) + cw.nl() + cw.p(cnt_name + ('' if max_by_define else ',')) + if not max_by_define: + cw.p(f"{attr_set.max_name} = {max_value}") + cw.block_end(line=';') + if max_by_define: + cw.p(f"#define {attr_set.max_name} {max_value}") + cw.nl() + + # Commands + separate_ntf = 'async-prefix' in family['operations'] + + max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX")) + cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX")) + max_value = f"({cnt_name} - 1)" + + uapi_enum_start(family, cw, family['operations'], 'enum-name') + for op in family.msg_list: + if separate_ntf and ('notify' in op or 'event' in op): + continue + + suffix = ',' + if 'value' in op: + suffix = f" = {op['value']}," + cw.p(op.enum_name + suffix) + cw.nl() + cw.p(cnt_name + ('' if max_by_define else ',')) + if not max_by_define: + cw.p(f"{max_name} = {max_value}") + cw.block_end(line=';') + if max_by_define: + cw.p(f"#define {max_name} {max_value}") + cw.nl() + + if separate_ntf: + uapi_enum_start(family, cw, family['operations'], enum_name='async-enum') + for op in family.msg_list: + if separate_ntf and not ('notify' in op or 'event' in op): + continue + + suffix = ',' + if 'value' in op: + suffix = f" = {op['value']}," + cw.p(op.enum_name + suffix) + cw.block_end(line=';') + cw.nl() + + # Multicast + defines = [] + for grp in family.mcgrps['list']: + name = grp['name'] + defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")), + f'{name}']) + cw.nl() + if defines: + cw.writes_defines(defines) + cw.nl() + + cw.p(f'#endif /* {hdr_prot} */') + + +def find_kernel_root(full_path): + sub_path = '' + while True: + sub_path = os.path.join(os.path.basename(full_path), sub_path) + full_path = os.path.dirname(full_path) + maintainers = os.path.join(full_path, "MAINTAINERS") + if os.path.exists(maintainers): + return full_path, sub_path[:-1] + + +def main(): + parser = argparse.ArgumentParser(description='Netlink simple parsing generator') + parser.add_argument('--mode', dest='mode', type=str, required=True) + parser.add_argument('--spec', dest='spec', type=str, required=True) + parser.add_argument('--header', dest='header', action='store_true', default=None) + parser.add_argument('--source', dest='header', action='store_false') + parser.add_argument('--user-header', nargs='+', default=[]) + parser.add_argument('-o', dest='out_file', type=str) + args = parser.parse_args() + + out_file = open(args.out_file, 'w+') if args.out_file else os.sys.stdout + + if args.header is None: + parser.error("--header or --source is required") + + try: + parsed = Family(args.spec) + except yaml.YAMLError as exc: + print(exc) + os.sys.exit(1) + return + + cw = CodeWriter(BaseNlLib(), out_file) + + _, spec_kernel = find_kernel_root(args.spec) + if args.mode == 'uapi': + cw.p('/* SPDX-License-Identifier: GPL-2.0 WITH Linux-syscall-note */') + else: + if args.header: + cw.p('/* SPDX-License-Identifier: BSD-3-Clause */') + else: + cw.p('// SPDX-License-Identifier: BSD-3-Clause') + cw.p("/* Do not edit directly, auto-generated from: */") + cw.p(f"/*\t{spec_kernel} */") + cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */") + cw.nl() + + if args.mode == 'uapi': + render_uapi(parsed, cw) + return + + hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H" + if args.header: + cw.p('#ifndef ' + hdr_prot) + cw.p('#define ' + hdr_prot) + cw.nl() + + if args.mode == 'kernel': + cw.p('#include <net/netlink.h>') + cw.p('#include <net/genetlink.h>') + cw.nl() + if not args.header: + if args.out_file: + cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"') + cw.nl() + headers = [parsed.uapi_header] + for definition in parsed['definitions']: + if 'header' in definition: + headers.append(definition['header']) + for one in headers: + cw.p(f"#include <{one}>") + cw.nl() + + if args.mode == "user": + if not args.header: + cw.p("#include <stdlib.h>") + cw.p("#include <stdio.h>") + cw.p("#include <string.h>") + cw.p("#include <libmnl/libmnl.h>") + cw.p("#include <linux/genetlink.h>") + cw.nl() + for one in args.user_header: + cw.p(f'#include "{one}"') + else: + cw.p('struct ynl_sock;') + cw.nl() + + if args.mode == "kernel": + if args.header: + for _, struct in sorted(parsed.pure_nested_structs.items()): + if struct.request: + cw.p('/* Common nested types */') + break + for attr_set, struct in sorted(parsed.pure_nested_structs.items()): + if struct.request: + print_req_policy_fwd(cw, struct) + cw.nl() + + if parsed.kernel_policy == 'global': + cw.p(f"/* Global operation policy for {parsed.name} */") + + struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy) + print_req_policy_fwd(cw, struct) + cw.nl() + + if parsed.kernel_policy in {'per-op', 'split'}: + for op_name, op in parsed.ops.items(): + if 'do' in op and 'event' not in op: + ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do") + print_req_policy_fwd(cw, ri.struct['request'], ri=ri) + cw.nl() + + print_kernel_op_table_hdr(parsed, cw) + print_kernel_mcgrp_hdr(parsed, cw) + print_kernel_family_struct_hdr(parsed, cw) + else: + for _, struct in sorted(parsed.pure_nested_structs.items()): + if struct.request: + cw.p('/* Common nested types */') + break + for attr_set, struct in sorted(parsed.pure_nested_structs.items()): + if struct.request: + print_req_policy(cw, struct) + cw.nl() + + if parsed.kernel_policy == 'global': + cw.p(f"/* Global operation policy for {parsed.name} */") + + struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy) + print_req_policy(cw, struct) + cw.nl() + + for op_name, op in parsed.ops.items(): + if parsed.kernel_policy in {'per-op', 'split'}: + for op_mode in {'do', 'dump'}: + if op_mode in op and 'request' in op[op_mode]: + cw.p(f"/* {op.enum_name} - {op_mode} */") + ri = RenderInfo(cw, parsed, args.mode, op, op_name, op_mode) + print_req_policy(cw, ri.struct['request'], ri=ri) + cw.nl() + + print_kernel_op_table(parsed, cw) + print_kernel_mcgrp_src(parsed, cw) + print_kernel_family_struct_src(parsed, cw) + + if args.mode == "user": + has_ntf = False + if args.header: + cw.p('/* Common nested types */') + for attr_set, struct in sorted(parsed.pure_nested_structs.items()): + ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set) + print_type_full(ri, struct) + + for op_name, op in parsed.ops.items(): + cw.p(f"/* ============== {op.enum_name} ============== */") + + if 'do' in op and 'event' not in op: + cw.p(f"/* {op.enum_name} - do */") + ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do") + print_req_type(ri) + print_req_type_helpers(ri) + cw.nl() + print_rsp_type(ri) + print_rsp_type_helpers(ri) + cw.nl() + print_req_prototype(ri) + cw.nl() + + if 'dump' in op: + cw.p(f"/* {op.enum_name} - dump */") + ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'dump') + if 'request' in op['dump']: + print_req_type(ri) + print_req_type_helpers(ri) + if not ri.type_consistent: + print_rsp_type(ri) + print_wrapped_type(ri) + print_dump_prototype(ri) + cw.nl() + + if 'notify' in op: + cw.p(f"/* {op.enum_name} - notify */") + ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify') + has_ntf = True + if not ri.type_consistent: + raise Exception('Only notifications with consistent types supported') + print_wrapped_type(ri) + + if 'event' in op: + ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'event') + cw.p(f"/* {op.enum_name} - event */") + print_rsp_type(ri) + cw.nl() + print_wrapped_type(ri) + + if has_ntf: + cw.p('/* --------------- Common notification parsing --------------- */') + print_ntf_parse_prototype(parsed, cw) + cw.nl() + else: + cw.p('/* Policies */') + for name, _ in parsed.attr_sets.items(): + struct = Struct(parsed, name) + put_typol_fwd(cw, struct) + cw.nl() + + for name, _ in parsed.attr_sets.items(): + struct = Struct(parsed, name) + put_typol(cw, struct) + + cw.p('/* Common nested types */') + for attr_set, struct in sorted(parsed.pure_nested_structs.items()): + ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set) + + free_rsp_nested(ri, struct) + if struct.request: + put_req_nested(ri, struct) + if struct.reply: + parse_rsp_nested(ri, struct) + + for op_name, op in parsed.ops.items(): + cw.p(f"/* ============== {op.enum_name} ============== */") + if 'do' in op and 'event' not in op: + cw.p(f"/* {op.enum_name} - do */") + ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do") + print_rsp_free(ri) + parse_rsp_msg(ri) + print_req(ri) + cw.nl() + + if 'dump' in op: + cw.p(f"/* {op.enum_name} - dump */") + ri = RenderInfo(cw, parsed, args.mode, op, op_name, "dump") + if not ri.type_consistent: + parse_rsp_msg(ri, deref=True) + print_dump_type_free(ri) + print_dump(ri) + cw.nl() + + if 'notify' in op: + cw.p(f"/* {op.enum_name} - notify */") + ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify') + has_ntf = True + if not ri.type_consistent: + raise Exception('Only notifications with consistent types supported') + print_ntf_type_free(ri) + + if 'event' in op: + cw.p(f"/* {op.enum_name} - event */") + has_ntf = True + + ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do") + parse_rsp_msg(ri) + + ri = RenderInfo(cw, parsed, args.mode, op, op_name, "event") + print_ntf_type_free(ri) + + if has_ntf: + cw.p('/* --------------- Common notification parsing --------------- */') + print_ntf_type_parse(parsed, cw, args.mode) + + if args.header: + cw.p(f'#endif /* {hdr_prot} */') + + +if __name__ == "__main__": + main() diff --git a/tools/net/ynl/ynl-regen.sh b/tools/net/ynl/ynl-regen.sh new file mode 100755 index 000000000000..43989ae48ed0 --- /dev/null +++ b/tools/net/ynl/ynl-regen.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# SPDX-License-Identifier: BSD-3-Clause + +TOOL=$(dirname $(realpath $0))/ynl-gen-c.py + +force= + +while [ ! -z "$1" ]; do + case "$1" in + -f ) force=yes; shift ;; + * ) echo "Unrecognized option '$1'"; exit 1 ;; + esac +done + +KDIR=$(dirname $(dirname $(dirname $(dirname $(realpath $0))))) + +files=$(git grep --files-with-matches '^/\* YNL-GEN \(kernel\|uapi\)') +for f in $files; do + # params: 0 1 2 3 + # $YAML YNL-GEN kernel $mode + params=( $(git grep -B1 -h '/\* YNL-GEN' $f | sed 's@/\*\(.*\)\*/@\1@') ) + + if [ $f -nt ${params[0]} -a -z "$force" ]; then + echo -e "\tSKIP $f" + continue + fi + + echo -e "\tGEN ${params[2]}\t$f" + $TOOL --mode ${params[2]} --${params[3]} --spec $KDIR/${params[0]} -o $f +done diff --git a/tools/testing/selftests/bpf/DENYLIST.s390x b/tools/testing/selftests/bpf/DENYLIST.s390x index 3fc3e54b19aa..96e8371f5c2a 100644 --- a/tools/testing/selftests/bpf/DENYLIST.s390x +++ b/tools/testing/selftests/bpf/DENYLIST.s390x @@ -27,6 +27,7 @@ get_func_args_test # trampoline get_func_ip_test # get_func_ip_test__attach unexpected error: -524 (trampoline) get_stack_raw_tp # user_stack corrupted user stack (no backchain userspace) htab_update # failed to attach: ERROR: strerror_r(-524)=22 (trampoline) +jit_probe_mem # jit_probe_mem__open_and_load unexpected error: -524 (kfunc) kfree_skb # attach fentry unexpected error: -524 (trampoline) kfunc_call # 'bpf_prog_active': not found in kernel BTF (?) kfunc_dynptr_param # JIT does not support calling kernel function (kfunc) diff --git a/tools/testing/selftests/bpf/Makefile b/tools/testing/selftests/bpf/Makefile index c22c43bbee19..205e8c3c346a 100644 --- a/tools/testing/selftests/bpf/Makefile +++ b/tools/testing/selftests/bpf/Makefile @@ -626,3 +626,6 @@ EXTRA_CLEAN := $(TEST_CUSTOM_PROGS) $(SCRATCH_DIR) $(HOST_SCRATCH_DIR) \ liburandom_read.so) .PHONY: docs docs-clean + +# Delete partially updated (corrupted) files on error +.DELETE_ON_ERROR: diff --git a/tools/testing/selftests/bpf/prog_tests/jit_probe_mem.c b/tools/testing/selftests/bpf/prog_tests/jit_probe_mem.c new file mode 100644 index 000000000000..5639428607e6 --- /dev/null +++ b/tools/testing/selftests/bpf/prog_tests/jit_probe_mem.c @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2022 Meta Platforms, Inc. and affiliates. */ +#include <test_progs.h> +#include <network_helpers.h> + +#include "jit_probe_mem.skel.h" + +void test_jit_probe_mem(void) +{ + LIBBPF_OPTS(bpf_test_run_opts, opts, + .data_in = &pkt_v4, + .data_size_in = sizeof(pkt_v4), + .repeat = 1, + ); + struct jit_probe_mem *skel; + int ret; + + skel = jit_probe_mem__open_and_load(); + if (!ASSERT_OK_PTR(skel, "jit_probe_mem__open_and_load")) + return; + + ret = bpf_prog_test_run_opts(bpf_program__fd(skel->progs.test_jit_probe_mem), &opts); + ASSERT_OK(ret, "jit_probe_mem ret"); + ASSERT_OK(opts.retval, "jit_probe_mem opts.retval"); + ASSERT_EQ(skel->data->total_sum, 192, "jit_probe_mem total_sum"); + + jit_probe_mem__destroy(skel); +} diff --git a/tools/testing/selftests/bpf/progs/btf_dump_test_case_bitfields.c b/tools/testing/selftests/bpf/progs/btf_dump_test_case_bitfields.c index e5560a656030..e01690618e1e 100644 --- a/tools/testing/selftests/bpf/progs/btf_dump_test_case_bitfields.c +++ b/tools/testing/selftests/bpf/progs/btf_dump_test_case_bitfields.c @@ -53,7 +53,7 @@ struct bitfields_only_mixed_types { */ /* ------ END-EXPECTED-OUTPUT ------ */ struct bitfield_mixed_with_others { - long: 4; /* char is enough as a backing field */ + char: 4; /* char is enough as a backing field */ int a: 4; /* 8-bit implicit padding */ short b; /* combined with previous bitfield */ diff --git a/tools/testing/selftests/bpf/progs/btf_dump_test_case_packing.c b/tools/testing/selftests/bpf/progs/btf_dump_test_case_packing.c index e304b6204bd9..7998f27df7dd 100644 --- a/tools/testing/selftests/bpf/progs/btf_dump_test_case_packing.c +++ b/tools/testing/selftests/bpf/progs/btf_dump_test_case_packing.c @@ -58,7 +58,81 @@ union jump_code_union { } __attribute__((packed)); }; -/*------ END-EXPECTED-OUTPUT ------ */ +/* ----- START-EXPECTED-OUTPUT ----- */ +/* + *struct nested_packed_but_aligned_struct { + * int x1; + * int x2; + *}; + * + *struct outer_implicitly_packed_struct { + * char y1; + * struct nested_packed_but_aligned_struct y2; + *} __attribute__((packed)); + * + */ +/* ------ END-EXPECTED-OUTPUT ------ */ + +struct nested_packed_but_aligned_struct { + int x1; + int x2; +} __attribute__((packed)); + +struct outer_implicitly_packed_struct { + char y1; + struct nested_packed_but_aligned_struct y2; +}; +/* ----- START-EXPECTED-OUTPUT ----- */ +/* + *struct usb_ss_ep_comp_descriptor { + * char: 8; + * char bDescriptorType; + * char bMaxBurst; + * short wBytesPerInterval; + *}; + * + *struct usb_host_endpoint { + * long: 64; + * char: 8; + * struct usb_ss_ep_comp_descriptor ss_ep_comp; + * long: 0; + *} __attribute__((packed)); + * + */ +/* ------ END-EXPECTED-OUTPUT ------ */ + +struct usb_ss_ep_comp_descriptor { + char: 8; + char bDescriptorType; + char bMaxBurst; + int: 0; + short wBytesPerInterval; +} __attribute__((packed)); + +struct usb_host_endpoint { + long: 64; + char: 8; + struct usb_ss_ep_comp_descriptor ss_ep_comp; + long: 0; +}; + +/* ----- START-EXPECTED-OUTPUT ----- */ +struct nested_packed_struct { + int a; + char b; +} __attribute__((packed)); + +struct outer_nonpacked_struct { + short a; + struct nested_packed_struct b; +}; + +struct outer_packed_struct { + short a; + struct nested_packed_struct b; +} __attribute__((packed)); + +/* ------ END-EXPECTED-OUTPUT ------ */ int f(struct { struct packed_trailing_space _1; @@ -69,6 +143,10 @@ int f(struct { union union_is_never_packed _6; union union_does_not_need_packing _7; union jump_code_union _8; + struct outer_implicitly_packed_struct _9; + struct usb_host_endpoint _10; + struct outer_nonpacked_struct _11; + struct outer_packed_struct _12; } *_) { return 0; diff --git a/tools/testing/selftests/bpf/progs/btf_dump_test_case_padding.c b/tools/testing/selftests/bpf/progs/btf_dump_test_case_padding.c index 7cb522d22a66..79276fbe454a 100644 --- a/tools/testing/selftests/bpf/progs/btf_dump_test_case_padding.c +++ b/tools/testing/selftests/bpf/progs/btf_dump_test_case_padding.c @@ -19,7 +19,7 @@ struct padded_implicitly { /* *struct padded_explicitly { * int a; - * int: 32; + * long: 0; * int b; *}; * @@ -28,41 +28,28 @@ struct padded_implicitly { struct padded_explicitly { int a; - int: 1; /* algo will explicitly pad with full 32 bits here */ + int: 1; /* algo will emit aligning `long: 0;` here */ int b; }; /* ----- START-EXPECTED-OUTPUT ----- */ -/* - *struct padded_a_lot { - * int a; - * long: 32; - * long: 64; - * long: 64; - * int b; - *}; - * - */ -/* ------ END-EXPECTED-OUTPUT ------ */ - struct padded_a_lot { int a; - /* 32 bit of implicit padding here, which algo will make explicit */ long: 64; long: 64; int b; }; +/* ------ END-EXPECTED-OUTPUT ------ */ + /* ----- START-EXPECTED-OUTPUT ----- */ /* *struct padded_cache_line { * int a; - * long: 32; * long: 64; * long: 64; * long: 64; * int b; - * long: 32; * long: 64; * long: 64; * long: 64; @@ -85,7 +72,7 @@ struct padded_cache_line { *struct zone { * int a; * short b; - * short: 16; + * long: 0; * struct zone_padding __pad__; *}; * @@ -108,6 +95,131 @@ struct padding_wo_named_members { long: 64; }; +struct padding_weird_1 { + int a; + long: 64; + short: 16; + short b; +}; + +/* ------ END-EXPECTED-OUTPUT ------ */ + +/* ----- START-EXPECTED-OUTPUT ----- */ +/* + *struct padding_weird_2 { + * long: 56; + * char a; + * long: 56; + * char b; + * char: 8; + *}; + * + */ +/* ------ END-EXPECTED-OUTPUT ------ */ +struct padding_weird_2 { + int: 32; /* these paddings will be collapsed into `long: 56;` */ + short: 16; + char: 8; + char a; + int: 32; /* these paddings will be collapsed into `long: 56;` */ + short: 16; + char: 8; + char b; + char: 8; +}; + +/* ----- START-EXPECTED-OUTPUT ----- */ +struct exact_1byte { + char x; +}; + +struct padded_1byte { + char: 8; +}; + +struct exact_2bytes { + short x; +}; + +struct padded_2bytes { + short: 16; +}; + +struct exact_4bytes { + int x; +}; + +struct padded_4bytes { + int: 32; +}; + +struct exact_8bytes { + long x; +}; + +struct padded_8bytes { + long: 64; +}; + +struct ff_periodic_effect { + int: 32; + short magnitude; + long: 0; + short phase; + long: 0; + int: 32; + int custom_len; + short *custom_data; +}; + +struct ib_wc { + long: 64; + long: 64; + int: 32; + int byte_len; + void *qp; + union {} ex; + long: 64; + int slid; + int wc_flags; + long: 64; + char smac[6]; + long: 0; + char network_hdr_type; +}; + +struct acpi_object_method { + long: 64; + char: 8; + char type; + short reference_count; + char flags; + short: 0; + char: 8; + char sync_level; + long: 64; + void *node; + void *aml_start; + union {} dispatch; + long: 64; + int aml_length; +}; + +struct nested_unpacked { + int x; +}; + +struct nested_packed { + struct nested_unpacked a; + char c; +} __attribute__((packed)); + +struct outer_mixed_but_unpacked { + struct nested_packed b1; + short a1; + struct nested_packed b2; +}; + /* ------ END-EXPECTED-OUTPUT ------ */ int f(struct { @@ -117,6 +229,20 @@ int f(struct { struct padded_cache_line _4; struct zone _5; struct padding_wo_named_members _6; + struct padding_weird_1 _7; + struct padding_weird_2 _8; + struct exact_1byte _100; + struct padded_1byte _101; + struct exact_2bytes _102; + struct padded_2bytes _103; + struct exact_4bytes _104; + struct padded_4bytes _105; + struct exact_8bytes _106; + struct padded_8bytes _107; + struct ff_periodic_effect _200; + struct ib_wc _201; + struct acpi_object_method _202; + struct outer_mixed_but_unpacked _203; } *_) { return 0; diff --git a/tools/testing/selftests/bpf/progs/btf_dump_test_case_syntax.c b/tools/testing/selftests/bpf/progs/btf_dump_test_case_syntax.c index 4ee4748133fe..26fffb02ed10 100644 --- a/tools/testing/selftests/bpf/progs/btf_dump_test_case_syntax.c +++ b/tools/testing/selftests/bpf/progs/btf_dump_test_case_syntax.c @@ -25,6 +25,39 @@ typedef enum { H = 2, } e3_t; +/* ----- START-EXPECTED-OUTPUT ----- */ +/* + *enum e_byte { + * EBYTE_1 = 0, + * EBYTE_2 = 1, + *} __attribute__((mode(byte))); + * + */ +/* ----- END-EXPECTED-OUTPUT ----- */ +enum e_byte { + EBYTE_1, + EBYTE_2, +} __attribute__((mode(byte))); + +/* ----- START-EXPECTED-OUTPUT ----- */ +/* + *enum e_word { + * EWORD_1 = 0LL, + * EWORD_2 = 1LL, + *} __attribute__((mode(word))); + * + */ +/* ----- END-EXPECTED-OUTPUT ----- */ +enum e_word { + EWORD_1, + EWORD_2, +} __attribute__((mode(word))); /* force to use 8-byte backing for this enum */ + +/* ----- START-EXPECTED-OUTPUT ----- */ +enum e_big { + EBIG_1 = 1000000000000ULL, +}; + typedef int int_t; typedef volatile const int * volatile const crazy_ptr_t; @@ -224,6 +257,9 @@ struct root_struct { enum e2 _2; e2_t _2_1; e3_t _2_2; + enum e_byte _100; + enum e_word _101; + enum e_big _102; struct struct_w_typedefs _3; anon_struct_t _7; struct struct_fwd *_8; diff --git a/tools/testing/selftests/bpf/progs/jit_probe_mem.c b/tools/testing/selftests/bpf/progs/jit_probe_mem.c new file mode 100644 index 000000000000..2d2e61470794 --- /dev/null +++ b/tools/testing/selftests/bpf/progs/jit_probe_mem.c @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2022 Meta Platforms, Inc. and affiliates. */ +#include <vmlinux.h> +#include <bpf/bpf_tracing.h> +#include <bpf/bpf_helpers.h> + +static struct prog_test_ref_kfunc __kptr_ref *v; +long total_sum = -1; + +extern struct prog_test_ref_kfunc *bpf_kfunc_call_test_acquire(unsigned long *sp) __ksym; +extern void bpf_kfunc_call_test_release(struct prog_test_ref_kfunc *p) __ksym; + +SEC("tc") +int test_jit_probe_mem(struct __sk_buff *ctx) +{ + struct prog_test_ref_kfunc *p; + unsigned long zero = 0, sum; + + p = bpf_kfunc_call_test_acquire(&zero); + if (!p) + return 1; + + p = bpf_kptr_xchg(&v, p); + if (p) + goto release_out; + + /* Direct map value access of kptr, should be PTR_UNTRUSTED */ + p = v; + if (!p) + return 1; + + asm volatile ( + "r9 = %[p];" + "%[sum] = 0;" + + /* r8 = p->a */ + "r8 = *(u32 *)(r9 + 0);" + "%[sum] += r8;" + + /* r8 = p->b */ + "r8 = *(u32 *)(r9 + 4);" + "%[sum] += r8;" + + "r9 += 8;" + /* r9 = p->a */ + "r9 = *(u32 *)(r9 - 8);" + "%[sum] += r9;" + + : [sum] "=r"(sum) + : [p] "r"(p) + : "r8", "r9" + ); + + total_sum = sum; + return 0; +release_out: + bpf_kfunc_call_test_release(p); + return 1; +} + +char _license[] SEC("license") = "GPL"; diff --git a/tools/testing/selftests/bpf/progs/test_tunnel_kern.c b/tools/testing/selftests/bpf/progs/test_tunnel_kern.c index 98af55f0bcd3..508da4a23c4f 100644 --- a/tools/testing/selftests/bpf/progs/test_tunnel_kern.c +++ b/tools/testing/selftests/bpf/progs/test_tunnel_kern.c @@ -82,6 +82,27 @@ int gre_set_tunnel(struct __sk_buff *skb) } SEC("tc") +int gre_set_tunnel_no_key(struct __sk_buff *skb) +{ + int ret; + struct bpf_tunnel_key key; + + __builtin_memset(&key, 0x0, sizeof(key)); + key.remote_ipv4 = 0xac100164; /* 172.16.1.100 */ + key.tunnel_ttl = 64; + + ret = bpf_skb_set_tunnel_key(skb, &key, sizeof(key), + BPF_F_ZERO_CSUM_TX | BPF_F_SEQ_NUMBER | + BPF_F_NO_TUNNEL_KEY); + if (ret < 0) { + log_err(ret); + return TC_ACT_SHOT; + } + + return TC_ACT_OK; +} + +SEC("tc") int gre_get_tunnel(struct __sk_buff *skb) { int ret; diff --git a/tools/testing/selftests/bpf/test_tunnel.sh b/tools/testing/selftests/bpf/test_tunnel.sh index 2eaedc1d9ed3..06857b689c11 100755 --- a/tools/testing/selftests/bpf/test_tunnel.sh +++ b/tools/testing/selftests/bpf/test_tunnel.sh @@ -66,15 +66,20 @@ config_device() add_gre_tunnel() { + tun_key= + if [ -n "$1" ]; then + tun_key="key $1" + fi + # at_ns0 namespace ip netns exec at_ns0 \ - ip link add dev $DEV_NS type $TYPE seq key 2 \ + ip link add dev $DEV_NS type $TYPE seq $tun_key \ local 172.16.1.100 remote 172.16.1.200 ip netns exec at_ns0 ip link set dev $DEV_NS up ip netns exec at_ns0 ip addr add dev $DEV_NS 10.1.1.100/24 # root namespace - ip link add dev $DEV type $TYPE key 2 external + ip link add dev $DEV type $TYPE $tun_key external ip link set dev $DEV up ip addr add dev $DEV 10.1.1.200/24 } @@ -238,7 +243,7 @@ test_gre() check $TYPE config_device - add_gre_tunnel + add_gre_tunnel 2 attach_bpf $DEV gre_set_tunnel gre_get_tunnel ping $PING_ARG 10.1.1.100 check_err $? @@ -253,6 +258,30 @@ test_gre() echo -e ${GREEN}"PASS: $TYPE"${NC} } +test_gre_no_tunnel_key() +{ + TYPE=gre + DEV_NS=gre00 + DEV=gre11 + ret=0 + + check $TYPE + config_device + add_gre_tunnel + attach_bpf $DEV gre_set_tunnel_no_key gre_get_tunnel + ping $PING_ARG 10.1.1.100 + check_err $? + ip netns exec at_ns0 ping $PING_ARG 10.1.1.200 + check_err $? + cleanup + + if [ $ret -ne 0 ]; then + echo -e ${RED}"FAIL: $TYPE"${NC} + return 1 + fi + echo -e ${GREEN}"PASS: $TYPE"${NC} +} + test_ip6gre() { TYPE=ip6gre @@ -589,6 +618,7 @@ cleanup() ip link del ipip6tnl11 2> /dev/null ip link del ip6ip6tnl11 2> /dev/null ip link del gretap11 2> /dev/null + ip link del gre11 2> /dev/null ip link del ip6gre11 2> /dev/null ip link del ip6gretap11 2> /dev/null ip link del geneve11 2> /dev/null @@ -641,6 +671,10 @@ bpf_tunnel_test() test_gre errors=$(( $errors + $? )) + echo "Testing GRE tunnel (without tunnel keys)..." + test_gre_no_tunnel_key + errors=$(( $errors + $? )) + echo "Testing IP6GRE tunnel..." test_ip6gre errors=$(( $errors + $? )) diff --git a/tools/testing/selftests/net/Makefile b/tools/testing/selftests/net/Makefile index 3007e98a6d64..951bd5342bc6 100644 --- a/tools/testing/selftests/net/Makefile +++ b/tools/testing/selftests/net/Makefile @@ -45,6 +45,7 @@ TEST_PROGS += arp_ndisc_untracked_subnets.sh TEST_PROGS += stress_reuseport_listen.sh TEST_PROGS += l2_tos_ttl_inherit.sh TEST_PROGS += bind_bhash.sh +TEST_PROGS += ip_local_port_range.sh TEST_PROGS_EXTENDED := in_netns.sh setup_loopback.sh setup_veth.sh TEST_PROGS_EXTENDED += toeplitz_client.sh toeplitz.sh TEST_GEN_FILES = socket nettest @@ -75,14 +76,61 @@ TEST_GEN_PROGS += so_incoming_cpu TEST_PROGS += sctp_vrf.sh TEST_GEN_FILES += sctp_hello TEST_GEN_FILES += csum +TEST_GEN_FILES += nat6to4.o +TEST_GEN_FILES += ip_local_port_range TEST_FILES := settings include ../lib.mk -include bpf/Makefile - $(OUTPUT)/reuseport_bpf_numa: LDLIBS += -lnuma $(OUTPUT)/tcp_mmap: LDLIBS += -lpthread $(OUTPUT)/tcp_inq: LDLIBS += -lpthread $(OUTPUT)/bind_bhash: LDLIBS += -lpthread + +# Rules to generate bpf obj nat6to4.o +CLANG ?= clang +SCRATCH_DIR := $(OUTPUT)/tools +BUILD_DIR := $(SCRATCH_DIR)/build +BPFDIR := $(abspath ../../../lib/bpf) +APIDIR := $(abspath ../../../include/uapi) + +CCINCLUDE += -I../bpf +CCINCLUDE += -I../../../../usr/include/ +CCINCLUDE += -I$(SCRATCH_DIR)/include + +BPFOBJ := $(BUILD_DIR)/libbpf/libbpf.a + +MAKE_DIRS := $(BUILD_DIR)/libbpf +$(MAKE_DIRS): + mkdir -p $@ + +# Get Clang's default includes on this system, as opposed to those seen by +# '-target bpf'. This fixes "missing" files on some architectures/distros, +# such as asm/byteorder.h, asm/socket.h, asm/sockios.h, sys/cdefs.h etc. +# +# Use '-idirafter': Don't interfere with include mechanics except where the +# build would have failed anyways. +define get_sys_includes +$(shell $(1) $(2) -v -E - </dev/null 2>&1 \ + | sed -n '/<...> search starts here:/,/End of search list./{ s| \(/.*\)|-idirafter \1|p }') \ +$(shell $(1) $(2) -dM -E - </dev/null | grep '__riscv_xlen ' | awk '{printf("-D__riscv_xlen=%d -D__BITS_PER_LONG=%d", $$3, $$3)}') +endef + +ifneq ($(CROSS_COMPILE),) +CLANG_TARGET_ARCH = --target=$(notdir $(CROSS_COMPILE:%-=%)) +endif + +CLANG_SYS_INCLUDES = $(call get_sys_includes,$(CLANG),$(CLANG_TARGET_ARCH)) + +$(OUTPUT)/nat6to4.o: nat6to4.c $(BPFOBJ) | $(MAKE_DIRS) + $(CLANG) -O2 -target bpf -c $< $(CCINCLUDE) $(CLANG_SYS_INCLUDES) -o $@ + +$(BPFOBJ): $(wildcard $(BPFDIR)/*.[ch] $(BPFDIR)/Makefile) \ + $(APIDIR)/linux/bpf.h \ + | $(BUILD_DIR)/libbpf + $(MAKE) $(submake_extras) -C $(BPFDIR) OUTPUT=$(BUILD_DIR)/libbpf/ \ + EXTRA_CFLAGS='-g -O0' \ + DESTDIR=$(SCRATCH_DIR) prefix= all install_headers + +EXTRA_CLEAN := $(SCRATCH_DIR) diff --git a/tools/testing/selftests/net/bpf/Makefile b/tools/testing/selftests/net/bpf/Makefile deleted file mode 100644 index 4abaf16d2077..000000000000 --- a/tools/testing/selftests/net/bpf/Makefile +++ /dev/null @@ -1,51 +0,0 @@ -# SPDX-License-Identifier: GPL-2.0 - -CLANG ?= clang -SCRATCH_DIR := $(OUTPUT)/tools -BUILD_DIR := $(SCRATCH_DIR)/build -BPFDIR := $(abspath ../../../lib/bpf) -APIDIR := $(abspath ../../../include/uapi) - -CCINCLUDE += -I../../bpf -CCINCLUDE += -I../../../../../usr/include/ -CCINCLUDE += -I$(SCRATCH_DIR)/include - -BPFOBJ := $(BUILD_DIR)/libbpf/libbpf.a - -MAKE_DIRS := $(BUILD_DIR)/libbpf $(OUTPUT)/bpf -$(MAKE_DIRS): - mkdir -p $@ - -TEST_CUSTOM_PROGS = $(OUTPUT)/bpf/nat6to4.o -all: $(TEST_CUSTOM_PROGS) - -# Get Clang's default includes on this system, as opposed to those seen by -# '-target bpf'. This fixes "missing" files on some architectures/distros, -# such as asm/byteorder.h, asm/socket.h, asm/sockios.h, sys/cdefs.h etc. -# -# Use '-idirafter': Don't interfere with include mechanics except where the -# build would have failed anyways. -define get_sys_includes -$(shell $(1) $(2) -v -E - </dev/null 2>&1 \ - | sed -n '/<...> search starts here:/,/End of search list./{ s| \(/.*\)|-idirafter \1|p }') \ -$(shell $(1) $(2) -dM -E - </dev/null | grep '__riscv_xlen ' | awk '{printf("-D__riscv_xlen=%d -D__BITS_PER_LONG=%d", $$3, $$3)}') -endef - -ifneq ($(CROSS_COMPILE),) -CLANG_TARGET_ARCH = --target=$(notdir $(CROSS_COMPILE:%-=%)) -endif - -CLANG_SYS_INCLUDES = $(call get_sys_includes,$(CLANG),$(CLANG_TARGET_ARCH)) - -$(TEST_CUSTOM_PROGS): $(OUTPUT)/%.o: %.c $(BPFOBJ) | $(MAKE_DIRS) - $(CLANG) -O2 -target bpf -c $< $(CCINCLUDE) $(CLANG_SYS_INCLUDES) -o $@ - -$(BPFOBJ): $(wildcard $(BPFDIR)/*.[ch] $(BPFDIR)/Makefile) \ - $(APIDIR)/linux/bpf.h \ - | $(BUILD_DIR)/libbpf - $(MAKE) $(submake_extras) -C $(BPFDIR) OUTPUT=$(BUILD_DIR)/libbpf/ \ - EXTRA_CFLAGS='-g -O0' \ - DESTDIR=$(SCRATCH_DIR) prefix= all install_headers - -EXTRA_CLEAN := $(TEST_CUSTOM_PROGS) $(SCRATCH_DIR) - diff --git a/tools/testing/selftests/net/forwarding/tc_actions.sh b/tools/testing/selftests/net/forwarding/tc_actions.sh index 1e0a62f638fe..919c0dd9fe4b 100755 --- a/tools/testing/selftests/net/forwarding/tc_actions.sh +++ b/tools/testing/selftests/net/forwarding/tc_actions.sh @@ -3,7 +3,8 @@ ALL_TESTS="gact_drop_and_ok_test mirred_egress_redirect_test \ mirred_egress_mirror_test matchall_mirred_egress_mirror_test \ - gact_trap_test mirred_egress_to_ingress_test" + gact_trap_test mirred_egress_to_ingress_test \ + mirred_egress_to_ingress_tcp_test" NUM_NETIFS=4 source tc_common.sh source lib.sh @@ -198,6 +199,52 @@ mirred_egress_to_ingress_test() log_test "mirred_egress_to_ingress ($tcflags)" } +mirred_egress_to_ingress_tcp_test() +{ + local tmpfile=$(mktemp) tmpfile1=$(mktemp) + + RET=0 + dd conv=sparse status=none if=/dev/zero bs=1M count=2 of=$tmpfile + tc filter add dev $h1 protocol ip pref 100 handle 100 egress flower \ + $tcflags ip_proto tcp src_ip 192.0.2.1 dst_ip 192.0.2.2 \ + action ct commit nat src addr 192.0.2.2 pipe \ + action ct clear pipe \ + action ct commit nat dst addr 192.0.2.1 pipe \ + action ct clear pipe \ + action skbedit ptype host pipe \ + action mirred ingress redirect dev $h1 + tc filter add dev $h1 protocol ip pref 101 handle 101 egress flower \ + $tcflags ip_proto icmp \ + action mirred ingress redirect dev $h1 + tc filter add dev $h1 protocol ip pref 102 handle 102 ingress flower \ + ip_proto icmp \ + action drop + + ip vrf exec v$h1 nc --recv-only -w10 -l -p 12345 -o $tmpfile1 & + local rpid=$! + ip vrf exec v$h1 nc -w1 --send-only 192.0.2.2 12345 <$tmpfile + wait -n $rpid + cmp -s $tmpfile $tmpfile1 + check_err $? "server output check failed" + + $MZ $h1 -c 10 -p 64 -a $h1mac -b $h1mac -A 192.0.2.1 -B 192.0.2.1 \ + -t icmp "ping,id=42,seq=5" -q + tc_check_packets "dev $h1 egress" 101 10 + check_err $? "didn't mirred redirect ICMP" + tc_check_packets "dev $h1 ingress" 102 10 + check_err $? "didn't drop mirred ICMP" + local overlimits=$(tc_rule_stats_get ${h1} 101 egress .overlimits) + test ${overlimits} = 10 + check_err $? "wrong overlimits, expected 10 got ${overlimits}" + + tc filter del dev $h1 egress protocol ip pref 100 handle 100 flower + tc filter del dev $h1 egress protocol ip pref 101 handle 101 flower + tc filter del dev $h1 ingress protocol ip pref 102 handle 102 flower + + rm -f $tmpfile $tmpfile1 + log_test "mirred_egress_to_ingress_tcp ($tcflags)" +} + setup_prepare() { h1=${NETIFS[p1]} diff --git a/tools/testing/selftests/net/ip_local_port_range.c b/tools/testing/selftests/net/ip_local_port_range.c new file mode 100644 index 000000000000..75e3fdacdf73 --- /dev/null +++ b/tools/testing/selftests/net/ip_local_port_range.c @@ -0,0 +1,447 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +// Copyright (c) 2023 Cloudflare + +/* Test IP_LOCAL_PORT_RANGE socket option: IPv4 + IPv6, TCP + UDP. + * + * Tests assume that net.ipv4.ip_local_port_range is [40000, 49999]. + * Don't run these directly but with ip_local_port_range.sh script. + */ + +#include <fcntl.h> +#include <netinet/ip.h> + +#include "../kselftest_harness.h" + +#ifndef IP_LOCAL_PORT_RANGE +#define IP_LOCAL_PORT_RANGE 51 +#endif + +static __u32 pack_port_range(__u16 lo, __u16 hi) +{ + return (hi << 16) | (lo << 0); +} + +static void unpack_port_range(__u32 range, __u16 *lo, __u16 *hi) +{ + *lo = range & 0xffff; + *hi = range >> 16; +} + +static int get_so_domain(int fd) +{ + int domain, err; + socklen_t len; + + len = sizeof(domain); + err = getsockopt(fd, SOL_SOCKET, SO_DOMAIN, &domain, &len); + if (err) + return -1; + + return domain; +} + +static int bind_to_loopback_any_port(int fd) +{ + union { + struct sockaddr sa; + struct sockaddr_in v4; + struct sockaddr_in6 v6; + } addr; + socklen_t addr_len; + + memset(&addr, 0, sizeof(addr)); + switch (get_so_domain(fd)) { + case AF_INET: + addr.v4.sin_family = AF_INET; + addr.v4.sin_port = htons(0); + addr.v4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + addr_len = sizeof(addr.v4); + break; + case AF_INET6: + addr.v6.sin6_family = AF_INET6; + addr.v6.sin6_port = htons(0); + addr.v6.sin6_addr = in6addr_loopback; + addr_len = sizeof(addr.v6); + break; + default: + return -1; + } + + return bind(fd, &addr.sa, addr_len); +} + +static int get_sock_port(int fd) +{ + union { + struct sockaddr sa; + struct sockaddr_in v4; + struct sockaddr_in6 v6; + } addr; + socklen_t addr_len; + int err; + + addr_len = sizeof(addr); + memset(&addr, 0, sizeof(addr)); + err = getsockname(fd, &addr.sa, &addr_len); + if (err) + return -1; + + switch (addr.sa.sa_family) { + case AF_INET: + return ntohs(addr.v4.sin_port); + case AF_INET6: + return ntohs(addr.v6.sin6_port); + default: + errno = EAFNOSUPPORT; + return -1; + } +} + +static int get_ip_local_port_range(int fd, __u32 *range) +{ + socklen_t len; + __u32 val; + int err; + + len = sizeof(val); + err = getsockopt(fd, SOL_IP, IP_LOCAL_PORT_RANGE, &val, &len); + if (err) + return -1; + + *range = val; + return 0; +} + +FIXTURE(ip_local_port_range) {}; + +FIXTURE_SETUP(ip_local_port_range) +{ +} + +FIXTURE_TEARDOWN(ip_local_port_range) +{ +} + +FIXTURE_VARIANT(ip_local_port_range) { + int so_domain; + int so_type; + int so_protocol; +}; + +FIXTURE_VARIANT_ADD(ip_local_port_range, ip4_tcp) { + .so_domain = AF_INET, + .so_type = SOCK_STREAM, + .so_protocol = 0, +}; + +FIXTURE_VARIANT_ADD(ip_local_port_range, ip4_udp) { + .so_domain = AF_INET, + .so_type = SOCK_DGRAM, + .so_protocol = 0, +}; + +FIXTURE_VARIANT_ADD(ip_local_port_range, ip4_stcp) { + .so_domain = AF_INET, + .so_type = SOCK_STREAM, + .so_protocol = IPPROTO_SCTP, +}; + +FIXTURE_VARIANT_ADD(ip_local_port_range, ip6_tcp) { + .so_domain = AF_INET6, + .so_type = SOCK_STREAM, + .so_protocol = 0, +}; + +FIXTURE_VARIANT_ADD(ip_local_port_range, ip6_udp) { + .so_domain = AF_INET6, + .so_type = SOCK_DGRAM, + .so_protocol = 0, +}; + +FIXTURE_VARIANT_ADD(ip_local_port_range, ip6_stcp) { + .so_domain = AF_INET6, + .so_type = SOCK_STREAM, + .so_protocol = IPPROTO_SCTP, +}; + +TEST_F(ip_local_port_range, invalid_option_value) +{ + __u16 val16; + __u32 val32; + __u64 val64; + int fd, err; + + fd = socket(variant->so_domain, variant->so_type, variant->so_protocol); + ASSERT_GE(fd, 0) TH_LOG("socket failed"); + + /* Too few bytes */ + val16 = 40000; + err = setsockopt(fd, SOL_IP, IP_LOCAL_PORT_RANGE, &val16, sizeof(val16)); + EXPECT_TRUE(err) TH_LOG("expected setsockopt(IP_LOCAL_PORT_RANGE) to fail"); + EXPECT_EQ(errno, EINVAL); + + /* Empty range: low port > high port */ + val32 = pack_port_range(40222, 40111); + err = setsockopt(fd, SOL_IP, IP_LOCAL_PORT_RANGE, &val32, sizeof(val32)); + EXPECT_TRUE(err) TH_LOG("expected setsockopt(IP_LOCAL_PORT_RANGE) to fail"); + EXPECT_EQ(errno, EINVAL); + + /* Too many bytes */ + val64 = pack_port_range(40333, 40444); + err = setsockopt(fd, SOL_IP, IP_LOCAL_PORT_RANGE, &val64, sizeof(val64)); + EXPECT_TRUE(err) TH_LOG("expected setsockopt(IP_LOCAL_PORT_RANGE) to fail"); + EXPECT_EQ(errno, EINVAL); + + err = close(fd); + ASSERT_TRUE(!err) TH_LOG("close failed"); +} + +TEST_F(ip_local_port_range, port_range_out_of_netns_range) +{ + const struct test { + __u16 range_lo; + __u16 range_hi; + } tests[] = { + { 30000, 39999 }, /* socket range below netns range */ + { 50000, 59999 }, /* socket range above netns range */ + }; + const struct test *t; + + for (t = tests; t < tests + ARRAY_SIZE(tests); t++) { + /* Bind a couple of sockets, not just one, to check + * that the range wasn't clamped to a single port from + * the netns range. That is [40000, 40000] or [49999, + * 49999], respectively for each test case. + */ + int fds[2], i; + + TH_LOG("lo %5hu, hi %5hu", t->range_lo, t->range_hi); + + for (i = 0; i < ARRAY_SIZE(fds); i++) { + int fd, err, port; + __u32 range; + + fd = socket(variant->so_domain, variant->so_type, variant->so_protocol); + ASSERT_GE(fd, 0) TH_LOG("#%d: socket failed", i); + + range = pack_port_range(t->range_lo, t->range_hi); + err = setsockopt(fd, SOL_IP, IP_LOCAL_PORT_RANGE, &range, sizeof(range)); + ASSERT_TRUE(!err) TH_LOG("#%d: setsockopt(IP_LOCAL_PORT_RANGE) failed", i); + + err = bind_to_loopback_any_port(fd); + ASSERT_TRUE(!err) TH_LOG("#%d: bind failed", i); + + /* Check that socket port range outside of ephemeral range is ignored */ + port = get_sock_port(fd); + ASSERT_GE(port, 40000) TH_LOG("#%d: expected port within netns range", i); + ASSERT_LE(port, 49999) TH_LOG("#%d: expected port within netns range", i); + + fds[i] = fd; + } + + for (i = 0; i < ARRAY_SIZE(fds); i++) + ASSERT_TRUE(close(fds[i]) == 0) TH_LOG("#%d: close failed", i); + } +} + +TEST_F(ip_local_port_range, single_port_range) +{ + const struct test { + __u16 range_lo; + __u16 range_hi; + __u16 expected; + } tests[] = { + /* single port range within ephemeral range */ + { 45000, 45000, 45000 }, + /* first port in the ephemeral range (clamp from above) */ + { 0, 40000, 40000 }, + /* last port in the ephemeral range (clamp from below) */ + { 49999, 0, 49999 }, + }; + const struct test *t; + + for (t = tests; t < tests + ARRAY_SIZE(tests); t++) { + int fd, err, port; + __u32 range; + + TH_LOG("lo %5hu, hi %5hu, expected %5hu", + t->range_lo, t->range_hi, t->expected); + + fd = socket(variant->so_domain, variant->so_type, variant->so_protocol); + ASSERT_GE(fd, 0) TH_LOG("socket failed"); + + range = pack_port_range(t->range_lo, t->range_hi); + err = setsockopt(fd, SOL_IP, IP_LOCAL_PORT_RANGE, &range, sizeof(range)); + ASSERT_TRUE(!err) TH_LOG("setsockopt(IP_LOCAL_PORT_RANGE) failed"); + + err = bind_to_loopback_any_port(fd); + ASSERT_TRUE(!err) TH_LOG("bind failed"); + + port = get_sock_port(fd); + ASSERT_EQ(port, t->expected) TH_LOG("unexpected local port"); + + err = close(fd); + ASSERT_TRUE(!err) TH_LOG("close failed"); + } +} + +TEST_F(ip_local_port_range, exhaust_8_port_range) +{ + __u8 port_set = 0; + int i, fd, err; + __u32 range; + __u16 port; + int fds[8]; + + for (i = 0; i < ARRAY_SIZE(fds); i++) { + fd = socket(variant->so_domain, variant->so_type, variant->so_protocol); + ASSERT_GE(fd, 0) TH_LOG("socket failed"); + + range = pack_port_range(40000, 40007); + err = setsockopt(fd, SOL_IP, IP_LOCAL_PORT_RANGE, &range, sizeof(range)); + ASSERT_TRUE(!err) TH_LOG("setsockopt(IP_LOCAL_PORT_RANGE) failed"); + + err = bind_to_loopback_any_port(fd); + ASSERT_TRUE(!err) TH_LOG("bind failed"); + + port = get_sock_port(fd); + ASSERT_GE(port, 40000) TH_LOG("expected port within sockopt range"); + ASSERT_LE(port, 40007) TH_LOG("expected port within sockopt range"); + + port_set |= 1 << (port - 40000); + fds[i] = fd; + } + + /* Check that all every port from the test range is in use */ + ASSERT_EQ(port_set, 0xff) TH_LOG("expected all ports to be busy"); + + /* Check that bind() fails because the whole range is busy */ + fd = socket(variant->so_domain, variant->so_type, variant->so_protocol); + ASSERT_GE(fd, 0) TH_LOG("socket failed"); + + range = pack_port_range(40000, 40007); + err = setsockopt(fd, SOL_IP, IP_LOCAL_PORT_RANGE, &range, sizeof(range)); + ASSERT_TRUE(!err) TH_LOG("setsockopt(IP_LOCAL_PORT_RANGE) failed"); + + err = bind_to_loopback_any_port(fd); + ASSERT_TRUE(err) TH_LOG("expected bind to fail"); + ASSERT_EQ(errno, EADDRINUSE); + + err = close(fd); + ASSERT_TRUE(!err) TH_LOG("close failed"); + + for (i = 0; i < ARRAY_SIZE(fds); i++) { + err = close(fds[i]); + ASSERT_TRUE(!err) TH_LOG("close failed"); + } +} + +TEST_F(ip_local_port_range, late_bind) +{ + union { + struct sockaddr sa; + struct sockaddr_in v4; + struct sockaddr_in6 v6; + } addr; + socklen_t addr_len; + const int one = 1; + int fd, err; + __u32 range; + __u16 port; + + if (variant->so_protocol == IPPROTO_SCTP) + SKIP(return, "SCTP doesn't support IP_BIND_ADDRESS_NO_PORT"); + + fd = socket(variant->so_domain, variant->so_type, 0); + ASSERT_GE(fd, 0) TH_LOG("socket failed"); + + range = pack_port_range(40100, 40199); + err = setsockopt(fd, SOL_IP, IP_LOCAL_PORT_RANGE, &range, sizeof(range)); + ASSERT_TRUE(!err) TH_LOG("setsockopt(IP_LOCAL_PORT_RANGE) failed"); + + err = setsockopt(fd, SOL_IP, IP_BIND_ADDRESS_NO_PORT, &one, sizeof(one)); + ASSERT_TRUE(!err) TH_LOG("setsockopt(IP_BIND_ADDRESS_NO_PORT) failed"); + + err = bind_to_loopback_any_port(fd); + ASSERT_TRUE(!err) TH_LOG("bind failed"); + + port = get_sock_port(fd); + ASSERT_EQ(port, 0) TH_LOG("getsockname failed"); + + /* Invalid destination */ + memset(&addr, 0, sizeof(addr)); + switch (variant->so_domain) { + case AF_INET: + addr.v4.sin_family = AF_INET; + addr.v4.sin_port = htons(0); + addr.v4.sin_addr.s_addr = htonl(INADDR_ANY); + addr_len = sizeof(addr.v4); + break; + case AF_INET6: + addr.v6.sin6_family = AF_INET6; + addr.v6.sin6_port = htons(0); + addr.v6.sin6_addr = in6addr_any; + addr_len = sizeof(addr.v6); + break; + default: + ASSERT_TRUE(false) TH_LOG("unsupported socket domain"); + } + + /* connect() doesn't need to succeed for late bind to happen */ + connect(fd, &addr.sa, addr_len); + + port = get_sock_port(fd); + ASSERT_GE(port, 40100); + ASSERT_LE(port, 40199); + + err = close(fd); + ASSERT_TRUE(!err) TH_LOG("close failed"); +} + +TEST_F(ip_local_port_range, get_port_range) +{ + __u16 lo, hi; + __u32 range; + int fd, err; + + fd = socket(variant->so_domain, variant->so_type, variant->so_protocol); + ASSERT_GE(fd, 0) TH_LOG("socket failed"); + + /* Get range before it will be set */ + err = get_ip_local_port_range(fd, &range); + ASSERT_TRUE(!err) TH_LOG("getsockopt(IP_LOCAL_PORT_RANGE) failed"); + + unpack_port_range(range, &lo, &hi); + ASSERT_EQ(lo, 0) TH_LOG("unexpected low port"); + ASSERT_EQ(hi, 0) TH_LOG("unexpected high port"); + + range = pack_port_range(12345, 54321); + err = setsockopt(fd, SOL_IP, IP_LOCAL_PORT_RANGE, &range, sizeof(range)); + ASSERT_TRUE(!err) TH_LOG("setsockopt(IP_LOCAL_PORT_RANGE) failed"); + + /* Get range after it has been set */ + err = get_ip_local_port_range(fd, &range); + ASSERT_TRUE(!err) TH_LOG("getsockopt(IP_LOCAL_PORT_RANGE) failed"); + + unpack_port_range(range, &lo, &hi); + ASSERT_EQ(lo, 12345) TH_LOG("unexpected low port"); + ASSERT_EQ(hi, 54321) TH_LOG("unexpected high port"); + + /* Unset the port range */ + range = pack_port_range(0, 0); + err = setsockopt(fd, SOL_IP, IP_LOCAL_PORT_RANGE, &range, sizeof(range)); + ASSERT_TRUE(!err) TH_LOG("setsockopt(IP_LOCAL_PORT_RANGE) failed"); + + /* Get range after it has been unset */ + err = get_ip_local_port_range(fd, &range); + ASSERT_TRUE(!err) TH_LOG("getsockopt(IP_LOCAL_PORT_RANGE) failed"); + + unpack_port_range(range, &lo, &hi); + ASSERT_EQ(lo, 0) TH_LOG("unexpected low port"); + ASSERT_EQ(hi, 0) TH_LOG("unexpected high port"); + + err = close(fd); + ASSERT_TRUE(!err) TH_LOG("close failed"); +} + +TEST_HARNESS_MAIN diff --git a/tools/testing/selftests/net/ip_local_port_range.sh b/tools/testing/selftests/net/ip_local_port_range.sh new file mode 100755 index 000000000000..6c6ad346eaa0 --- /dev/null +++ b/tools/testing/selftests/net/ip_local_port_range.sh @@ -0,0 +1,5 @@ +#!/bin/sh +# SPDX-License-Identifier: GPL-2.0 + +./in_netns.sh \ + sh -c 'sysctl -q -w net.ipv4.ip_local_port_range="40000 49999" && ./ip_local_port_range' diff --git a/tools/testing/selftests/net/mptcp/diag.sh b/tools/testing/selftests/net/mptcp/diag.sh index 24bcd7b9bdb2..ef628b16fe9b 100755 --- a/tools/testing/selftests/net/mptcp/diag.sh +++ b/tools/testing/selftests/net/mptcp/diag.sh @@ -17,6 +17,11 @@ flush_pids() sleep 1.1 ip netns pids "${ns}" | xargs --no-run-if-empty kill -SIGUSR1 &>/dev/null + + for _ in $(seq 10); do + [ -z "$(ip netns pids "${ns}")" ] && break + sleep 0.1 + done } cleanup() @@ -37,15 +42,20 @@ if [ $? -ne 0 ];then exit $ksft_skip fi +get_msk_inuse() +{ + ip netns exec $ns cat /proc/net/protocols | awk '$1~/^MPTCP$/{print $3}' +} + __chk_nr() { - local condition="$1" + local command="$1" local expected=$2 local msg nr shift 2 msg=$* - nr=$(ss -inmHMN $ns | $condition) + nr=$(eval $command) printf "%-50s" "$msg" if [ $nr != $expected ]; then @@ -57,9 +67,17 @@ __chk_nr() test_cnt=$((test_cnt+1)) } +__chk_msk_nr() +{ + local condition=$1 + shift 1 + + __chk_nr "ss -inmHMN $ns | $condition" $* +} + chk_msk_nr() { - __chk_nr "grep -c token:" $* + __chk_msk_nr "grep -c token:" $* } wait_msk_nr() @@ -97,12 +115,12 @@ wait_msk_nr() chk_msk_fallback_nr() { - __chk_nr "grep -c fallback" $* + __chk_msk_nr "grep -c fallback" $* } chk_msk_remote_key_nr() { - __chk_nr "grep -c remote_key" $* + __chk_msk_nr "grep -c remote_key" $* } __chk_listen() @@ -142,6 +160,26 @@ chk_msk_listen() nr=$(ss -Ml $filter | wc -l) } +chk_msk_inuse() +{ + local expected=$1 + local listen_nr + + shift 1 + + listen_nr=$(ss -N "${ns}" -Ml | grep -c LISTEN) + expected=$((expected + listen_nr)) + + for _ in $(seq 10); do + if [ $(get_msk_inuse) -eq $expected ];then + break + fi + sleep 0.1 + done + + __chk_nr get_msk_inuse $expected $* +} + # $1: ns, $2: port wait_local_port_listen() { @@ -195,8 +233,10 @@ wait_connected $ns 10000 chk_msk_nr 2 "after MPC handshake " chk_msk_remote_key_nr 2 "....chk remote_key" chk_msk_fallback_nr 0 "....chk no fallback" +chk_msk_inuse 2 "....chk 2 msk in use" flush_pids +chk_msk_inuse 0 "....chk 0 msk in use after flush" echo "a" | \ timeout ${timeout_test} \ @@ -211,8 +251,11 @@ echo "b" | \ 127.0.0.1 >/dev/null & wait_connected $ns 10001 chk_msk_fallback_nr 1 "check fallback" +chk_msk_inuse 1 "....chk 1 msk in use" flush_pids +chk_msk_inuse 0 "....chk 0 msk in use after flush" + NR_CLIENTS=100 for I in `seq 1 $NR_CLIENTS`; do echo "a" | \ @@ -232,6 +275,9 @@ for I in `seq 1 $NR_CLIENTS`; do done wait_msk_nr $((NR_CLIENTS*2)) "many msk socket present" +chk_msk_inuse $((NR_CLIENTS*2)) "....chk many msk in use" flush_pids +chk_msk_inuse 0 "....chk 0 msk in use after flush" + exit $ret diff --git a/tools/testing/selftests/net/mptcp/mptcp_connect.c b/tools/testing/selftests/net/mptcp/mptcp_connect.c index 8a8266957bc5..b25a31445ded 100644 --- a/tools/testing/selftests/net/mptcp/mptcp_connect.c +++ b/tools/testing/selftests/net/mptcp/mptcp_connect.c @@ -627,7 +627,7 @@ static int copyfd_io_poll(int infd, int peerfd, int outfd, char rbuf[8192]; ssize_t len; - if (fds.events == 0) + if (fds.events == 0 || quit) break; switch (poll(&fds, 1, poll_timeout)) { @@ -733,7 +733,7 @@ static int copyfd_io_poll(int infd, int peerfd, int outfd, } /* leave some time for late join/announce */ - if (cfg_remove) + if (cfg_remove && !quit) usleep(cfg_wait); return 0; diff --git a/tools/testing/selftests/net/mptcp/mptcp_join.sh b/tools/testing/selftests/net/mptcp/mptcp_join.sh index d11d3d566608..387abdcec011 100755 --- a/tools/testing/selftests/net/mptcp/mptcp_join.sh +++ b/tools/testing/selftests/net/mptcp/mptcp_join.sh @@ -774,24 +774,17 @@ do_transfer() addr_nr_ns2=${addr_nr_ns2:9} fi - local local_addr - if is_v6 "${connect_addr}"; then - local_addr="::" - else - local_addr="0.0.0.0" - fi - extra_srv_args="$extra_args $extra_srv_args" if [ "$test_link_fail" -gt 1 ];then timeout ${timeout_test} \ ip netns exec ${listener_ns} \ ./mptcp_connect -t ${timeout_poll} -l -p $port -s ${srv_proto} \ - $extra_srv_args ${local_addr} < "$sinfail" > "$sout" & + $extra_srv_args "::" < "$sinfail" > "$sout" & else timeout ${timeout_test} \ ip netns exec ${listener_ns} \ ./mptcp_connect -t ${timeout_poll} -l -p $port -s ${srv_proto} \ - $extra_srv_args ${local_addr} < "$sin" > "$sout" & + $extra_srv_args "::" < "$sin" > "$sout" & fi local spid=$! @@ -2448,6 +2441,47 @@ v4mapped_tests() fi } +mixed_tests() +{ + if reset "IPv4 sockets do not use IPv6 addresses"; then + pm_nl_set_limits $ns1 0 1 + pm_nl_set_limits $ns2 1 1 + pm_nl_add_endpoint $ns1 dead:beef:2::1 flags signal + run_tests $ns1 $ns2 10.0.1.1 0 0 0 slow + chk_join_nr 0 0 0 + fi + + # Need an IPv6 mptcp socket to allow subflows of both families + if reset "simult IPv4 and IPv6 subflows"; then + pm_nl_set_limits $ns1 0 1 + pm_nl_set_limits $ns2 1 1 + pm_nl_add_endpoint $ns1 10.0.1.1 flags signal + run_tests $ns1 $ns2 dead:beef:2::1 0 0 0 slow + chk_join_nr 1 1 1 + fi + + # cross families subflows will not be created even in fullmesh mode + if reset "simult IPv4 and IPv6 subflows, fullmesh 1x1"; then + pm_nl_set_limits $ns1 0 4 + pm_nl_set_limits $ns2 1 4 + pm_nl_add_endpoint $ns2 dead:beef:2::2 flags subflow,fullmesh + pm_nl_add_endpoint $ns1 10.0.1.1 flags signal + run_tests $ns1 $ns2 dead:beef:2::1 0 0 0 slow + chk_join_nr 1 1 1 + fi + + # fullmesh still tries to create all the possibly subflows with + # matching family + if reset "simult IPv4 and IPv6 subflows, fullmesh 2x2"; then + pm_nl_set_limits $ns1 0 4 + pm_nl_set_limits $ns2 2 4 + pm_nl_add_endpoint $ns1 10.0.2.1 flags signal + pm_nl_add_endpoint $ns1 dead:beef:2::1 flags signal + run_tests $ns1 $ns2 dead:beef:1::1 0 0 fullmesh_1 slow + chk_join_nr 4 4 4 + fi +} + backup_tests() { # single subflow, backup @@ -3120,6 +3154,7 @@ all_tests_sorted=( a@add_tests 6@ipv6_tests 4@v4mapped_tests + M@mixed_tests b@backup_tests p@add_addr_ports_tests k@syncookies_tests diff --git a/tools/testing/selftests/net/mptcp/userspace_pm.sh b/tools/testing/selftests/net/mptcp/userspace_pm.sh index ab2d581f28a1..66c5be25c13d 100755 --- a/tools/testing/selftests/net/mptcp/userspace_pm.sh +++ b/tools/testing/selftests/net/mptcp/userspace_pm.sh @@ -43,41 +43,40 @@ rndh=$(printf %x "$sec")-$(mktemp -u XXXXXX) ns1="ns1-$rndh" ns2="ns2-$rndh" +print_title() +{ + stdbuf -o0 -e0 printf "INFO: %s\n" "${1}" +} + kill_wait() { + [ $1 -eq 0 ] && return 0 + + kill -SIGUSR1 $1 > /dev/null 2>&1 kill $1 > /dev/null 2>&1 wait $1 2>/dev/null } cleanup() { - echo "cleanup" - - rm -rf $file $client_evts $server_evts + print_title "Cleanup" # Terminate the MPTCP connection and related processes - if [ $client4_pid -ne 0 ]; then - kill -SIGUSR1 $client4_pid > /dev/null 2>&1 - fi - if [ $server4_pid -ne 0 ]; then - kill_wait $server4_pid - fi - if [ $client6_pid -ne 0 ]; then - kill -SIGUSR1 $client6_pid > /dev/null 2>&1 - fi - if [ $server6_pid -ne 0 ]; then - kill_wait $server6_pid - fi - if [ $server_evts_pid -ne 0 ]; then - kill_wait $server_evts_pid - fi - if [ $client_evts_pid -ne 0 ]; then - kill_wait $client_evts_pid - fi + local pid + for pid in $client4_pid $server4_pid $client6_pid $server6_pid\ + $server_evts_pid $client_evts_pid + do + kill_wait $pid + done + local netns for netns in "$ns1" "$ns2" ;do ip netns del "$netns" done + + rm -rf $file $client_evts $server_evts + + stdbuf -o0 -e0 printf "Done\n" } trap cleanup EXIT @@ -108,6 +107,7 @@ ip -net "$ns2" addr add dead:beef:1::2/64 dev ns2eth1 nodad ip -net "$ns2" addr add dead:beef:2::2/64 dev ns2eth1 nodad ip -net "$ns2" link set ns2eth1 up +print_title "Init" stdbuf -o0 -e0 printf "Created network namespaces ns1, ns2 \t\t\t[OK]\n" make_file() @@ -193,11 +193,16 @@ make_connection() server_serverside=$(grep "type:1," "$server_evts" | sed --unbuffered -n 's/.*\(server_side:\)\([[:digit:]]*\).*$/\2/p;q') + stdbuf -o0 -e0 printf "Established IP%s MPTCP Connection ns2 => ns1 \t\t" $is_v6 if [ "$client_token" != "" ] && [ "$server_token" != "" ] && [ "$client_serverside" = 0 ] && [ "$server_serverside" = 1 ] then - stdbuf -o0 -e0 printf "Established IP%s MPTCP Connection ns2 => ns1 \t\t[OK]\n" $is_v6 + stdbuf -o0 -e0 printf "[OK]\n" else + stdbuf -o0 -e0 printf "[FAIL]\n" + stdbuf -o0 -e0 printf "\tExpected tokens (c:%s - s:%s) and server (c:%d - s:%d)\n" \ + "${client_token}" "${server_token}" \ + "${client_serverside}" "${server_serverside}" exit 1 fi @@ -217,6 +222,48 @@ make_connection() fi } +# $1: var name ; $2: prev ret +check_expected_one() +{ + local var="${1}" + local exp="e_${var}" + local prev_ret="${2}" + + if [ "${!var}" = "${!exp}" ] + then + return 0 + fi + + if [ "${prev_ret}" = "0" ] + then + stdbuf -o0 -e0 printf "[FAIL]\n" + fi + + stdbuf -o0 -e0 printf "\tExpected value for '%s': '%s', got '%s'.\n" \ + "${var}" "${!var}" "${!exp}" + return 1 +} + +# $@: all var names to check +check_expected() +{ + local ret=0 + local var + + for var in "${@}" + do + check_expected_one "${var}" "${ret}" || ret=1 + done + + if [ ${ret} -eq 0 ] + then + stdbuf -o0 -e0 printf "[OK]\n" + return 0 + fi + + exit 1 +} + verify_announce_event() { local evt=$1 @@ -242,19 +289,14 @@ verify_announce_event() fi dport=$(sed --unbuffered -n 's/.*\(dport:\)\([[:digit:]]*\).*$/\2/p;q' "$evt") id=$(sed --unbuffered -n 's/.*\(rem_id:\)\([[:digit:]]*\).*$/\2/p;q' "$evt") - if [ "$type" = "$e_type" ] && [ "$token" = "$e_token" ] && - [ "$addr" = "$e_addr" ] && [ "$dport" = "$e_dport" ] && - [ "$id" = "$e_id" ] - then - stdbuf -o0 -e0 printf "[OK]\n" - return 0 - fi - stdbuf -o0 -e0 printf "[FAIL]\n" - exit 1 + + check_expected "type" "token" "addr" "dport" "id" } test_announce() { + print_title "Announce tests" + # Capture events on the network namespace running the server :>"$server_evts" @@ -270,7 +312,7 @@ test_announce() then stdbuf -o0 -e0 printf "[OK]\n" else - stdbuf -o0 -e0 printf "[FAIL]\n" + stdbuf -o0 -e0 printf "[FAIL]\n\ttype defined: %s\n" "${type}" exit 1 fi @@ -347,18 +389,14 @@ verify_remove_event() type=$(sed --unbuffered -n 's/.*\(type:\)\([[:digit:]]*\).*$/\2/p;q' "$evt") token=$(sed --unbuffered -n 's/.*\(token:\)\([[:digit:]]*\).*$/\2/p;q' "$evt") id=$(sed --unbuffered -n 's/.*\(rem_id:\)\([[:digit:]]*\).*$/\2/p;q' "$evt") - if [ "$type" = "$e_type" ] && [ "$token" = "$e_token" ] && - [ "$id" = "$e_id" ] - then - stdbuf -o0 -e0 printf "[OK]\n" - return 0 - fi - stdbuf -o0 -e0 printf "[FAIL]\n" - exit 1 + + check_expected "type" "token" "id" } test_remove() { + print_title "Remove tests" + # Capture events on the network namespace running the server :>"$server_evts" @@ -507,20 +545,13 @@ verify_subflow_events() daddr=$(sed --unbuffered -n 's/.*\(daddr4:\)\([0-9.]*\).*$/\2/p;q' "$evt") fi - if [ "$type" = "$e_type" ] && [ "$token" = "$e_token" ] && - [ "$daddr" = "$e_daddr" ] && [ "$e_dport" = "$dport" ] && - [ "$family" = "$e_family" ] && [ "$saddr" = "$e_saddr" ] && - [ "$e_locid" = "$locid" ] && [ "$e_remid" = "$remid" ] - then - stdbuf -o0 -e0 printf "[OK]\n" - return 0 - fi - stdbuf -o0 -e0 printf "[FAIL]\n" - exit 1 + check_expected "type" "token" "daddr" "dport" "family" "saddr" "locid" "remid" } test_subflows() { + print_title "Subflows v4 or v6 only tests" + # Capture events on the network namespace running the server :>"$server_evts" @@ -754,6 +785,8 @@ test_subflows() test_subflows_v4_v6_mix() { + print_title "Subflows v4 and v6 mix tests" + # Attempt to add a listener at 10.0.2.1:<subflow-port> ip netns exec "$ns1" ./pm_nl_ctl listen 10.0.2.1\ $app6_port > /dev/null 2>&1 & @@ -800,6 +833,8 @@ test_subflows_v4_v6_mix() test_prio() { + print_title "Prio tests" + local count # Send MP_PRIO signal from client to server machine @@ -811,7 +846,7 @@ test_prio() count=$(ip netns exec "$ns2" nstat -as | grep MPTcpExtMPPrioTx | awk '{print $2}') [ -z "$count" ] && count=0 if [ $count != 1 ]; then - stdbuf -o0 -e0 printf "[FAIL]\n" + stdbuf -o0 -e0 printf "[FAIL]\n\tCount != 1: %d\n" "${count}" exit 1 else stdbuf -o0 -e0 printf "[OK]\n" @@ -822,7 +857,7 @@ test_prio() count=$(ip netns exec "$ns1" nstat -as | grep MPTcpExtMPPrioRx | awk '{print $2}') [ -z "$count" ] && count=0 if [ $count != 1 ]; then - stdbuf -o0 -e0 printf "[FAIL]\n" + stdbuf -o0 -e0 printf "[FAIL]\n\tCount != 1: %d\n" "${count}" exit 1 else stdbuf -o0 -e0 printf "[OK]\n" @@ -863,19 +898,13 @@ verify_listener_events() sed --unbuffered -n 's/.*\(saddr4:\)\([0-9.]*\).*$/\2/p;q') fi - if [ $type ] && [ $type = $e_type ] && - [ $family ] && [ $family = $e_family ] && - [ $saddr ] && [ $saddr = $e_saddr ] && - [ $sport ] && [ $sport = $e_sport ]; then - stdbuf -o0 -e0 printf "[OK]\n" - return 0 - fi - stdbuf -o0 -e0 printf "[FAIL]\n" - exit 1 + check_expected "type" "family" "saddr" "sport" } test_listener() { + print_title "Listener tests" + # Capture events on the network namespace running the client :>$client_evts @@ -902,8 +931,10 @@ test_listener() verify_listener_events $client_evts $LISTENER_CLOSED $AF_INET 10.0.2.2 $client4_port } +print_title "Make connections" make_connection make_connection "v6" + test_announce test_remove test_subflows diff --git a/tools/testing/selftests/net/bpf/nat6to4.c b/tools/testing/selftests/net/nat6to4.c index ac54c36b25fc..ac54c36b25fc 100644 --- a/tools/testing/selftests/net/bpf/nat6to4.c +++ b/tools/testing/selftests/net/nat6to4.c diff --git a/tools/testing/selftests/net/tcp_mmap.c b/tools/testing/selftests/net/tcp_mmap.c index 00f837c9bc6c..46a02bbd31d0 100644 --- a/tools/testing/selftests/net/tcp_mmap.c +++ b/tools/testing/selftests/net/tcp_mmap.c @@ -137,7 +137,8 @@ static void *mmap_large_buffer(size_t need, size_t *allocated) if (buffer == (void *)-1) { sz = need; buffer = mmap(NULL, sz, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + MAP_PRIVATE | MAP_ANONYMOUS | MAP_POPULATE, + -1, 0); if (buffer != (void *)-1) fprintf(stderr, "MAP_HUGETLB attempt failed, look at /sys/kernel/mm/hugepages for optimal performance\n"); } diff --git a/tools/testing/selftests/net/udpgro_frglist.sh b/tools/testing/selftests/net/udpgro_frglist.sh index c9c4b9d65839..0a6359bed0b9 100755 --- a/tools/testing/selftests/net/udpgro_frglist.sh +++ b/tools/testing/selftests/net/udpgro_frglist.sh @@ -40,8 +40,8 @@ run_one() { ip -n "${PEER_NS}" link set veth1 xdp object ${BPF_FILE} section xdp tc -n "${PEER_NS}" qdisc add dev veth1 clsact - tc -n "${PEER_NS}" filter add dev veth1 ingress prio 4 protocol ipv6 bpf object-file ../bpf/nat6to4.o section schedcls/ingress6/nat_6 direct-action - tc -n "${PEER_NS}" filter add dev veth1 egress prio 4 protocol ip bpf object-file ../bpf/nat6to4.o section schedcls/egress4/snat4 direct-action + tc -n "${PEER_NS}" filter add dev veth1 ingress prio 4 protocol ipv6 bpf object-file nat6to4.o section schedcls/ingress6/nat_6 direct-action + tc -n "${PEER_NS}" filter add dev veth1 egress prio 4 protocol ip bpf object-file nat6to4.o section schedcls/egress4/snat4 direct-action echo ${rx_args} ip netns exec "${PEER_NS}" ./udpgso_bench_rx ${rx_args} -r & @@ -88,8 +88,8 @@ if [ ! -f ${BPF_FILE} ]; then exit -1 fi -if [ ! -f bpf/nat6to4.o ]; then - echo "Missing nat6to4 helper. Build bpfnat6to4.o selftest first" +if [ ! -f nat6to4.o ]; then + echo "Missing nat6to4 helper. Build bpf nat6to4.o selftest first" exit -1 fi diff --git a/tools/testing/vsock/Makefile b/tools/testing/vsock/Makefile index f8293c6910c9..43a254f0e14d 100644 --- a/tools/testing/vsock/Makefile +++ b/tools/testing/vsock/Makefile @@ -1,8 +1,9 @@ # SPDX-License-Identifier: GPL-2.0-only -all: test +all: test vsock_perf test: vsock_test vsock_diag_test vsock_test: vsock_test.o timeout.o control.o util.o vsock_diag_test: vsock_diag_test.o timeout.o control.o util.o +vsock_perf: vsock_perf.o CFLAGS += -g -O2 -Werror -Wall -I. -I../../include -I../../../usr/include -Wno-pointer-sign -fno-strict-overflow -fno-strict-aliasing -fno-common -MMD -U_FORTIFY_SOURCE -D_GNU_SOURCE .PHONY: all test clean diff --git a/tools/testing/vsock/README b/tools/testing/vsock/README index 4d5045e7d2c3..84ee217ba8ee 100644 --- a/tools/testing/vsock/README +++ b/tools/testing/vsock/README @@ -35,3 +35,37 @@ Invoke test binaries in both directions as follows: --control-port=$GUEST_IP \ --control-port=1234 \ --peer-cid=3 + +vsock_perf utility +------------------- +'vsock_perf' is a simple tool to measure vsock performance. It works in +sender/receiver modes: sender connect to peer at the specified port and +starts data transmission to the receiver. After data processing is done, +it prints several metrics(see below). + +Usage: +# run as sender +# connect to CID 2, port 1234, send 1G of data, tx buf size is 1M +./vsock_perf --sender 2 --port 1234 --bytes 1G --buf-size 1M + +Output: +tx performance: A Gbits/s + +Output explanation: +A is calculated as "number of bits to send" / "time in tx loop" + +# run as receiver +# listen port 1234, rx buf size is 1M, socket buf size is 1G, SO_RCVLOWAT is 64K +./vsock_perf --port 1234 --buf-size 1M --vsk-size 1G --rcvlowat 64K + +Output: +rx performance: A Gbits/s +total in 'read()': B sec +POLLIN wakeups: C +average in 'read()': D ns + +Output explanation: +A is calculated as "number of received bits" / "time in rx loop". +B is time, spent in 'read()' system call(excluding 'poll()') +C is number of 'poll()' wake ups with POLLIN bit set. +D is B / C, e.g. average amount of time, spent in single 'read()'. diff --git a/tools/testing/vsock/control.c b/tools/testing/vsock/control.c index 4874872fc5a3..d2deb4b15b94 100644 --- a/tools/testing/vsock/control.c +++ b/tools/testing/vsock/control.c @@ -141,6 +141,34 @@ void control_writeln(const char *str) timeout_end(); } +void control_writeulong(unsigned long value) +{ + char str[32]; + + if (snprintf(str, sizeof(str), "%lu", value) >= sizeof(str)) { + perror("snprintf"); + exit(EXIT_FAILURE); + } + + control_writeln(str); +} + +unsigned long control_readulong(void) +{ + unsigned long value; + char *str; + + str = control_readln(); + + if (!str) + exit(EXIT_FAILURE); + + value = strtoul(str, NULL, 10); + free(str); + + return value; +} + /* Return the next line from the control socket (without the trailing newline). * * The program terminates if a timeout occurs. diff --git a/tools/testing/vsock/control.h b/tools/testing/vsock/control.h index 51814b4f9ac1..c1f77fdb2c7a 100644 --- a/tools/testing/vsock/control.h +++ b/tools/testing/vsock/control.h @@ -9,7 +9,9 @@ void control_init(const char *control_host, const char *control_port, void control_cleanup(void); void control_writeln(const char *str); char *control_readln(void); +unsigned long control_readulong(void); void control_expectln(const char *str); bool control_cmpln(char *line, const char *str, bool fail); +void control_writeulong(unsigned long value); #endif /* CONTROL_H */ diff --git a/tools/testing/vsock/util.c b/tools/testing/vsock/util.c index 2acbb7703c6a..01b636d3039a 100644 --- a/tools/testing/vsock/util.c +++ b/tools/testing/vsock/util.c @@ -395,3 +395,16 @@ void skip_test(struct test_case *test_cases, size_t test_cases_len, test_cases[test_id].skip = true; } + +unsigned long hash_djb2(const void *data, size_t len) +{ + unsigned long hash = 5381; + int i = 0; + + while (i < len) { + hash = ((hash << 5) + hash) + ((unsigned char *)data)[i]; + i++; + } + + return hash; +} diff --git a/tools/testing/vsock/util.h b/tools/testing/vsock/util.h index a3375ad2fb7f..fb99208a95ea 100644 --- a/tools/testing/vsock/util.h +++ b/tools/testing/vsock/util.h @@ -49,4 +49,5 @@ void run_tests(const struct test_case *test_cases, void list_tests(const struct test_case *test_cases); void skip_test(struct test_case *test_cases, size_t test_cases_len, const char *test_id_str); +unsigned long hash_djb2(const void *data, size_t len); #endif /* UTIL_H */ diff --git a/tools/testing/vsock/vsock_perf.c b/tools/testing/vsock/vsock_perf.c new file mode 100644 index 000000000000..a72520338f84 --- /dev/null +++ b/tools/testing/vsock/vsock_perf.c @@ -0,0 +1,427 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * vsock_perf - benchmark utility for vsock. + * + * Copyright (C) 2022 SberDevices. + * + * Author: Arseniy Krasnov <AVKrasnov@sberdevices.ru> + */ +#include <getopt.h> +#include <stdio.h> +#include <stdlib.h> +#include <stdbool.h> +#include <string.h> +#include <errno.h> +#include <unistd.h> +#include <time.h> +#include <stdint.h> +#include <poll.h> +#include <sys/socket.h> +#include <linux/vm_sockets.h> + +#define DEFAULT_BUF_SIZE_BYTES (128 * 1024) +#define DEFAULT_TO_SEND_BYTES (64 * 1024) +#define DEFAULT_VSOCK_BUF_BYTES (256 * 1024) +#define DEFAULT_RCVLOWAT_BYTES 1 +#define DEFAULT_PORT 1234 + +#define BYTES_PER_GB (1024 * 1024 * 1024ULL) +#define NSEC_PER_SEC (1000000000ULL) + +static unsigned int port = DEFAULT_PORT; +static unsigned long buf_size_bytes = DEFAULT_BUF_SIZE_BYTES; +static unsigned long vsock_buf_bytes = DEFAULT_VSOCK_BUF_BYTES; + +static void error(const char *s) +{ + perror(s); + exit(EXIT_FAILURE); +} + +static time_t current_nsec(void) +{ + struct timespec ts; + + if (clock_gettime(CLOCK_REALTIME, &ts)) + error("clock_gettime"); + + return (ts.tv_sec * NSEC_PER_SEC) + ts.tv_nsec; +} + +/* From lib/cmdline.c. */ +static unsigned long memparse(const char *ptr) +{ + char *endptr; + + unsigned long long ret = strtoull(ptr, &endptr, 0); + + switch (*endptr) { + case 'E': + case 'e': + ret <<= 10; + case 'P': + case 'p': + ret <<= 10; + case 'T': + case 't': + ret <<= 10; + case 'G': + case 'g': + ret <<= 10; + case 'M': + case 'm': + ret <<= 10; + case 'K': + case 'k': + ret <<= 10; + endptr++; + default: + break; + } + + return ret; +} + +static void vsock_increase_buf_size(int fd) +{ + if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_MAX_SIZE, + &vsock_buf_bytes, sizeof(vsock_buf_bytes))) + error("setsockopt(SO_VM_SOCKETS_BUFFER_MAX_SIZE)"); + + if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE, + &vsock_buf_bytes, sizeof(vsock_buf_bytes))) + error("setsockopt(SO_VM_SOCKETS_BUFFER_SIZE)"); +} + +static int vsock_connect(unsigned int cid, unsigned int port) +{ + union { + struct sockaddr sa; + struct sockaddr_vm svm; + } addr = { + .svm = { + .svm_family = AF_VSOCK, + .svm_port = port, + .svm_cid = cid, + }, + }; + int fd; + + fd = socket(AF_VSOCK, SOCK_STREAM, 0); + + if (fd < 0) { + perror("socket"); + return -1; + } + + if (connect(fd, &addr.sa, sizeof(addr.svm)) < 0) { + perror("connect"); + close(fd); + return -1; + } + + return fd; +} + +static float get_gbps(unsigned long bits, time_t ns_delta) +{ + return ((float)bits / 1000000000ULL) / + ((float)ns_delta / NSEC_PER_SEC); +} + +static void run_receiver(unsigned long rcvlowat_bytes) +{ + unsigned int read_cnt; + time_t rx_begin_ns; + time_t in_read_ns; + size_t total_recv; + int client_fd; + char *data; + int fd; + union { + struct sockaddr sa; + struct sockaddr_vm svm; + } addr = { + .svm = { + .svm_family = AF_VSOCK, + .svm_port = port, + .svm_cid = VMADDR_CID_ANY, + }, + }; + union { + struct sockaddr sa; + struct sockaddr_vm svm; + } clientaddr; + + socklen_t clientaddr_len = sizeof(clientaddr.svm); + + printf("Run as receiver\n"); + printf("Listen port %u\n", port); + printf("RX buffer %lu bytes\n", buf_size_bytes); + printf("vsock buffer %lu bytes\n", vsock_buf_bytes); + printf("SO_RCVLOWAT %lu bytes\n", rcvlowat_bytes); + + fd = socket(AF_VSOCK, SOCK_STREAM, 0); + + if (fd < 0) + error("socket"); + + if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) + error("bind"); + + if (listen(fd, 1) < 0) + error("listen"); + + client_fd = accept(fd, &clientaddr.sa, &clientaddr_len); + + if (client_fd < 0) + error("accept"); + + vsock_increase_buf_size(client_fd); + + if (setsockopt(client_fd, SOL_SOCKET, SO_RCVLOWAT, + &rcvlowat_bytes, + sizeof(rcvlowat_bytes))) + error("setsockopt(SO_RCVLOWAT)"); + + data = malloc(buf_size_bytes); + + if (!data) { + fprintf(stderr, "'malloc()' failed\n"); + exit(EXIT_FAILURE); + } + + read_cnt = 0; + in_read_ns = 0; + total_recv = 0; + rx_begin_ns = current_nsec(); + + while (1) { + struct pollfd fds = { 0 }; + + fds.fd = client_fd; + fds.events = POLLIN | POLLERR | + POLLHUP | POLLRDHUP; + + if (poll(&fds, 1, -1) < 0) + error("poll"); + + if (fds.revents & POLLERR) { + fprintf(stderr, "'poll()' error\n"); + exit(EXIT_FAILURE); + } + + if (fds.revents & POLLIN) { + ssize_t bytes_read; + time_t t; + + t = current_nsec(); + bytes_read = read(fds.fd, data, buf_size_bytes); + in_read_ns += (current_nsec() - t); + read_cnt++; + + if (!bytes_read) + break; + + if (bytes_read < 0) { + perror("read"); + exit(EXIT_FAILURE); + } + + total_recv += bytes_read; + } + + if (fds.revents & (POLLHUP | POLLRDHUP)) + break; + } + + printf("total bytes received: %zu\n", total_recv); + printf("rx performance: %f Gbits/s\n", + get_gbps(total_recv * 8, current_nsec() - rx_begin_ns)); + printf("total time in 'read()': %f sec\n", (float)in_read_ns / NSEC_PER_SEC); + printf("average time in 'read()': %f ns\n", (float)in_read_ns / read_cnt); + printf("POLLIN wakeups: %i\n", read_cnt); + + free(data); + close(client_fd); + close(fd); +} + +static void run_sender(int peer_cid, unsigned long to_send_bytes) +{ + time_t tx_begin_ns; + time_t tx_total_ns; + size_t total_send; + void *data; + int fd; + + printf("Run as sender\n"); + printf("Connect to %i:%u\n", peer_cid, port); + printf("Send %lu bytes\n", to_send_bytes); + printf("TX buffer %lu bytes\n", buf_size_bytes); + + fd = vsock_connect(peer_cid, port); + + if (fd < 0) + exit(EXIT_FAILURE); + + data = malloc(buf_size_bytes); + + if (!data) { + fprintf(stderr, "'malloc()' failed\n"); + exit(EXIT_FAILURE); + } + + memset(data, 0, buf_size_bytes); + total_send = 0; + tx_begin_ns = current_nsec(); + + while (total_send < to_send_bytes) { + ssize_t sent; + + sent = write(fd, data, buf_size_bytes); + + if (sent <= 0) + error("write"); + + total_send += sent; + } + + tx_total_ns = current_nsec() - tx_begin_ns; + + printf("total bytes sent: %zu\n", total_send); + printf("tx performance: %f Gbits/s\n", + get_gbps(total_send * 8, tx_total_ns)); + printf("total time in 'write()': %f sec\n", + (float)tx_total_ns / NSEC_PER_SEC); + + close(fd); + free(data); +} + +static const char optstring[] = ""; +static const struct option longopts[] = { + { + .name = "help", + .has_arg = no_argument, + .val = 'H', + }, + { + .name = "sender", + .has_arg = required_argument, + .val = 'S', + }, + { + .name = "port", + .has_arg = required_argument, + .val = 'P', + }, + { + .name = "bytes", + .has_arg = required_argument, + .val = 'M', + }, + { + .name = "buf-size", + .has_arg = required_argument, + .val = 'B', + }, + { + .name = "vsk-size", + .has_arg = required_argument, + .val = 'V', + }, + { + .name = "rcvlowat", + .has_arg = required_argument, + .val = 'R', + }, + {}, +}; + +static void usage(void) +{ + printf("Usage: ./vsock_perf [--help] [options]\n" + "\n" + "This is benchmarking utility, to test vsock performance.\n" + "It runs in two modes: sender or receiver. In sender mode, it\n" + "connects to the specified CID and starts data transmission.\n" + "\n" + "Options:\n" + " --help This message\n" + " --sender <cid> Sender mode (receiver default)\n" + " <cid> of the receiver to connect to\n" + " --port <port> Port (default %d)\n" + " --bytes <bytes>KMG Bytes to send (default %d)\n" + " --buf-size <bytes>KMG Data buffer size (default %d). In sender mode\n" + " it is the buffer size, passed to 'write()'. In\n" + " receiver mode it is the buffer size passed to 'read()'.\n" + " --vsk-size <bytes>KMG Socket buffer size (default %d)\n" + " --rcvlowat <bytes>KMG SO_RCVLOWAT value (default %d)\n" + "\n", DEFAULT_PORT, DEFAULT_TO_SEND_BYTES, + DEFAULT_BUF_SIZE_BYTES, DEFAULT_VSOCK_BUF_BYTES, + DEFAULT_RCVLOWAT_BYTES); + exit(EXIT_FAILURE); +} + +static long strtolx(const char *arg) +{ + long value; + char *end; + + value = strtol(arg, &end, 10); + + if (end != arg + strlen(arg)) + usage(); + + return value; +} + +int main(int argc, char **argv) +{ + unsigned long to_send_bytes = DEFAULT_TO_SEND_BYTES; + unsigned long rcvlowat_bytes = DEFAULT_RCVLOWAT_BYTES; + int peer_cid = -1; + bool sender = false; + + while (1) { + int opt = getopt_long(argc, argv, optstring, longopts, NULL); + + if (opt == -1) + break; + + switch (opt) { + case 'V': /* Peer buffer size. */ + vsock_buf_bytes = memparse(optarg); + break; + case 'R': /* SO_RCVLOWAT value. */ + rcvlowat_bytes = memparse(optarg); + break; + case 'P': /* Port to connect to. */ + port = strtolx(optarg); + break; + case 'M': /* Bytes to send. */ + to_send_bytes = memparse(optarg); + break; + case 'B': /* Size of rx/tx buffer. */ + buf_size_bytes = memparse(optarg); + break; + case 'S': /* Sender mode. CID to connect to. */ + peer_cid = strtolx(optarg); + sender = true; + break; + case 'H': /* Help. */ + usage(); + break; + default: + usage(); + } + } + + if (!sender) + run_receiver(rcvlowat_bytes); + else + run_sender(peer_cid, to_send_bytes); + + return 0; +} diff --git a/tools/testing/vsock/vsock_test.c b/tools/testing/vsock/vsock_test.c index bb6d691cb30d..67e9f9df3a8c 100644 --- a/tools/testing/vsock/vsock_test.c +++ b/tools/testing/vsock/vsock_test.c @@ -284,10 +284,14 @@ static void test_stream_msg_peek_server(const struct test_opts *opts) close(fd); } -#define MESSAGES_CNT 7 -#define MSG_EOR_IDX (MESSAGES_CNT / 2) +#define SOCK_BUF_SIZE (2 * 1024 * 1024) +#define MAX_MSG_SIZE (32 * 1024) + static void test_seqpacket_msg_bounds_client(const struct test_opts *opts) { + unsigned long curr_hash; + int page_size; + int msg_count; int fd; fd = vsock_seqpacket_connect(opts->peer_cid, 1234); @@ -296,18 +300,79 @@ static void test_seqpacket_msg_bounds_client(const struct test_opts *opts) exit(EXIT_FAILURE); } - /* Send several messages, one with MSG_EOR flag */ - for (int i = 0; i < MESSAGES_CNT; i++) - send_byte(fd, 1, (i == MSG_EOR_IDX) ? MSG_EOR : 0); + /* Wait, until receiver sets buffer size. */ + control_expectln("SRVREADY"); + + curr_hash = 0; + page_size = getpagesize(); + msg_count = SOCK_BUF_SIZE / MAX_MSG_SIZE; + + for (int i = 0; i < msg_count; i++) { + ssize_t send_size; + size_t buf_size; + int flags; + void *buf; + + /* Use "small" buffers and "big" buffers. */ + if (i & 1) + buf_size = page_size + + (rand() % (MAX_MSG_SIZE - page_size)); + else + buf_size = 1 + (rand() % page_size); + + buf = malloc(buf_size); + + if (!buf) { + perror("malloc"); + exit(EXIT_FAILURE); + } + + memset(buf, rand() & 0xff, buf_size); + /* Set at least one MSG_EOR + some random. */ + if (i == (msg_count / 2) || (rand() & 1)) { + flags = MSG_EOR; + curr_hash++; + } else { + flags = 0; + } + + send_size = send(fd, buf, buf_size, flags); + + if (send_size < 0) { + perror("send"); + exit(EXIT_FAILURE); + } + + if (send_size != buf_size) { + fprintf(stderr, "Invalid send size\n"); + exit(EXIT_FAILURE); + } + + /* + * Hash sum is computed at both client and server in + * the same way: + * H += hash('message data') + * Such hash "controls" both data integrity and message + * bounds. After data exchange, both sums are compared + * using control socket, and if message bounds wasn't + * broken - two values must be equal. + */ + curr_hash += hash_djb2(buf, buf_size); + free(buf); + } control_writeln("SENDDONE"); + control_writeulong(curr_hash); close(fd); } static void test_seqpacket_msg_bounds_server(const struct test_opts *opts) { + unsigned long sock_buf_size; + unsigned long remote_hash; + unsigned long curr_hash; int fd; - char buf[16]; + char buf[MAX_MSG_SIZE]; struct msghdr msg = {0}; struct iovec iov = {0}; @@ -317,25 +382,57 @@ static void test_seqpacket_msg_bounds_server(const struct test_opts *opts) exit(EXIT_FAILURE); } + sock_buf_size = SOCK_BUF_SIZE; + + if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_MAX_SIZE, + &sock_buf_size, sizeof(sock_buf_size))) { + perror("setsockopt(SO_VM_SOCKETS_BUFFER_MAX_SIZE)"); + exit(EXIT_FAILURE); + } + + if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE, + &sock_buf_size, sizeof(sock_buf_size))) { + perror("setsockopt(SO_VM_SOCKETS_BUFFER_SIZE)"); + exit(EXIT_FAILURE); + } + + /* Ready to receive data. */ + control_writeln("SRVREADY"); + /* Wait, until peer sends whole data. */ control_expectln("SENDDONE"); iov.iov_base = buf; iov.iov_len = sizeof(buf); msg.msg_iov = &iov; msg.msg_iovlen = 1; - for (int i = 0; i < MESSAGES_CNT; i++) { - if (recvmsg(fd, &msg, 0) != 1) { - perror("message bound violated"); - exit(EXIT_FAILURE); - } + curr_hash = 0; + + while (1) { + ssize_t recv_size; - if ((i == MSG_EOR_IDX) ^ !!(msg.msg_flags & MSG_EOR)) { - perror("MSG_EOR"); + recv_size = recvmsg(fd, &msg, 0); + + if (!recv_size) + break; + + if (recv_size < 0) { + perror("recvmsg"); exit(EXIT_FAILURE); } + + if (msg.msg_flags & MSG_EOR) + curr_hash++; + + curr_hash += hash_djb2(msg.msg_iov[0].iov_base, recv_size); } close(fd); + remote_hash = control_readulong(); + + if (curr_hash != remote_hash) { + fprintf(stderr, "Message bounds broken\n"); + exit(EXIT_FAILURE); + } } #define MESSAGE_TRUNC_SZ 32 @@ -427,7 +524,7 @@ static void test_seqpacket_timeout_client(const struct test_opts *opts) tv.tv_usec = 0; if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (void *)&tv, sizeof(tv)) == -1) { - perror("setsockopt 'SO_RCVTIMEO'"); + perror("setsockopt(SO_RCVTIMEO)"); exit(EXIT_FAILURE); } @@ -472,6 +569,70 @@ static void test_seqpacket_timeout_server(const struct test_opts *opts) close(fd); } +static void test_seqpacket_bigmsg_client(const struct test_opts *opts) +{ + unsigned long sock_buf_size; + ssize_t send_size; + socklen_t len; + void *data; + int fd; + + len = sizeof(sock_buf_size); + + fd = vsock_seqpacket_connect(opts->peer_cid, 1234); + if (fd < 0) { + perror("connect"); + exit(EXIT_FAILURE); + } + + if (getsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE, + &sock_buf_size, &len)) { + perror("getsockopt"); + exit(EXIT_FAILURE); + } + + sock_buf_size++; + + data = malloc(sock_buf_size); + if (!data) { + perror("malloc"); + exit(EXIT_FAILURE); + } + + send_size = send(fd, data, sock_buf_size, 0); + if (send_size != -1) { + fprintf(stderr, "expected 'send(2)' failure, got %zi\n", + send_size); + exit(EXIT_FAILURE); + } + + if (errno != EMSGSIZE) { + fprintf(stderr, "expected EMSGSIZE in 'errno', got %i\n", + errno); + exit(EXIT_FAILURE); + } + + control_writeln("CLISENT"); + + free(data); + close(fd); +} + +static void test_seqpacket_bigmsg_server(const struct test_opts *opts) +{ + int fd; + + fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL); + if (fd < 0) { + perror("accept"); + exit(EXIT_FAILURE); + } + + control_expectln("CLISENT"); + + close(fd); +} + #define BUF_PATTERN_1 'a' #define BUF_PATTERN_2 'b' @@ -644,7 +805,7 @@ static void test_stream_poll_rcvlowat_client(const struct test_opts *opts) if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT, &lowat_val, sizeof(lowat_val))) { - perror("setsockopt"); + perror("setsockopt(SO_RCVLOWAT)"); exit(EXIT_FAILURE); } @@ -754,6 +915,11 @@ static struct test_case test_cases[] = { .run_client = test_stream_poll_rcvlowat_client, .run_server = test_stream_poll_rcvlowat_server, }, + { + .name = "SOCK_SEQPACKET big message", + .run_client = test_seqpacket_bigmsg_client, + .run_server = test_seqpacket_bigmsg_server, + }, {}, }; @@ -837,6 +1003,7 @@ int main(int argc, char **argv) .peer_cid = VMADDR_CID_ANY, }; + srand(time(NULL)); init_signals(); for (;;) { |