diff options
5 files changed, 384 insertions, 9 deletions
diff --git a/tools/testing/selftests/bpf/network_helpers.c b/tools/testing/selftests/bpf/network_helpers.c index 2060bc122c53..26468a8f44f3 100644 --- a/tools/testing/selftests/bpf/network_helpers.c +++ b/tools/testing/selftests/bpf/network_helpers.c @@ -66,17 +66,13 @@ int settimeo(int fd, int timeout_ms) #define save_errno_close(fd) ({ int __save = errno; close(fd); errno = __save; }) -int start_server(int family, int type, const char *addr_str, __u16 port, - int timeout_ms) +static int __start_server(int type, const struct sockaddr *addr, + socklen_t addrlen, int timeout_ms, bool reuseport) { - struct sockaddr_storage addr = {}; - socklen_t len; + int on = 1; int fd; - if (make_sockaddr(family, addr_str, port, &addr, &len)) - return -1; - - fd = socket(family, type, 0); + fd = socket(addr->sa_family, type, 0); if (fd < 0) { log_err("Failed to create server socket"); return -1; @@ -85,7 +81,13 @@ int start_server(int family, int type, const char *addr_str, __u16 port, if (settimeo(fd, timeout_ms)) goto error_close; - if (bind(fd, (const struct sockaddr *)&addr, len) < 0) { + if (reuseport && + setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &on, sizeof(on))) { + log_err("Failed to set SO_REUSEPORT"); + return -1; + } + + if (bind(fd, addr, addrlen) < 0) { log_err("Failed to bind socket"); goto error_close; } @@ -104,6 +106,69 @@ error_close: return -1; } +int start_server(int family, int type, const char *addr_str, __u16 port, + int timeout_ms) +{ + struct sockaddr_storage addr; + socklen_t addrlen; + + if (make_sockaddr(family, addr_str, port, &addr, &addrlen)) + return -1; + + return __start_server(type, (struct sockaddr *)&addr, + addrlen, timeout_ms, false); +} + +int *start_reuseport_server(int family, int type, const char *addr_str, + __u16 port, int timeout_ms, unsigned int nr_listens) +{ + struct sockaddr_storage addr; + unsigned int nr_fds = 0; + socklen_t addrlen; + int *fds; + + if (!nr_listens) + return NULL; + + if (make_sockaddr(family, addr_str, port, &addr, &addrlen)) + return NULL; + + fds = malloc(sizeof(*fds) * nr_listens); + if (!fds) + return NULL; + + fds[0] = __start_server(type, (struct sockaddr *)&addr, addrlen, + timeout_ms, true); + if (fds[0] == -1) + goto close_fds; + nr_fds = 1; + + if (getsockname(fds[0], (struct sockaddr *)&addr, &addrlen)) + goto close_fds; + + for (; nr_fds < nr_listens; nr_fds++) { + fds[nr_fds] = __start_server(type, (struct sockaddr *)&addr, + addrlen, timeout_ms, true); + if (fds[nr_fds] == -1) + goto close_fds; + } + + return fds; + +close_fds: + free_fds(fds, nr_fds); + return NULL; +} + +void free_fds(int *fds, unsigned int nr_close_fds) +{ + if (fds) { + while (nr_close_fds) + close(fds[--nr_close_fds]); + free(fds); + } +} + int fastopen_connect(int server_fd, const char *data, unsigned int data_len, int timeout_ms) { @@ -217,6 +282,7 @@ int make_sockaddr(int family, const char *addr_str, __u16 port, if (family == AF_INET) { struct sockaddr_in *sin = (void *)addr; + memset(addr, 0, sizeof(*sin)); sin->sin_family = AF_INET; sin->sin_port = htons(port); if (addr_str && @@ -230,6 +296,7 @@ int make_sockaddr(int family, const char *addr_str, __u16 port, } else if (family == AF_INET6) { struct sockaddr_in6 *sin6 = (void *)addr; + memset(addr, 0, sizeof(*sin6)); sin6->sin6_family = AF_INET6; sin6->sin6_port = htons(port); if (addr_str && diff --git a/tools/testing/selftests/bpf/network_helpers.h b/tools/testing/selftests/bpf/network_helpers.h index 5e0d51c07b63..d60bc2897770 100644 --- a/tools/testing/selftests/bpf/network_helpers.h +++ b/tools/testing/selftests/bpf/network_helpers.h @@ -36,6 +36,10 @@ extern struct ipv6_packet pkt_v6; int settimeo(int fd, int timeout_ms); int start_server(int family, int type, const char *addr, __u16 port, int timeout_ms); +int *start_reuseport_server(int family, int type, const char *addr_str, + __u16 port, int timeout_ms, + unsigned int nr_listens); +void free_fds(int *fds, unsigned int nr_close_fds); int connect_to_fd(int server_fd, int timeout_ms); int connect_fd_to_fd(int client_fd, int server_fd, int timeout_ms); int fastopen_connect(int server_fd, const char *data, unsigned int data_len, diff --git a/tools/testing/selftests/bpf/prog_tests/bpf_iter_setsockopt.c b/tools/testing/selftests/bpf/prog_tests/bpf_iter_setsockopt.c new file mode 100644 index 000000000000..85babb0487b3 --- /dev/null +++ b/tools/testing/selftests/bpf/prog_tests/bpf_iter_setsockopt.c @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2021 Facebook */ +#define _GNU_SOURCE +#include <sched.h> +#include <test_progs.h> +#include "network_helpers.h" +#include "bpf_dctcp.skel.h" +#include "bpf_cubic.skel.h" +#include "bpf_iter_setsockopt.skel.h" + +static int create_netns(void) +{ + if (!ASSERT_OK(unshare(CLONE_NEWNET), "create netns")) + return -1; + + if (!ASSERT_OK(system("ip link set dev lo up"), "bring up lo")) + return -1; + + return 0; +} + +static unsigned int set_bpf_cubic(int *fds, unsigned int nr_fds) +{ + unsigned int i; + + for (i = 0; i < nr_fds; i++) { + if (setsockopt(fds[i], SOL_TCP, TCP_CONGESTION, "bpf_cubic", + sizeof("bpf_cubic"))) + return i; + } + + return nr_fds; +} + +static unsigned int check_bpf_dctcp(int *fds, unsigned int nr_fds) +{ + char tcp_cc[16]; + socklen_t optlen = sizeof(tcp_cc); + unsigned int i; + + for (i = 0; i < nr_fds; i++) { + if (getsockopt(fds[i], SOL_TCP, TCP_CONGESTION, + tcp_cc, &optlen) || + strcmp(tcp_cc, "bpf_dctcp")) + return i; + } + + return nr_fds; +} + +static int *make_established(int listen_fd, unsigned int nr_est, + int **paccepted_fds) +{ + int *est_fds, *accepted_fds; + unsigned int i; + + est_fds = malloc(sizeof(*est_fds) * nr_est); + if (!est_fds) + return NULL; + + accepted_fds = malloc(sizeof(*accepted_fds) * nr_est); + if (!accepted_fds) { + free(est_fds); + return NULL; + } + + for (i = 0; i < nr_est; i++) { + est_fds[i] = connect_to_fd(listen_fd, 0); + if (est_fds[i] == -1) + break; + if (set_bpf_cubic(&est_fds[i], 1) != 1) { + close(est_fds[i]); + break; + } + + accepted_fds[i] = accept(listen_fd, NULL, 0); + if (accepted_fds[i] == -1) { + close(est_fds[i]); + break; + } + } + + if (!ASSERT_EQ(i, nr_est, "create established fds")) { + free_fds(accepted_fds, i); + free_fds(est_fds, i); + return NULL; + } + + *paccepted_fds = accepted_fds; + return est_fds; +} + +static unsigned short get_local_port(int fd) +{ + struct sockaddr_in6 addr; + socklen_t addrlen = sizeof(addr); + + if (!getsockname(fd, &addr, &addrlen)) + return ntohs(addr.sin6_port); + + return 0; +} + +static void do_bpf_iter_setsockopt(struct bpf_iter_setsockopt *iter_skel, + bool random_retry) +{ + int *reuse_listen_fds = NULL, *accepted_fds = NULL, *est_fds = NULL; + unsigned int nr_reuse_listens = 256, nr_est = 256; + int err, iter_fd = -1, listen_fd = -1; + char buf; + + /* Prepare non-reuseport listen_fd */ + listen_fd = start_server(AF_INET6, SOCK_STREAM, "::1", 0, 0); + if (!ASSERT_GE(listen_fd, 0, "start_server")) + return; + if (!ASSERT_EQ(set_bpf_cubic(&listen_fd, 1), 1, + "set listen_fd to cubic")) + goto done; + iter_skel->bss->listen_hport = get_local_port(listen_fd); + if (!ASSERT_NEQ(iter_skel->bss->listen_hport, 0, + "get_local_port(listen_fd)")) + goto done; + + /* Connect to non-reuseport listen_fd */ + est_fds = make_established(listen_fd, nr_est, &accepted_fds); + if (!ASSERT_OK_PTR(est_fds, "create established")) + goto done; + + /* Prepare reuseport listen fds */ + reuse_listen_fds = start_reuseport_server(AF_INET6, SOCK_STREAM, + "::1", 0, 0, + nr_reuse_listens); + if (!ASSERT_OK_PTR(reuse_listen_fds, "start_reuseport_server")) + goto done; + if (!ASSERT_EQ(set_bpf_cubic(reuse_listen_fds, nr_reuse_listens), + nr_reuse_listens, "set reuse_listen_fds to cubic")) + goto done; + iter_skel->bss->reuse_listen_hport = get_local_port(reuse_listen_fds[0]); + if (!ASSERT_NEQ(iter_skel->bss->reuse_listen_hport, 0, + "get_local_port(reuse_listen_fds[0])")) + goto done; + + /* Run bpf tcp iter to switch from bpf_cubic to bpf_dctcp */ + iter_skel->bss->random_retry = random_retry; + iter_fd = bpf_iter_create(bpf_link__fd(iter_skel->links.change_tcp_cc)); + if (!ASSERT_GE(iter_fd, 0, "create iter_fd")) + goto done; + + while ((err = read(iter_fd, &buf, sizeof(buf))) == -1 && + errno == EAGAIN) + ; + if (!ASSERT_OK(err, "read iter error")) + goto done; + + /* Check reuseport listen fds for dctcp */ + ASSERT_EQ(check_bpf_dctcp(reuse_listen_fds, nr_reuse_listens), + nr_reuse_listens, + "check reuse_listen_fds dctcp"); + + /* Check non reuseport listen fd for dctcp */ + ASSERT_EQ(check_bpf_dctcp(&listen_fd, 1), 1, + "check listen_fd dctcp"); + + /* Check established fds for dctcp */ + ASSERT_EQ(check_bpf_dctcp(est_fds, nr_est), nr_est, + "check est_fds dctcp"); + + /* Check accepted fds for dctcp */ + ASSERT_EQ(check_bpf_dctcp(accepted_fds, nr_est), nr_est, + "check accepted_fds dctcp"); + +done: + if (iter_fd != -1) + close(iter_fd); + if (listen_fd != -1) + close(listen_fd); + free_fds(reuse_listen_fds, nr_reuse_listens); + free_fds(accepted_fds, nr_est); + free_fds(est_fds, nr_est); +} + +void test_bpf_iter_setsockopt(void) +{ + struct bpf_iter_setsockopt *iter_skel = NULL; + struct bpf_cubic *cubic_skel = NULL; + struct bpf_dctcp *dctcp_skel = NULL; + struct bpf_link *cubic_link = NULL; + struct bpf_link *dctcp_link = NULL; + + if (create_netns()) + return; + + /* Load iter_skel */ + iter_skel = bpf_iter_setsockopt__open_and_load(); + if (!ASSERT_OK_PTR(iter_skel, "iter_skel")) + return; + iter_skel->links.change_tcp_cc = bpf_program__attach_iter(iter_skel->progs.change_tcp_cc, NULL); + if (!ASSERT_OK_PTR(iter_skel->links.change_tcp_cc, "attach iter")) + goto done; + + /* Load bpf_cubic */ + cubic_skel = bpf_cubic__open_and_load(); + if (!ASSERT_OK_PTR(cubic_skel, "cubic_skel")) + goto done; + cubic_link = bpf_map__attach_struct_ops(cubic_skel->maps.cubic); + if (!ASSERT_OK_PTR(cubic_link, "cubic_link")) + goto done; + + /* Load bpf_dctcp */ + dctcp_skel = bpf_dctcp__open_and_load(); + if (!ASSERT_OK_PTR(dctcp_skel, "dctcp_skel")) + goto done; + dctcp_link = bpf_map__attach_struct_ops(dctcp_skel->maps.dctcp); + if (!ASSERT_OK_PTR(dctcp_link, "dctcp_link")) + goto done; + + do_bpf_iter_setsockopt(iter_skel, true); + do_bpf_iter_setsockopt(iter_skel, false); + +done: + bpf_link__destroy(cubic_link); + bpf_link__destroy(dctcp_link); + bpf_cubic__destroy(cubic_skel); + bpf_dctcp__destroy(dctcp_skel); + bpf_iter_setsockopt__destroy(iter_skel); +} diff --git a/tools/testing/selftests/bpf/progs/bpf_iter_setsockopt.c b/tools/testing/selftests/bpf/progs/bpf_iter_setsockopt.c new file mode 100644 index 000000000000..b77adfd55d73 --- /dev/null +++ b/tools/testing/selftests/bpf/progs/bpf_iter_setsockopt.c @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2021 Facebook */ +#include "bpf_iter.h" +#include "bpf_tracing_net.h" +#include <bpf/bpf_helpers.h> +#include <bpf/bpf_endian.h> + +#define bpf_tcp_sk(skc) ({ \ + struct sock_common *_skc = skc; \ + sk = NULL; \ + tp = NULL; \ + if (_skc) { \ + tp = bpf_skc_to_tcp_sock(_skc); \ + sk = (struct sock *)tp; \ + } \ + tp; \ +}) + +unsigned short reuse_listen_hport = 0; +unsigned short listen_hport = 0; +char cubic_cc[TCP_CA_NAME_MAX] = "bpf_cubic"; +char dctcp_cc[TCP_CA_NAME_MAX] = "bpf_dctcp"; +bool random_retry = false; + +static bool tcp_cc_eq(const char *a, const char *b) +{ + int i; + + for (i = 0; i < TCP_CA_NAME_MAX; i++) { + if (a[i] != b[i]) + return false; + if (!a[i]) + break; + } + + return true; +} + +SEC("iter/tcp") +int change_tcp_cc(struct bpf_iter__tcp *ctx) +{ + char cur_cc[TCP_CA_NAME_MAX]; + struct tcp_sock *tp; + struct sock *sk; + int ret; + + if (!bpf_tcp_sk(ctx->sk_common)) + return 0; + + if (sk->sk_family != AF_INET6 || + (sk->sk_state != TCP_LISTEN && + sk->sk_state != TCP_ESTABLISHED) || + (sk->sk_num != reuse_listen_hport && + sk->sk_num != listen_hport && + bpf_ntohs(sk->sk_dport) != listen_hport)) + return 0; + + if (bpf_getsockopt(tp, SOL_TCP, TCP_CONGESTION, + cur_cc, sizeof(cur_cc))) + return 0; + + if (!tcp_cc_eq(cur_cc, cubic_cc)) + return 0; + + if (random_retry && bpf_get_prandom_u32() % 4 == 1) + return 1; + + bpf_setsockopt(tp, SOL_TCP, TCP_CONGESTION, dctcp_cc, sizeof(dctcp_cc)); + return 0; +} + +char _license[] SEC("license") = "GPL"; diff --git a/tools/testing/selftests/bpf/progs/bpf_tracing_net.h b/tools/testing/selftests/bpf/progs/bpf_tracing_net.h index 01378911252b..3af0998a0623 100644 --- a/tools/testing/selftests/bpf/progs/bpf_tracing_net.h +++ b/tools/testing/selftests/bpf/progs/bpf_tracing_net.h @@ -5,6 +5,10 @@ #define AF_INET 2 #define AF_INET6 10 +#define SOL_TCP 6 +#define TCP_CONGESTION 13 +#define TCP_CA_NAME_MAX 16 + #define ICSK_TIME_RETRANS 1 #define ICSK_TIME_PROBE0 3 #define ICSK_TIME_LOSS_PROBE 5 @@ -32,6 +36,8 @@ #define ir_v6_rmt_addr req.__req_common.skc_v6_daddr #define ir_v6_loc_addr req.__req_common.skc_v6_rcv_saddr +#define sk_num __sk_common.skc_num +#define sk_dport __sk_common.skc_dport #define sk_family __sk_common.skc_family #define sk_rmem_alloc sk_backlog.rmem_alloc #define sk_refcnt __sk_common.skc_refcnt |