diff options
author | Mina Almasry <almasrymina@google.com> | 2024-09-10 20:14:53 +0300 |
---|---|---|
committer | Jakub Kicinski <kuba@kernel.org> | 2024-09-12 06:44:32 +0300 |
commit | 8f0b3cc9a4c102c24808c87f1bc943659d7a7f9f (patch) | |
tree | 30642b181f7c7bfbeab6c87f268d1ffabc17d2cf | |
parent | 65249feb6b3df9e17bab5911ee56fa7b0971e231 (diff) | |
download | linux-8f0b3cc9a4c102c24808c87f1bc943659d7a7f9f.tar.xz |
tcp: RX path for devmem TCP
In tcp_recvmsg_locked(), detect if the skb being received by the user
is a devmem skb. In this case - if the user provided the MSG_SOCK_DEVMEM
flag - pass it to tcp_recvmsg_devmem() for custom handling.
tcp_recvmsg_devmem() copies any data in the skb header to the linear
buffer, and returns a cmsg to the user indicating the number of bytes
returned in the linear buffer.
tcp_recvmsg_devmem() then loops over the unaccessible devmem skb frags,
and returns to the user a cmsg_devmem indicating the location of the
data in the dmabuf device memory. cmsg_devmem contains this information:
1. the offset into the dmabuf where the payload starts. 'frag_offset'.
2. the size of the frag. 'frag_size'.
3. an opaque token 'frag_token' to return to the kernel when the buffer
is to be released.
The pages awaiting freeing are stored in the newly added
sk->sk_user_frags, and each page passed to userspace is get_page()'d.
This reference is dropped once the userspace indicates that it is
done reading this page. All pages are released when the socket is
destroyed.
Signed-off-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Kaiyuan Zhang <kaiyuanz@google.com>
Signed-off-by: Mina Almasry <almasrymina@google.com>
Reviewed-by: Pavel Begunkov <asml.silence@gmail.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Link: https://patch.msgid.link/20240910171458.219195-10-almasrymina@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
-rw-r--r-- | arch/alpha/include/uapi/asm/socket.h | 5 | ||||
-rw-r--r-- | arch/mips/include/uapi/asm/socket.h | 5 | ||||
-rw-r--r-- | arch/parisc/include/uapi/asm/socket.h | 5 | ||||
-rw-r--r-- | arch/sparc/include/uapi/asm/socket.h | 5 | ||||
-rw-r--r-- | include/linux/socket.h | 1 | ||||
-rw-r--r-- | include/net/sock.h | 2 | ||||
-rw-r--r-- | include/uapi/asm-generic/socket.h | 5 | ||||
-rw-r--r-- | include/uapi/linux/uio.h | 13 | ||||
-rw-r--r-- | net/core/devmem.h | 22 | ||||
-rw-r--r-- | net/ipv4/tcp.c | 257 | ||||
-rw-r--r-- | net/ipv4/tcp_ipv4.c | 16 | ||||
-rw-r--r-- | net/ipv4/tcp_minisocks.c | 2 |
12 files changed, 333 insertions, 5 deletions
diff --git a/arch/alpha/include/uapi/asm/socket.h b/arch/alpha/include/uapi/asm/socket.h index e94f621903fe..ef4656a41058 100644 --- a/arch/alpha/include/uapi/asm/socket.h +++ b/arch/alpha/include/uapi/asm/socket.h @@ -140,6 +140,11 @@ #define SO_PASSPIDFD 76 #define SO_PEERPIDFD 77 +#define SO_DEVMEM_LINEAR 78 +#define SCM_DEVMEM_LINEAR SO_DEVMEM_LINEAR +#define SO_DEVMEM_DMABUF 79 +#define SCM_DEVMEM_DMABUF SO_DEVMEM_DMABUF + #if !defined(__KERNEL__) #if __BITS_PER_LONG == 64 diff --git a/arch/mips/include/uapi/asm/socket.h b/arch/mips/include/uapi/asm/socket.h index 60ebaed28a4c..414807d55e33 100644 --- a/arch/mips/include/uapi/asm/socket.h +++ b/arch/mips/include/uapi/asm/socket.h @@ -151,6 +151,11 @@ #define SO_PASSPIDFD 76 #define SO_PEERPIDFD 77 +#define SO_DEVMEM_LINEAR 78 +#define SCM_DEVMEM_LINEAR SO_DEVMEM_LINEAR +#define SO_DEVMEM_DMABUF 79 +#define SCM_DEVMEM_DMABUF SO_DEVMEM_DMABUF + #if !defined(__KERNEL__) #if __BITS_PER_LONG == 64 diff --git a/arch/parisc/include/uapi/asm/socket.h b/arch/parisc/include/uapi/asm/socket.h index be264c2b1a11..2b817efd4544 100644 --- a/arch/parisc/include/uapi/asm/socket.h +++ b/arch/parisc/include/uapi/asm/socket.h @@ -132,6 +132,11 @@ #define SO_PASSPIDFD 0x404A #define SO_PEERPIDFD 0x404B +#define SO_DEVMEM_LINEAR 78 +#define SCM_DEVMEM_LINEAR SO_DEVMEM_LINEAR +#define SO_DEVMEM_DMABUF 79 +#define SCM_DEVMEM_DMABUF SO_DEVMEM_DMABUF + #if !defined(__KERNEL__) #if __BITS_PER_LONG == 64 diff --git a/arch/sparc/include/uapi/asm/socket.h b/arch/sparc/include/uapi/asm/socket.h index 682da3714686..00248fc68977 100644 --- a/arch/sparc/include/uapi/asm/socket.h +++ b/arch/sparc/include/uapi/asm/socket.h @@ -133,6 +133,11 @@ #define SO_PASSPIDFD 0x0055 #define SO_PEERPIDFD 0x0056 +#define SO_DEVMEM_LINEAR 0x0057 +#define SCM_DEVMEM_LINEAR SO_DEVMEM_LINEAR +#define SO_DEVMEM_DMABUF 0x0058 +#define SCM_DEVMEM_DMABUF SO_DEVMEM_DMABUF + #if !defined(__KERNEL__) diff --git a/include/linux/socket.h b/include/linux/socket.h index df9cdb8bbfb8..d18cc47e89bd 100644 --- a/include/linux/socket.h +++ b/include/linux/socket.h @@ -327,6 +327,7 @@ struct ucred { * plain text and require encryption */ +#define MSG_SOCK_DEVMEM 0x2000000 /* Receive devmem skbs as cmsg */ #define MSG_ZEROCOPY 0x4000000 /* Use user data in kernel path */ #define MSG_SPLICE_PAGES 0x8000000 /* Splice the pages from the iterator in sendmsg() */ #define MSG_FASTOPEN 0x20000000 /* Send data in TCP SYN */ diff --git a/include/net/sock.h b/include/net/sock.h index f51d61fab059..c58ca8dd561b 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -337,6 +337,7 @@ struct sk_filter; * @sk_txtime_report_errors: set report errors mode for SO_TXTIME * @sk_txtime_unused: unused txtime flags * @ns_tracker: tracker for netns reference + * @sk_user_frags: xarray of pages the user is holding a reference on. */ struct sock { /* @@ -542,6 +543,7 @@ struct sock { #endif struct rcu_head sk_rcu; netns_tracker ns_tracker; + struct xarray sk_user_frags; }; struct sock_bh_locked { diff --git a/include/uapi/asm-generic/socket.h b/include/uapi/asm-generic/socket.h index 8ce8a39a1e5f..e993edc9c0ee 100644 --- a/include/uapi/asm-generic/socket.h +++ b/include/uapi/asm-generic/socket.h @@ -135,6 +135,11 @@ #define SO_PASSPIDFD 76 #define SO_PEERPIDFD 77 +#define SO_DEVMEM_LINEAR 78 +#define SCM_DEVMEM_LINEAR SO_DEVMEM_LINEAR +#define SO_DEVMEM_DMABUF 79 +#define SCM_DEVMEM_DMABUF SO_DEVMEM_DMABUF + #if !defined(__KERNEL__) #if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__)) diff --git a/include/uapi/linux/uio.h b/include/uapi/linux/uio.h index 059b1a9147f4..3a22ddae376a 100644 --- a/include/uapi/linux/uio.h +++ b/include/uapi/linux/uio.h @@ -20,6 +20,19 @@ struct iovec __kernel_size_t iov_len; /* Must be size_t (1003.1g) */ }; +struct dmabuf_cmsg { + __u64 frag_offset; /* offset into the dmabuf where the frag starts. + */ + __u32 frag_size; /* size of the frag. */ + __u32 frag_token; /* token representing this frag for + * DEVMEM_DONTNEED. + */ + __u32 dmabuf_id; /* dmabuf id this frag belongs to. */ + __u32 flags; /* Currently unused. Reserved for future + * uses. + */ +}; + /* * UIO_MAXIOV shall be at least 16 1003.1g (5.4.1.1) */ diff --git a/net/core/devmem.h b/net/core/devmem.h index b1db4877cff9..76099ef9c482 100644 --- a/net/core/devmem.h +++ b/net/core/devmem.h @@ -91,6 +91,19 @@ net_iov_binding(const struct net_iov *niov) return net_iov_owner(niov)->binding; } +static inline unsigned long net_iov_virtual_addr(const struct net_iov *niov) +{ + struct dmabuf_genpool_chunk_owner *owner = net_iov_owner(niov); + + return owner->base_virtual + + ((unsigned long)net_iov_idx(niov) << PAGE_SHIFT); +} + +static inline u32 net_iov_binding_id(const struct net_iov *niov) +{ + return net_iov_owner(niov)->binding->id; +} + static inline void net_devmem_dmabuf_binding_get(struct net_devmem_dmabuf_binding *binding) { @@ -153,6 +166,15 @@ static inline void net_devmem_free_dmabuf(struct net_iov *ppiov) { } +static inline unsigned long net_iov_virtual_addr(const struct net_iov *niov) +{ + return 0; +} + +static inline u32 net_iov_binding_id(const struct net_iov *niov) +{ + return 0; +} #endif #endif /* _NET_DEVMEM_H */ diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index a2fac029a84a..4f77bd862e95 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -285,6 +285,8 @@ #include <trace/events/tcp.h> #include <net/rps.h> +#include "../core/devmem.h" + /* Track pending CMSGs. */ enum { TCP_CMSG_INQ = 1, @@ -471,6 +473,7 @@ void tcp_init_sock(struct sock *sk) set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags); sk_sockets_allocated_inc(sk); + xa_init_flags(&sk->sk_user_frags, XA_FLAGS_ALLOC1); } EXPORT_SYMBOL(tcp_init_sock); @@ -2328,6 +2331,220 @@ static int tcp_inq_hint(struct sock *sk) return inq; } +/* batch __xa_alloc() calls and reduce xa_lock()/xa_unlock() overhead. */ +struct tcp_xa_pool { + u8 max; /* max <= MAX_SKB_FRAGS */ + u8 idx; /* idx <= max */ + __u32 tokens[MAX_SKB_FRAGS]; + netmem_ref netmems[MAX_SKB_FRAGS]; +}; + +static void tcp_xa_pool_commit_locked(struct sock *sk, struct tcp_xa_pool *p) +{ + int i; + + /* Commit part that has been copied to user space. */ + for (i = 0; i < p->idx; i++) + __xa_cmpxchg(&sk->sk_user_frags, p->tokens[i], XA_ZERO_ENTRY, + (__force void *)p->netmems[i], GFP_KERNEL); + /* Rollback what has been pre-allocated and is no longer needed. */ + for (; i < p->max; i++) + __xa_erase(&sk->sk_user_frags, p->tokens[i]); + + p->max = 0; + p->idx = 0; +} + +static void tcp_xa_pool_commit(struct sock *sk, struct tcp_xa_pool *p) +{ + if (!p->max) + return; + + xa_lock_bh(&sk->sk_user_frags); + + tcp_xa_pool_commit_locked(sk, p); + + xa_unlock_bh(&sk->sk_user_frags); +} + +static int tcp_xa_pool_refill(struct sock *sk, struct tcp_xa_pool *p, + unsigned int max_frags) +{ + int err, k; + + if (p->idx < p->max) + return 0; + + xa_lock_bh(&sk->sk_user_frags); + + tcp_xa_pool_commit_locked(sk, p); + + for (k = 0; k < max_frags; k++) { + err = __xa_alloc(&sk->sk_user_frags, &p->tokens[k], + XA_ZERO_ENTRY, xa_limit_31b, GFP_KERNEL); + if (err) + break; + } + + xa_unlock_bh(&sk->sk_user_frags); + + p->max = k; + p->idx = 0; + return k ? 0 : err; +} + +/* On error, returns the -errno. On success, returns number of bytes sent to the + * user. May not consume all of @remaining_len. + */ +static int tcp_recvmsg_dmabuf(struct sock *sk, const struct sk_buff *skb, + unsigned int offset, struct msghdr *msg, + int remaining_len) +{ + struct dmabuf_cmsg dmabuf_cmsg = { 0 }; + struct tcp_xa_pool tcp_xa_pool; + unsigned int start; + int i, copy, n; + int sent = 0; + int err = 0; + + tcp_xa_pool.max = 0; + tcp_xa_pool.idx = 0; + do { + start = skb_headlen(skb); + + if (skb_frags_readable(skb)) { + err = -ENODEV; + goto out; + } + + /* Copy header. */ + copy = start - offset; + if (copy > 0) { + copy = min(copy, remaining_len); + + n = copy_to_iter(skb->data + offset, copy, + &msg->msg_iter); + if (n != copy) { + err = -EFAULT; + goto out; + } + + offset += copy; + remaining_len -= copy; + + /* First a dmabuf_cmsg for # bytes copied to user + * buffer. + */ + memset(&dmabuf_cmsg, 0, sizeof(dmabuf_cmsg)); + dmabuf_cmsg.frag_size = copy; + err = put_cmsg(msg, SOL_SOCKET, SO_DEVMEM_LINEAR, + sizeof(dmabuf_cmsg), &dmabuf_cmsg); + if (err || msg->msg_flags & MSG_CTRUNC) { + msg->msg_flags &= ~MSG_CTRUNC; + if (!err) + err = -ETOOSMALL; + goto out; + } + + sent += copy; + + if (remaining_len == 0) + goto out; + } + + /* after that, send information of dmabuf pages through a + * sequence of cmsg + */ + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + struct net_iov *niov; + u64 frag_offset; + int end; + + /* !skb_frags_readable() should indicate that ALL the + * frags in this skb are dmabuf net_iovs. We're checking + * for that flag above, but also check individual frags + * here. If the tcp stack is not setting + * skb_frags_readable() correctly, we still don't want + * to crash here. + */ + if (!skb_frag_net_iov(frag)) { + net_err_ratelimited("Found non-dmabuf skb with net_iov"); + err = -ENODEV; + goto out; + } + + niov = skb_frag_net_iov(frag); + end = start + skb_frag_size(frag); + copy = end - offset; + + if (copy > 0) { + copy = min(copy, remaining_len); + + frag_offset = net_iov_virtual_addr(niov) + + skb_frag_off(frag) + offset - + start; + dmabuf_cmsg.frag_offset = frag_offset; + dmabuf_cmsg.frag_size = copy; + err = tcp_xa_pool_refill(sk, &tcp_xa_pool, + skb_shinfo(skb)->nr_frags - i); + if (err) + goto out; + + /* Will perform the exchange later */ + dmabuf_cmsg.frag_token = tcp_xa_pool.tokens[tcp_xa_pool.idx]; + dmabuf_cmsg.dmabuf_id = net_iov_binding_id(niov); + + offset += copy; + remaining_len -= copy; + + err = put_cmsg(msg, SOL_SOCKET, + SO_DEVMEM_DMABUF, + sizeof(dmabuf_cmsg), + &dmabuf_cmsg); + if (err || msg->msg_flags & MSG_CTRUNC) { + msg->msg_flags &= ~MSG_CTRUNC; + if (!err) + err = -ETOOSMALL; + goto out; + } + + atomic_long_inc(&niov->pp_ref_count); + tcp_xa_pool.netmems[tcp_xa_pool.idx++] = skb_frag_netmem(frag); + + sent += copy; + + if (remaining_len == 0) + goto out; + } + start = end; + } + + tcp_xa_pool_commit(sk, &tcp_xa_pool); + if (!remaining_len) + goto out; + + /* if remaining_len is not satisfied yet, we need to go to the + * next frag in the frag_list to satisfy remaining_len. + */ + skb = skb_shinfo(skb)->frag_list ?: skb->next; + + offset = offset - start; + } while (skb); + + if (remaining_len) { + err = -EFAULT; + goto out; + } + +out: + tcp_xa_pool_commit(sk, &tcp_xa_pool); + if (!sent) + sent = err; + + return sent; +} + /* * This routine copies from a sock struct into the user buffer. * @@ -2341,6 +2558,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, int *cmsg_flags) { struct tcp_sock *tp = tcp_sk(sk); + int last_copied_dmabuf = -1; /* uninitialized */ int copied = 0; u32 peek_seq; u32 *seq; @@ -2520,15 +2738,44 @@ found_ok_skb: } if (!(flags & MSG_TRUNC)) { - err = skb_copy_datagram_msg(skb, offset, msg, used); - if (err) { - /* Exception. Bailout! */ - if (!copied) - copied = -EFAULT; + if (last_copied_dmabuf != -1 && + last_copied_dmabuf != !skb_frags_readable(skb)) break; + + if (skb_frags_readable(skb)) { + err = skb_copy_datagram_msg(skb, offset, msg, + used); + if (err) { + /* Exception. Bailout! */ + if (!copied) + copied = -EFAULT; + break; + } + } else { + if (!(flags & MSG_SOCK_DEVMEM)) { + /* dmabuf skbs can only be received + * with the MSG_SOCK_DEVMEM flag. + */ + if (!copied) + copied = -EFAULT; + + break; + } + + err = tcp_recvmsg_dmabuf(sk, skb, offset, msg, + used); + if (err <= 0) { + if (!copied) + copied = -EFAULT; + + break; + } + used = err; } } + last_copied_dmabuf = !skb_frags_readable(skb); + WRITE_ONCE(*seq, *seq + used); copied += used; len -= used; diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index eb631e66ee03..5afe5e57c89b 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -79,6 +79,7 @@ #include <linux/seq_file.h> #include <linux/inetdevice.h> #include <linux/btf_ids.h> +#include <linux/skbuff_ref.h> #include <crypto/hash.h> #include <linux/scatterlist.h> @@ -2512,10 +2513,25 @@ static void tcp_md5sig_info_free_rcu(struct rcu_head *head) } #endif +static void tcp_release_user_frags(struct sock *sk) +{ +#ifdef CONFIG_PAGE_POOL + unsigned long index; + void *netmem; + + xa_for_each(&sk->sk_user_frags, index, netmem) + WARN_ON_ONCE(!napi_pp_put_page((__force netmem_ref)netmem)); +#endif +} + void tcp_v4_destroy_sock(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); + tcp_release_user_frags(sk); + + xa_destroy(&sk->sk_user_frags); + trace_tcp_destroy_sock(sk); tcp_clear_xmit_timers(sk); diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c index ad562272db2e..bb1fe1ba867a 100644 --- a/net/ipv4/tcp_minisocks.c +++ b/net/ipv4/tcp_minisocks.c @@ -628,6 +628,8 @@ struct sock *tcp_create_openreq_child(const struct sock *sk, __TCP_INC_STATS(sock_net(sk), TCP_MIB_PASSIVEOPENS); + xa_init_flags(&newsk->sk_user_frags, XA_FLAGS_ALLOC1); + return newsk; } EXPORT_SYMBOL(tcp_create_openreq_child); |