From 88527790c079fb1ea41cbcfa4450ee37906a2fb0 Mon Sep 17 00:00:00 2001 From: Jakub Kicinski Date: Tue, 5 Jul 2022 16:59:24 -0700 Subject: tls: rx: add sockopt for enabling optimistic decrypt with TLS 1.3 Since optimisitic decrypt may add extra load in case of retries require socket owner to explicitly opt-in. Signed-off-by: Jakub Kicinski Signed-off-by: David S. Miller --- net/tls/tls_main.c | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) (limited to 'net/tls/tls_main.c') diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 2ffede463e4a..1b3efc96db0b 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -533,6 +533,37 @@ static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval, return 0; } +static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, + int __user *optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + unsigned int value; + int err, len; + + if (ctx->prot_info.version != TLS_1_3_VERSION) + return -EINVAL; + + if (get_user(len, optlen)) + return -EFAULT; + if (len < sizeof(value)) + return -EINVAL; + + lock_sock(sk); + err = -EINVAL; + if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) + value = ctx->rx_no_pad; + release_sock(sk); + if (err) + return err; + + if (put_user(sizeof(value), optlen)) + return -EFAULT; + if (copy_to_user(optval, &value, sizeof(value))) + return -EFAULT; + + return 0; +} + static int do_tls_getsockopt(struct sock *sk, int optname, char __user *optval, int __user *optlen) { @@ -547,6 +578,9 @@ static int do_tls_getsockopt(struct sock *sk, int optname, case TLS_TX_ZEROCOPY_RO: rc = do_tls_getsockopt_tx_zc(sk, optval, optlen); break; + case TLS_RX_EXPECT_NO_PAD: + rc = do_tls_getsockopt_no_pad(sk, optval, optlen); + break; default: rc = -ENOPROTOOPT; break; @@ -718,6 +752,38 @@ static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval, return 0; } +static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval, + unsigned int optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + u32 val; + int rc; + + if (ctx->prot_info.version != TLS_1_3_VERSION || + sockptr_is_null(optval) || optlen < sizeof(val)) + return -EINVAL; + + rc = copy_from_sockptr(&val, optval, sizeof(val)); + if (rc) + return -EFAULT; + if (val > 1) + return -EINVAL; + rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val)); + if (rc < 1) + return rc == 0 ? -EINVAL : rc; + + lock_sock(sk); + rc = -EINVAL; + if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) { + ctx->rx_no_pad = val; + tls_update_rx_zc_capable(ctx); + rc = 0; + } + release_sock(sk); + + return rc; +} + static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, unsigned int optlen) { @@ -736,6 +802,9 @@ static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, rc = do_tls_setsockopt_tx_zc(sk, optval, optlen); release_sock(sk); break; + case TLS_RX_EXPECT_NO_PAD: + rc = do_tls_setsockopt_no_pad(sk, optval, optlen); + break; default: rc = -ENOPROTOOPT; break; @@ -976,6 +1045,11 @@ static int tls_get_info(const struct sock *sk, struct sk_buff *skb) if (err) goto nla_failure; } + if (ctx->rx_no_pad) { + err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD); + if (err) + goto nla_failure; + } rcu_read_unlock(); nla_nest_end(skb, start); @@ -997,6 +1071,7 @@ static size_t tls_get_info_size(const struct sock *sk) nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ + nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */ 0; return size; -- cgit v1.2.3 From 5879031423089b2e19b769f30fc618af742264c3 Mon Sep 17 00:00:00 2001 From: Jakub Kicinski Date: Thu, 7 Jul 2022 18:03:13 -0700 Subject: tls: create an internal header include/net/tls.h is getting a little long, and is probably hard for driver authors to navigate. Split out the internals into a header which will live under net/tls/. While at it move some static inlines with a single user into the source files, add a few tls_ prefixes and fix spelling of 'proccess'. Signed-off-by: Jakub Kicinski --- include/net/tls.h | 277 +--------------------------------------- net/tls/tls.h | 290 ++++++++++++++++++++++++++++++++++++++++++ net/tls/tls_device.c | 3 +- net/tls/tls_device_fallback.c | 2 + net/tls/tls_main.c | 23 +++- net/tls/tls_proc.c | 2 + net/tls/tls_sw.c | 22 +++- net/tls/tls_toe.c | 2 + 8 files changed, 338 insertions(+), 283 deletions(-) create mode 100644 net/tls/tls.h (limited to 'net/tls/tls_main.c') diff --git a/include/net/tls.h b/include/net/tls.h index 9394c0459fe8..8742e13bc362 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -39,7 +39,6 @@ #include #include #include -#include #include #include #include @@ -50,6 +49,7 @@ #include #include +struct tls_rec; /* Maximum data size carried in a TLS record */ #define TLS_MAX_PAYLOAD_SIZE ((size_t)1 << 14) @@ -78,13 +78,6 @@ #define TLS_AES_CCM_IV_B0_BYTE 2 #define TLS_SM4_CCM_IV_B0_BYTE 2 -#define __TLS_INC_STATS(net, field) \ - __SNMP_INC_STATS((net)->mib.tls_statistics, field) -#define TLS_INC_STATS(net, field) \ - SNMP_INC_STATS((net)->mib.tls_statistics, field) -#define TLS_DEC_STATS(net, field) \ - SNMP_DEC_STATS((net)->mib.tls_statistics, field) - enum { TLS_BASE, TLS_SW, @@ -93,32 +86,6 @@ enum { TLS_NUM_CONFIG, }; -/* TLS records are maintained in 'struct tls_rec'. It stores the memory pages - * allocated or mapped for each TLS record. After encryption, the records are - * stores in a linked list. - */ -struct tls_rec { - struct list_head list; - int tx_ready; - int tx_flags; - - struct sk_msg msg_plaintext; - struct sk_msg msg_encrypted; - - /* AAD | msg_plaintext.sg.data | sg_tag */ - struct scatterlist sg_aead_in[2]; - /* AAD | msg_encrypted.sg.data (data contains overhead for hdr & iv & tag) */ - struct scatterlist sg_aead_out[2]; - - char content_type; - struct scatterlist sg_content_type; - - char aad_space[TLS_AAD_SPACE_SIZE]; - u8 iv_data[MAX_IV_SIZE]; - struct aead_request aead_req; - u8 aead_req_ctx[]; -}; - struct tx_work { struct delayed_work work; struct sock *sk; @@ -349,44 +316,6 @@ struct tls_offload_context_rx { #define TLS_OFFLOAD_CONTEXT_SIZE_RX \ (sizeof(struct tls_offload_context_rx) + TLS_DRIVER_STATE_SIZE_RX) -struct tls_context *tls_ctx_create(struct sock *sk); -void tls_ctx_free(struct sock *sk, struct tls_context *ctx); -void update_sk_prot(struct sock *sk, struct tls_context *ctx); - -int wait_on_pending_writer(struct sock *sk, long *timeo); -int tls_sk_query(struct sock *sk, int optname, char __user *optval, - int __user *optlen); -int tls_sk_attach(struct sock *sk, int optname, char __user *optval, - unsigned int optlen); -void tls_err_abort(struct sock *sk, int err); - -int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx); -void tls_update_rx_zc_capable(struct tls_context *tls_ctx); -void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx); -void tls_sw_strparser_done(struct tls_context *tls_ctx); -int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); -int tls_sw_sendpage_locked(struct sock *sk, struct page *page, - int offset, size_t size, int flags); -int tls_sw_sendpage(struct sock *sk, struct page *page, - int offset, size_t size, int flags); -void tls_sw_cancel_work_tx(struct tls_context *tls_ctx); -void tls_sw_release_resources_tx(struct sock *sk); -void tls_sw_free_ctx_tx(struct tls_context *tls_ctx); -void tls_sw_free_resources_rx(struct sock *sk); -void tls_sw_release_resources_rx(struct sock *sk); -void tls_sw_free_ctx_rx(struct tls_context *tls_ctx); -int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, - int flags, int *addr_len); -bool tls_sw_sock_is_readable(struct sock *sk); -ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, - struct pipe_inode_info *pipe, - size_t len, unsigned int flags); - -int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); -int tls_device_sendpage(struct sock *sk, struct page *page, - int offset, size_t size, int flags); -int tls_tx_records(struct sock *sk, int flags); - struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context, u32 seq, u64 *p_record_sn); @@ -400,58 +329,6 @@ static inline u32 tls_record_start_seq(struct tls_record_info *rec) return rec->end_seq - rec->len; } -int tls_push_sg(struct sock *sk, struct tls_context *ctx, - struct scatterlist *sg, u16 first_offset, - int flags); -int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, - int flags); -void tls_free_partial_record(struct sock *sk, struct tls_context *ctx); - -static inline struct tls_msg *tls_msg(struct sk_buff *skb) -{ - struct sk_skb_cb *scb = (struct sk_skb_cb *)skb->cb; - - return &scb->tls; -} - -static inline bool tls_is_partially_sent_record(struct tls_context *ctx) -{ - return !!ctx->partially_sent_record; -} - -static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx) -{ - return tls_ctx->pending_open_record_frags; -} - -static inline bool is_tx_ready(struct tls_sw_context_tx *ctx) -{ - struct tls_rec *rec; - - rec = list_first_entry(&ctx->tx_list, struct tls_rec, list); - if (!rec) - return false; - - return READ_ONCE(rec->tx_ready); -} - -static inline u16 tls_user_config(struct tls_context *ctx, bool tx) -{ - u16 config = tx ? ctx->tx_conf : ctx->rx_conf; - - switch (config) { - case TLS_BASE: - return TLS_CONF_BASE; - case TLS_SW: - return TLS_CONF_SW; - case TLS_HW: - return TLS_CONF_HW; - case TLS_HW_RECORD: - return TLS_CONF_HW_RECORD; - } - return 0; -} - struct sk_buff * tls_validate_xmit_skb(struct sock *sk, struct net_device *dev, struct sk_buff *skb); @@ -470,31 +347,6 @@ static inline bool tls_is_sk_tx_device_offloaded(struct sock *sk) #endif } -static inline bool tls_bigint_increment(unsigned char *seq, int len) -{ - int i; - - for (i = len - 1; i >= 0; i--) { - ++seq[i]; - if (seq[i] != 0) - break; - } - - return (i == -1); -} - -static inline void tls_bigint_subtract(unsigned char *seq, int n) -{ - u64 rcd_sn; - __be64 *p; - - BUILD_BUG_ON(TLS_MAX_REC_SEQ_SIZE != 8); - - p = (__be64 *)seq; - rcd_sn = be64_to_cpu(*p); - *p = cpu_to_be64(rcd_sn - n); -} - static inline struct tls_context *tls_get_ctx(const struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); @@ -505,82 +357,6 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk) return (__force void *)icsk->icsk_ulp_data; } -static inline void tls_advance_record_sn(struct sock *sk, - struct tls_prot_info *prot, - struct cipher_context *ctx) -{ - if (tls_bigint_increment(ctx->rec_seq, prot->rec_seq_size)) - tls_err_abort(sk, -EBADMSG); - - if (prot->version != TLS_1_3_VERSION && - prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) - tls_bigint_increment(ctx->iv + prot->salt_size, - prot->iv_size); -} - -static inline void tls_fill_prepend(struct tls_context *ctx, - char *buf, - size_t plaintext_len, - unsigned char record_type) -{ - struct tls_prot_info *prot = &ctx->prot_info; - size_t pkt_len, iv_size = prot->iv_size; - - pkt_len = plaintext_len + prot->tag_size; - if (prot->version != TLS_1_3_VERSION && - prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) { - pkt_len += iv_size; - - memcpy(buf + TLS_NONCE_OFFSET, - ctx->tx.iv + prot->salt_size, iv_size); - } - - /* we cover nonce explicit here as well, so buf should be of - * size KTLS_DTLS_HEADER_SIZE + KTLS_DTLS_NONCE_EXPLICIT_SIZE - */ - buf[0] = prot->version == TLS_1_3_VERSION ? - TLS_RECORD_TYPE_DATA : record_type; - /* Note that VERSION must be TLS_1_2 for both TLS1.2 and TLS1.3 */ - buf[1] = TLS_1_2_VERSION_MINOR; - buf[2] = TLS_1_2_VERSION_MAJOR; - /* we can use IV for nonce explicit according to spec */ - buf[3] = pkt_len >> 8; - buf[4] = pkt_len & 0xFF; -} - -static inline void tls_make_aad(char *buf, - size_t size, - char *record_sequence, - unsigned char record_type, - struct tls_prot_info *prot) -{ - if (prot->version != TLS_1_3_VERSION) { - memcpy(buf, record_sequence, prot->rec_seq_size); - buf += 8; - } else { - size += prot->tag_size; - } - - buf[0] = prot->version == TLS_1_3_VERSION ? - TLS_RECORD_TYPE_DATA : record_type; - buf[1] = TLS_1_2_VERSION_MAJOR; - buf[2] = TLS_1_2_VERSION_MINOR; - buf[3] = size >> 8; - buf[4] = size & 0xFF; -} - -static inline void xor_iv_with_seq(struct tls_prot_info *prot, char *iv, char *seq) -{ - int i; - - if (prot->version == TLS_1_3_VERSION || - prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) { - for (i = 0; i < 8; i++) - iv[i + 4] ^= seq[i]; - } -} - - static inline struct tls_sw_context_rx *tls_sw_ctx_rx( const struct tls_context *tls_ctx) { @@ -617,9 +393,6 @@ static inline bool tls_sw_has_ctx_rx(const struct sock *sk) return !!tls_sw_ctx_rx(ctx); } -void tls_sw_write_space(struct sock *sk, struct tls_context *ctx); -void tls_device_write_space(struct sock *sk, struct tls_context *ctx); - static inline struct tls_offload_context_rx * tls_offload_ctx_rx(const struct tls_context *tls_ctx) { @@ -694,31 +467,11 @@ static inline bool tls_offload_tx_resync_pending(struct sock *sk) return ret; } -int __net_init tls_proc_init(struct net *net); -void __net_exit tls_proc_fini(struct net *net); - -int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg, - unsigned char *record_type); -int decrypt_skb(struct sock *sk, struct sk_buff *skb, - struct scatterlist *sgout); struct sk_buff *tls_encrypt_skb(struct sk_buff *skb); -int tls_sw_fallback_init(struct sock *sk, - struct tls_offload_context_tx *offload_ctx, - struct tls_crypto_info *crypto_info); - #ifdef CONFIG_TLS_DEVICE -void tls_device_init(void); -void tls_device_cleanup(void); void tls_device_sk_destruct(struct sock *sk); -int tls_set_device_offload(struct sock *sk, struct tls_context *ctx); -void tls_device_free_resources_tx(struct sock *sk); -int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx); -void tls_device_offload_cleanup_rx(struct sock *sk); -void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq); void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq); -int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, - struct sk_buff *skb, struct strp_msg *rxm); static inline bool tls_is_sk_rx_device_offloaded(struct sock *sk) { @@ -727,33 +480,5 @@ static inline bool tls_is_sk_rx_device_offloaded(struct sock *sk) return false; return tls_get_ctx(sk)->rx_conf == TLS_HW; } -#else -static inline void tls_device_init(void) {} -static inline void tls_device_cleanup(void) {} - -static inline int -tls_set_device_offload(struct sock *sk, struct tls_context *ctx) -{ - return -EOPNOTSUPP; -} - -static inline void tls_device_free_resources_tx(struct sock *sk) {} - -static inline int -tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) -{ - return -EOPNOTSUPP; -} - -static inline void tls_device_offload_cleanup_rx(struct sock *sk) {} -static inline void -tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) {} - -static inline int -tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, - struct sk_buff *skb, struct strp_msg *rxm) -{ - return 0; -} #endif #endif /* _TLS_OFFLOAD_H */ diff --git a/net/tls/tls.h b/net/tls/tls.h new file mode 100644 index 000000000000..8005ee25157d --- /dev/null +++ b/net/tls/tls.h @@ -0,0 +1,290 @@ +/* + * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. + * Copyright (c) 2016-2017, Dave Watson . All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef _TLS_INT_H +#define _TLS_INT_H + +#include +#include +#include +#include + +#define __TLS_INC_STATS(net, field) \ + __SNMP_INC_STATS((net)->mib.tls_statistics, field) +#define TLS_INC_STATS(net, field) \ + SNMP_INC_STATS((net)->mib.tls_statistics, field) +#define TLS_DEC_STATS(net, field) \ + SNMP_DEC_STATS((net)->mib.tls_statistics, field) + +/* TLS records are maintained in 'struct tls_rec'. It stores the memory pages + * allocated or mapped for each TLS record. After encryption, the records are + * stores in a linked list. + */ +struct tls_rec { + struct list_head list; + int tx_ready; + int tx_flags; + + struct sk_msg msg_plaintext; + struct sk_msg msg_encrypted; + + /* AAD | msg_plaintext.sg.data | sg_tag */ + struct scatterlist sg_aead_in[2]; + /* AAD | msg_encrypted.sg.data (data contains overhead for hdr & iv & tag) */ + struct scatterlist sg_aead_out[2]; + + char content_type; + struct scatterlist sg_content_type; + + char aad_space[TLS_AAD_SPACE_SIZE]; + u8 iv_data[MAX_IV_SIZE]; + struct aead_request aead_req; + u8 aead_req_ctx[]; +}; + +int __net_init tls_proc_init(struct net *net); +void __net_exit tls_proc_fini(struct net *net); + +struct tls_context *tls_ctx_create(struct sock *sk); +void tls_ctx_free(struct sock *sk, struct tls_context *ctx); +void update_sk_prot(struct sock *sk, struct tls_context *ctx); + +int wait_on_pending_writer(struct sock *sk, long *timeo); +int tls_sk_query(struct sock *sk, int optname, char __user *optval, + int __user *optlen); +int tls_sk_attach(struct sock *sk, int optname, char __user *optval, + unsigned int optlen); +void tls_err_abort(struct sock *sk, int err); + +int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx); +void tls_update_rx_zc_capable(struct tls_context *tls_ctx); +void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx); +void tls_sw_strparser_done(struct tls_context *tls_ctx); +int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); +int tls_sw_sendpage_locked(struct sock *sk, struct page *page, + int offset, size_t size, int flags); +int tls_sw_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags); +void tls_sw_cancel_work_tx(struct tls_context *tls_ctx); +void tls_sw_release_resources_tx(struct sock *sk); +void tls_sw_free_ctx_tx(struct tls_context *tls_ctx); +void tls_sw_free_resources_rx(struct sock *sk); +void tls_sw_release_resources_rx(struct sock *sk); +void tls_sw_free_ctx_rx(struct tls_context *tls_ctx); +int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len); +bool tls_sw_sock_is_readable(struct sock *sk); +ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, + struct pipe_inode_info *pipe, + size_t len, unsigned int flags); + +int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); +int tls_device_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags); +int tls_tx_records(struct sock *sk, int flags); + +void tls_sw_write_space(struct sock *sk, struct tls_context *ctx); +void tls_device_write_space(struct sock *sk, struct tls_context *ctx); + +int tls_process_cmsg(struct sock *sk, struct msghdr *msg, + unsigned char *record_type); +int decrypt_skb(struct sock *sk, struct sk_buff *skb, + struct scatterlist *sgout); + +int tls_sw_fallback_init(struct sock *sk, + struct tls_offload_context_tx *offload_ctx, + struct tls_crypto_info *crypto_info); + +static inline struct tls_msg *tls_msg(struct sk_buff *skb) +{ + struct sk_skb_cb *scb = (struct sk_skb_cb *)skb->cb; + + return &scb->tls; +} + +#ifdef CONFIG_TLS_DEVICE +void tls_device_init(void); +void tls_device_cleanup(void); +int tls_set_device_offload(struct sock *sk, struct tls_context *ctx); +void tls_device_free_resources_tx(struct sock *sk); +int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx); +void tls_device_offload_cleanup_rx(struct sock *sk); +void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq); +int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, + struct sk_buff *skb, struct strp_msg *rxm); +#else +static inline void tls_device_init(void) {} +static inline void tls_device_cleanup(void) {} + +static inline int +tls_set_device_offload(struct sock *sk, struct tls_context *ctx) +{ + return -EOPNOTSUPP; +} + +static inline void tls_device_free_resources_tx(struct sock *sk) {} + +static inline int +tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) +{ + return -EOPNOTSUPP; +} + +static inline void tls_device_offload_cleanup_rx(struct sock *sk) {} +static inline void +tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) {} + +static inline int +tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, + struct sk_buff *skb, struct strp_msg *rxm) +{ + return 0; +} +#endif + +int tls_push_sg(struct sock *sk, struct tls_context *ctx, + struct scatterlist *sg, u16 first_offset, + int flags); +int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, + int flags); +void tls_free_partial_record(struct sock *sk, struct tls_context *ctx); + +static inline bool tls_is_partially_sent_record(struct tls_context *ctx) +{ + return !!ctx->partially_sent_record; +} + +static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx) +{ + return tls_ctx->pending_open_record_frags; +} + +static inline bool tls_bigint_increment(unsigned char *seq, int len) +{ + int i; + + for (i = len - 1; i >= 0; i--) { + ++seq[i]; + if (seq[i] != 0) + break; + } + + return (i == -1); +} + +static inline void tls_bigint_subtract(unsigned char *seq, int n) +{ + u64 rcd_sn; + __be64 *p; + + BUILD_BUG_ON(TLS_MAX_REC_SEQ_SIZE != 8); + + p = (__be64 *)seq; + rcd_sn = be64_to_cpu(*p); + *p = cpu_to_be64(rcd_sn - n); +} + +static inline void +tls_advance_record_sn(struct sock *sk, struct tls_prot_info *prot, + struct cipher_context *ctx) +{ + if (tls_bigint_increment(ctx->rec_seq, prot->rec_seq_size)) + tls_err_abort(sk, -EBADMSG); + + if (prot->version != TLS_1_3_VERSION && + prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) + tls_bigint_increment(ctx->iv + prot->salt_size, + prot->iv_size); +} + +static inline void +tls_xor_iv_with_seq(struct tls_prot_info *prot, char *iv, char *seq) +{ + int i; + + if (prot->version == TLS_1_3_VERSION || + prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) { + for (i = 0; i < 8; i++) + iv[i + 4] ^= seq[i]; + } +} + +static inline void +tls_fill_prepend(struct tls_context *ctx, char *buf, size_t plaintext_len, + unsigned char record_type) +{ + struct tls_prot_info *prot = &ctx->prot_info; + size_t pkt_len, iv_size = prot->iv_size; + + pkt_len = plaintext_len + prot->tag_size; + if (prot->version != TLS_1_3_VERSION && + prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) { + pkt_len += iv_size; + + memcpy(buf + TLS_NONCE_OFFSET, + ctx->tx.iv + prot->salt_size, iv_size); + } + + /* we cover nonce explicit here as well, so buf should be of + * size KTLS_DTLS_HEADER_SIZE + KTLS_DTLS_NONCE_EXPLICIT_SIZE + */ + buf[0] = prot->version == TLS_1_3_VERSION ? + TLS_RECORD_TYPE_DATA : record_type; + /* Note that VERSION must be TLS_1_2 for both TLS1.2 and TLS1.3 */ + buf[1] = TLS_1_2_VERSION_MINOR; + buf[2] = TLS_1_2_VERSION_MAJOR; + /* we can use IV for nonce explicit according to spec */ + buf[3] = pkt_len >> 8; + buf[4] = pkt_len & 0xFF; +} + +static inline +void tls_make_aad(char *buf, size_t size, char *record_sequence, + unsigned char record_type, struct tls_prot_info *prot) +{ + if (prot->version != TLS_1_3_VERSION) { + memcpy(buf, record_sequence, prot->rec_seq_size); + buf += 8; + } else { + size += prot->tag_size; + } + + buf[0] = prot->version == TLS_1_3_VERSION ? + TLS_RECORD_TYPE_DATA : record_type; + buf[1] = TLS_1_2_VERSION_MAJOR; + buf[2] = TLS_1_2_VERSION_MINOR; + buf[3] = size >> 8; + buf[4] = size & 0xFF; +} + +#endif diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index ec6f4b699a2b..227b92a3064a 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -38,6 +38,7 @@ #include #include +#include "tls.h" #include "trace.h" /* device_offload_lock is used to synchronize tls_dev_add @@ -562,7 +563,7 @@ int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) lock_sock(sk); if (unlikely(msg->msg_controllen)) { - rc = tls_proccess_cmsg(sk, msg, &record_type); + rc = tls_process_cmsg(sk, msg, &record_type); if (rc) goto out; } diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c index 3bae29ae57ca..618cee704217 100644 --- a/net/tls/tls_device_fallback.c +++ b/net/tls/tls_device_fallback.c @@ -34,6 +34,8 @@ #include #include +#include "tls.h" + static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk) { struct scatterlist *src = walk->sg; diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 1b3efc96db0b..f3d9dbfa507e 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -45,6 +45,8 @@ #include #include +#include "tls.h" + MODULE_AUTHOR("Mellanox Technologies"); MODULE_DESCRIPTION("Transport Layer Security Support"); MODULE_LICENSE("Dual BSD/GPL"); @@ -164,8 +166,8 @@ static int tls_handle_open_record(struct sock *sk, int flags) return 0; } -int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg, - unsigned char *record_type) +int tls_process_cmsg(struct sock *sk, struct msghdr *msg, + unsigned char *record_type) { struct cmsghdr *cmsg; int rc = -EINVAL; @@ -1003,6 +1005,23 @@ static void tls_update(struct sock *sk, struct proto *p, } } +static u16 tls_user_config(struct tls_context *ctx, bool tx) +{ + u16 config = tx ? ctx->tx_conf : ctx->rx_conf; + + switch (config) { + case TLS_BASE: + return TLS_CONF_BASE; + case TLS_SW: + return TLS_CONF_SW; + case TLS_HW: + return TLS_CONF_HW; + case TLS_HW_RECORD: + return TLS_CONF_HW_RECORD; + } + return 0; +} + static int tls_get_info(const struct sock *sk, struct sk_buff *skb) { u16 version, cipher_type; diff --git a/net/tls/tls_proc.c b/net/tls/tls_proc.c index 0c200000cc45..1246e52b48f6 100644 --- a/net/tls/tls_proc.c +++ b/net/tls/tls_proc.c @@ -6,6 +6,8 @@ #include #include +#include "tls.h" + #ifdef CONFIG_PROC_FS static const struct snmp_mib tls_mib_list[] = { SNMP_MIB_ITEM("TlsCurrTxSw", LINUX_MIB_TLSCURRTXSW), diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 21c76db8f9b3..1376f866734d 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -44,6 +44,8 @@ #include #include +#include "tls.h" + struct tls_decrypt_arg { bool zc; bool async; @@ -524,7 +526,8 @@ static int tls_do_encryption(struct sock *sk, memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv, prot->iv_size + prot->salt_size); - xor_iv_with_seq(prot, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq); + tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset, + tls_ctx->tx.rec_seq); sge->offset += prot->prepend_size; sge->length -= prot->prepend_size; @@ -961,7 +964,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) lock_sock(sk); if (unlikely(msg->msg_controllen)) { - ret = tls_proccess_cmsg(sk, msg, &record_type); + ret = tls_process_cmsg(sk, msg, &record_type); if (ret) { if (ret == -EINPROGRESS) num_async++; @@ -1495,7 +1498,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, goto exit_free; memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size); } - xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq); + tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq); /* Prepare AAD */ tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size + @@ -2267,12 +2270,23 @@ static void tx_work_handler(struct work_struct *work) mutex_unlock(&tls_ctx->tx_lock); } +static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx) +{ + struct tls_rec *rec; + + rec = list_first_entry(&ctx->tx_list, struct tls_rec, list); + if (!rec) + return false; + + return READ_ONCE(rec->tx_ready); +} + void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) { struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); /* Schedule the transmission if tx list is ready */ - if (is_tx_ready(tx_ctx) && + if (tls_is_tx_ready(tx_ctx) && !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) schedule_delayed_work(&tx_ctx->tx_work.work, 0); } diff --git a/net/tls/tls_toe.c b/net/tls/tls_toe.c index 7e1330f19165..825669e1ab47 100644 --- a/net/tls/tls_toe.c +++ b/net/tls/tls_toe.c @@ -38,6 +38,8 @@ #include #include +#include "tls.h" + static LIST_HEAD(device_list); static DEFINE_SPINLOCK(device_spinlock); -- cgit v1.2.3 From 57128e98c33d79285adc523e670fe02d11b7e5da Mon Sep 17 00:00:00 2001 From: Jakub Kicinski Date: Fri, 8 Jul 2022 19:52:54 -0700 Subject: tls: rx: fix the NoPad getsockopt Maxim reports do_tls_getsockopt_no_pad() will always return an error. Indeed looks like refactoring gone wrong - remove err and use value. Reported-by: Maxim Mikityanskiy Fixes: 88527790c079 ("tls: rx: add sockopt for enabling optimistic decrypt with TLS 1.3") Signed-off-by: Jakub Kicinski --- net/tls/tls_main.c | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'net/tls/tls_main.c') diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index f3d9dbfa507e..f71b46568112 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -539,8 +539,7 @@ static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, int __user *optlen) { struct tls_context *ctx = tls_get_ctx(sk); - unsigned int value; - int err, len; + int value, len; if (ctx->prot_info.version != TLS_1_3_VERSION) return -EINVAL; @@ -551,12 +550,12 @@ static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, return -EINVAL; lock_sock(sk); - err = -EINVAL; + value = -EINVAL; if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) value = ctx->rx_no_pad; release_sock(sk); - if (err) - return err; + if (value < 0) + return value; if (put_user(sizeof(value), optlen)) return -EFAULT; -- cgit v1.2.3 From 3d8c51b25a235e283e37750943bbf356ef187230 Mon Sep 17 00:00:00 2001 From: Tariq Toukan Date: Thu, 14 Jul 2022 10:07:54 +0300 Subject: net/tls: Check for errors in tls_device_init Add missing error checks in tls_device_init. Fixes: e8f69799810c ("net/tls: Add generic NIC offload infrastructure") Reported-by: Jakub Kicinski Reviewed-by: Maxim Mikityanskiy Signed-off-by: Tariq Toukan Link: https://lore.kernel.org/r/20220714070754.1428-1-tariqt@nvidia.com Signed-off-by: Jakub Kicinski --- include/net/tls.h | 4 ++-- net/tls/tls_device.c | 4 ++-- net/tls/tls_main.c | 7 ++++++- 3 files changed, 10 insertions(+), 5 deletions(-) (limited to 'net/tls/tls_main.c') diff --git a/include/net/tls.h b/include/net/tls.h index 8017f1703447..8bd938f98bdd 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -704,7 +704,7 @@ int tls_sw_fallback_init(struct sock *sk, struct tls_crypto_info *crypto_info); #ifdef CONFIG_TLS_DEVICE -void tls_device_init(void); +int tls_device_init(void); void tls_device_cleanup(void); void tls_device_sk_destruct(struct sock *sk); int tls_set_device_offload(struct sock *sk, struct tls_context *ctx); @@ -724,7 +724,7 @@ static inline bool tls_is_sk_rx_device_offloaded(struct sock *sk) return tls_get_ctx(sk)->rx_conf == TLS_HW; } #else -static inline void tls_device_init(void) {} +static inline int tls_device_init(void) { return 0; } static inline void tls_device_cleanup(void) {} static inline int diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index ec6f4b699a2b..ce827e79c66a 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -1419,9 +1419,9 @@ static struct notifier_block tls_dev_notifier = { .notifier_call = tls_dev_event, }; -void __init tls_device_init(void) +int __init tls_device_init(void) { - register_netdevice_notifier(&tls_dev_notifier); + return register_netdevice_notifier(&tls_dev_notifier); } void __exit tls_device_cleanup(void) diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 2ffede463e4a..d80ab3d1764e 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -1048,7 +1048,12 @@ static int __init tls_register(void) if (err) return err; - tls_device_init(); + err = tls_device_init(); + if (err) { + unregister_pernet_subsys(&tls_proc_ops); + return err; + } + tcp_register_ulp(&tcp_tls_ulp_ops); return 0; -- cgit v1.2.3 From 84c61fe1a75b4255df1e1e7c054c9e6d048da417 Mon Sep 17 00:00:00 2001 From: Jakub Kicinski Date: Fri, 22 Jul 2022 16:50:33 -0700 Subject: tls: rx: do not use the standard strparser TLS is a relatively poor fit for strparser. We pause the input every time a message is received, wait for a read which will decrypt the message, start the parser, repeat. strparser is built to delineate the messages, wrap them in individual skbs and let them float off into the stack or a different socket. TLS wants the data pages and nothing else. There's no need for TLS to keep cloning (and occasionally skb_unclone()'ing) the TCP rx queue. This patch uses a pre-allocated skb and attaches the skbs from the TCP rx queue to it as frags. TLS is careful never to modify the input skb without CoW'ing / detaching it first. Since we call TCP rx queue cleanup directly we also get back the benefit of skb deferred free. Overall this results in a 6% gain in my benchmarks. Acked-by: Paolo Abeni Signed-off-by: Jakub Kicinski --- include/net/tls.h | 19 ++- net/tls/tls.h | 24 ++- net/tls/tls_main.c | 20 ++- net/tls/tls_strp.c | 484 +++++++++++++++++++++++++++++++++++++++++++++++++++-- net/tls/tls_sw.c | 80 ++++----- 5 files changed, 558 insertions(+), 69 deletions(-) (limited to 'net/tls/tls_main.c') diff --git a/include/net/tls.h b/include/net/tls.h index 181c496b01b8..abb050b0df83 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -108,18 +108,33 @@ struct tls_sw_context_tx { unsigned long tx_bitmask; }; +struct tls_strparser { + struct sock *sk; + + u32 mark : 8; + u32 stopped : 1; + u32 copy_mode : 1; + u32 msg_ready : 1; + + struct strp_msg stm; + + struct sk_buff *anchor; + struct work_struct work; +}; + struct tls_sw_context_rx { struct crypto_aead *aead_recv; struct crypto_wait async_wait; - struct strparser strp; struct sk_buff_head rx_list; /* list of decrypted 'data' records */ void (*saved_data_ready)(struct sock *sk); - struct sk_buff *recv_pkt; u8 reader_present; u8 async_capable:1; u8 zc_capable:1; u8 reader_contended:1; + + struct tls_strparser strp; + atomic_t decrypt_pending; /* protect crypto_wait with decrypt_pending*/ spinlock_t decrypt_compl_lock; diff --git a/net/tls/tls.h b/net/tls/tls.h index 154a3773e785..0e840a0c3437 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -1,4 +1,5 @@ /* + * Copyright (c) 2016 Tom Herbert * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. * Copyright (c) 2016-2017, Dave Watson . All rights reserved. * @@ -127,10 +128,24 @@ int tls_sw_fallback_init(struct sock *sk, struct tls_offload_context_tx *offload_ctx, struct tls_crypto_info *crypto_info); +int tls_strp_dev_init(void); +void tls_strp_dev_exit(void); + +void tls_strp_done(struct tls_strparser *strp); +void tls_strp_stop(struct tls_strparser *strp); +int tls_strp_init(struct tls_strparser *strp, struct sock *sk); +void tls_strp_data_ready(struct tls_strparser *strp); + +void tls_strp_check_rcv(struct tls_strparser *strp); +void tls_strp_msg_done(struct tls_strparser *strp); + +int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb); +void tls_rx_msg_ready(struct tls_strparser *strp); + +void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh); int tls_strp_msg_cow(struct tls_sw_context_rx *ctx); struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx); -int tls_strp_msg_hold(struct sock *sk, struct sk_buff *skb, - struct sk_buff_head *dst); +int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst); static inline struct tls_msg *tls_msg(struct sk_buff *skb) { @@ -141,12 +156,13 @@ static inline struct tls_msg *tls_msg(struct sk_buff *skb) static inline struct sk_buff *tls_strp_msg(struct tls_sw_context_rx *ctx) { - return ctx->recv_pkt; + DEBUG_NET_WARN_ON_ONCE(!ctx->strp.msg_ready || !ctx->strp.anchor->len); + return ctx->strp.anchor; } static inline bool tls_strp_msg_ready(struct tls_sw_context_rx *ctx) { - return ctx->recv_pkt; + return ctx->strp.msg_ready; } #ifdef CONFIG_TLS_DEVICE diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 9703636cfc60..08ddf9d837ae 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -725,6 +725,10 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, if (tx) { ctx->sk_write_space = sk->sk_write_space; sk->sk_write_space = tls_write_space; + } else { + struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx); + + tls_strp_check_rcv(&rx_ctx->strp); } return 0; @@ -1141,20 +1145,28 @@ static int __init tls_register(void) if (err) return err; + err = tls_strp_dev_init(); + if (err) + goto err_pernet; + err = tls_device_init(); - if (err) { - unregister_pernet_subsys(&tls_proc_ops); - return err; - } + if (err) + goto err_strp; tcp_register_ulp(&tcp_tls_ulp_ops); return 0; +err_strp: + tls_strp_dev_exit(); +err_pernet: + unregister_pernet_subsys(&tls_proc_ops); + return err; } static void __exit tls_unregister(void) { tcp_unregister_ulp(&tcp_tls_ulp_ops); + tls_strp_dev_exit(); tls_device_cleanup(); unregister_pernet_subsys(&tls_proc_ops); } diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c index d9bb4f23f01a..b945288c312e 100644 --- a/net/tls/tls_strp.c +++ b/net/tls/tls_strp.c @@ -1,37 +1,493 @@ // SPDX-License-Identifier: GPL-2.0-only +/* Copyright (c) 2016 Tom Herbert */ #include +#include +#include +#include +#include +#include #include "tls.h" -struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx) +static struct workqueue_struct *tls_strp_wq; + +static void tls_strp_abort_strp(struct tls_strparser *strp, int err) +{ + if (strp->stopped) + return; + + strp->stopped = 1; + + /* Report an error on the lower socket */ + strp->sk->sk_err = -err; + sk_error_report(strp->sk); +} + +static void tls_strp_anchor_free(struct tls_strparser *strp) { + struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); + + DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); + shinfo->frag_list = NULL; + consume_skb(strp->anchor); + strp->anchor = NULL; +} + +/* Create a new skb with the contents of input copied to its page frags */ +static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp) +{ + struct strp_msg *rxm; struct sk_buff *skb; + int i, err, offset; + + skb = alloc_skb_with_frags(0, strp->anchor->len, TLS_PAGE_ORDER, + &err, strp->sk->sk_allocation); + if (!skb) + return NULL; + + offset = strp->stm.offset; + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; - skb = ctx->recv_pkt; - ctx->recv_pkt = NULL; + WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset, + skb_frag_address(frag), + skb_frag_size(frag))); + offset += skb_frag_size(frag); + } + + skb_copy_header(skb, strp->anchor); + rxm = strp_msg(skb); + rxm->offset = 0; return skb; } +/* Steal the input skb, input msg is invalid after calling this function */ +struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx) +{ + struct tls_strparser *strp = &ctx->strp; + +#ifdef CONFIG_TLS_DEVICE + DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted); +#else + /* This function turns an input into an output, + * that can only happen if we have offload. + */ + WARN_ON(1); +#endif + + if (strp->copy_mode) { + struct sk_buff *skb; + + /* Replace anchor with an empty skb, this is a little + * dangerous but __tls_cur_msg() warns on empty skbs + * so hopefully we'll catch abuses. + */ + skb = alloc_skb(0, strp->sk->sk_allocation); + if (!skb) + return NULL; + + swap(strp->anchor, skb); + return skb; + } + + return tls_strp_msg_make_copy(strp); +} + +/* Force the input skb to be in copy mode. The data ownership remains + * with the input skb itself (meaning unpause will wipe it) but it can + * be modified. + */ int tls_strp_msg_cow(struct tls_sw_context_rx *ctx) { - struct sk_buff *unused; - int nsg; + struct tls_strparser *strp = &ctx->strp; + struct sk_buff *skb; + + if (strp->copy_mode) + return 0; + + skb = tls_strp_msg_make_copy(strp); + if (!skb) + return -ENOMEM; + + tls_strp_anchor_free(strp); + strp->anchor = skb; + + tcp_read_done(strp->sk, strp->stm.full_len); + strp->copy_mode = 1; + + return 0; +} + +/* Make a clone (in the skb sense) of the input msg to keep a reference + * to the underlying data. The reference-holding skbs get placed on + * @dst. + */ +int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst) +{ + struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); + + if (strp->copy_mode) { + struct sk_buff *skb; + + WARN_ON_ONCE(!shinfo->nr_frags); + + /* We can't skb_clone() the anchor, it gets wiped by unpause */ + skb = alloc_skb(0, strp->sk->sk_allocation); + if (!skb) + return -ENOMEM; + + __skb_queue_tail(dst, strp->anchor); + strp->anchor = skb; + } else { + struct sk_buff *iter, *clone; + int chunk, len, offset; + + offset = strp->stm.offset; + len = strp->stm.full_len; + iter = shinfo->frag_list; + + while (len > 0) { + if (iter->len <= offset) { + offset -= iter->len; + goto next; + } + + chunk = iter->len - offset; + offset = 0; + + clone = skb_clone(iter, strp->sk->sk_allocation); + if (!clone) + return -ENOMEM; + __skb_queue_tail(dst, clone); + + len -= chunk; +next: + iter = iter->next; + } + } + + return 0; +} + +static void tls_strp_flush_anchor_copy(struct tls_strparser *strp) +{ + struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); + int i; + + DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); + + for (i = 0; i < shinfo->nr_frags; i++) + __skb_frag_unref(&shinfo->frags[i], false); + shinfo->nr_frags = 0; + strp->copy_mode = 0; +} + +static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, + unsigned int offset, size_t in_len) +{ + struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data; + size_t sz, len, chunk; + struct sk_buff *skb; + skb_frag_t *frag; + + if (strp->msg_ready) + return 0; + + skb = strp->anchor; + frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; + + len = in_len; + /* First make sure we got the header */ + if (!strp->stm.full_len) { + /* Assume one page is more than enough for headers */ + chunk = min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag)); + WARN_ON_ONCE(skb_copy_bits(in_skb, offset, + skb_frag_address(frag) + + skb_frag_size(frag), + chunk)); + + sz = tls_rx_msg_size(strp, strp->anchor); + if (sz < 0) { + desc->error = sz; + return 0; + } + + /* We may have over-read, sz == 0 is guaranteed under-read */ + if (sz > 0) + chunk = min_t(size_t, chunk, sz - skb->len); + + skb->len += chunk; + skb->data_len += chunk; + skb_frag_size_add(frag, chunk); + frag++; + len -= chunk; + offset += chunk; + + strp->stm.full_len = sz; + if (!strp->stm.full_len) + goto read_done; + } + + /* Load up more data */ + while (len && strp->stm.full_len > skb->len) { + chunk = min_t(size_t, len, strp->stm.full_len - skb->len); + chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag)); + WARN_ON_ONCE(skb_copy_bits(in_skb, offset, + skb_frag_address(frag) + + skb_frag_size(frag), + chunk)); + + skb->len += chunk; + skb->data_len += chunk; + skb_frag_size_add(frag, chunk); + frag++; + len -= chunk; + offset += chunk; + } + + if (strp->stm.full_len == skb->len) { + desc->count = 0; + + strp->msg_ready = 1; + tls_rx_msg_ready(strp); + } + +read_done: + return in_len - len; +} + +static int tls_strp_read_copyin(struct tls_strparser *strp) +{ + struct socket *sock = strp->sk->sk_socket; + read_descriptor_t desc; + + desc.arg.data = strp; + desc.error = 0; + desc.count = 1; /* give more than one skb per call */ + + /* sk should be locked here, so okay to do read_sock */ + sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin); + + return desc.error; +} + +static int tls_strp_read_short(struct tls_strparser *strp) +{ + struct skb_shared_info *shinfo; + struct page *page; + int need_spc, len; + + /* If the rbuf is small or rcv window has collapsed to 0 we need + * to read the data out. Otherwise the connection will stall. + * Without pressure threshold of INT_MAX will never be ready. + */ + if (likely(!tcp_epollin_ready(strp->sk, INT_MAX))) + return 0; + + shinfo = skb_shinfo(strp->anchor); + shinfo->frag_list = NULL; + + /* If we don't know the length go max plus page for cipher overhead */ + need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; + + for (len = need_spc; len > 0; len -= PAGE_SIZE) { + page = alloc_page(strp->sk->sk_allocation); + if (!page) { + tls_strp_flush_anchor_copy(strp); + return -ENOMEM; + } + + skb_fill_page_desc(strp->anchor, shinfo->nr_frags++, + page, 0, 0); + } + + strp->copy_mode = 1; + strp->stm.offset = 0; + + strp->anchor->len = 0; + strp->anchor->data_len = 0; + strp->anchor->truesize = round_up(need_spc, PAGE_SIZE); + + tls_strp_read_copyin(strp); + + return 0; +} + +static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len) +{ + struct tcp_sock *tp = tcp_sk(strp->sk); + struct sk_buff *first; + u32 offset; + + first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset); + if (WARN_ON_ONCE(!first)) + return; + + /* Bestow the state onto the anchor */ + strp->anchor->len = offset + len; + strp->anchor->data_len = offset + len; + strp->anchor->truesize = offset + len; + + skb_shinfo(strp->anchor)->frag_list = first; + + skb_copy_header(strp->anchor, first); + strp->anchor->destructor = NULL; + + strp->stm.offset = offset; +} + +void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) +{ + struct strp_msg *rxm; + struct tls_msg *tlm; + + DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready); + DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len); + + if (!strp->copy_mode && force_refresh) { + if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len)) + return; + + tls_strp_load_anchor_with_queue(strp, strp->stm.full_len); + } + + rxm = strp_msg(strp->anchor); + rxm->full_len = strp->stm.full_len; + rxm->offset = strp->stm.offset; + tlm = tls_msg(strp->anchor); + tlm->control = strp->mark; +} + +/* Called with lock held on lower socket */ +static int tls_strp_read_sock(struct tls_strparser *strp) +{ + int sz, inq; + + inq = tcp_inq(strp->sk); + if (inq < 1) + return 0; + + if (unlikely(strp->copy_mode)) + return tls_strp_read_copyin(strp); + + if (inq < strp->stm.full_len) + return tls_strp_read_short(strp); + + if (!strp->stm.full_len) { + tls_strp_load_anchor_with_queue(strp, inq); + + sz = tls_rx_msg_size(strp, strp->anchor); + if (sz < 0) { + tls_strp_abort_strp(strp, sz); + return sz; + } + + strp->stm.full_len = sz; + + if (!strp->stm.full_len || inq < strp->stm.full_len) + return tls_strp_read_short(strp); + } + + strp->msg_ready = 1; + tls_rx_msg_ready(strp); + + return 0; +} + +void tls_strp_check_rcv(struct tls_strparser *strp) +{ + if (unlikely(strp->stopped) || strp->msg_ready) + return; + + if (tls_strp_read_sock(strp) == -ENOMEM) + queue_work(tls_strp_wq, &strp->work); +} + +/* Lower sock lock held */ +void tls_strp_data_ready(struct tls_strparser *strp) +{ + /* This check is needed to synchronize with do_tls_strp_work. + * do_tls_strp_work acquires a process lock (lock_sock) whereas + * the lock held here is bh_lock_sock. The two locks can be + * held by different threads at the same time, but bh_lock_sock + * allows a thread in BH context to safely check if the process + * lock is held. In this case, if the lock is held, queue work. + */ + if (sock_owned_by_user_nocheck(strp->sk)) { + queue_work(tls_strp_wq, &strp->work); + return; + } + + tls_strp_check_rcv(strp); +} + +static void tls_strp_work(struct work_struct *w) +{ + struct tls_strparser *strp = + container_of(w, struct tls_strparser, work); + + lock_sock(strp->sk); + tls_strp_check_rcv(strp); + release_sock(strp->sk); +} + +void tls_strp_msg_done(struct tls_strparser *strp) +{ + WARN_ON(!strp->stm.full_len); + + if (likely(!strp->copy_mode)) + tcp_read_done(strp->sk, strp->stm.full_len); + else + tls_strp_flush_anchor_copy(strp); + + strp->msg_ready = 0; + memset(&strp->stm, 0, sizeof(strp->stm)); + + tls_strp_check_rcv(strp); +} + +void tls_strp_stop(struct tls_strparser *strp) +{ + strp->stopped = 1; +} + +int tls_strp_init(struct tls_strparser *strp, struct sock *sk) +{ + memset(strp, 0, sizeof(*strp)); + + strp->sk = sk; + + strp->anchor = alloc_skb(0, GFP_KERNEL); + if (!strp->anchor) + return -ENOMEM; + + INIT_WORK(&strp->work, tls_strp_work); - nsg = skb_cow_data(ctx->recv_pkt, 0, &unused); - if (nsg < 0) - return nsg; return 0; } -int tls_strp_msg_hold(struct sock *sk, struct sk_buff *skb, - struct sk_buff_head *dst) +/* strp must already be stopped so that tls_strp_recv will no longer be called. + * Note that tls_strp_done is not called with the lower socket held. + */ +void tls_strp_done(struct tls_strparser *strp) { - struct sk_buff *clone; + WARN_ON(!strp->stopped); - clone = skb_clone(skb, sk->sk_allocation); - if (!clone) + cancel_work_sync(&strp->work); + tls_strp_anchor_free(strp); +} + +int __init tls_strp_dev_init(void) +{ + tls_strp_wq = create_singlethread_workqueue("kstrp"); + if (unlikely(!tls_strp_wq)) return -ENOMEM; - __skb_queue_tail(dst, clone); + return 0; } + +void tls_strp_dev_exit(void) +{ + destroy_workqueue(tls_strp_wq); +} diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index bd4486819e64..0fc24a5ce208 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -1283,7 +1283,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, static int tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, - long timeo) + bool released, long timeo) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); @@ -1297,7 +1297,7 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, return sock_error(sk); if (!skb_queue_empty(&sk->sk_receive_queue)) { - __strp_unpause(&ctx->strp); + tls_strp_check_rcv(&ctx->strp); if (tls_strp_msg_ready(ctx)) break; } @@ -1311,6 +1311,7 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, if (nonblock || !timeo) return -EAGAIN; + released = true; add_wait_queue(sk_sleep(sk), &wait); sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); sk_wait_event(sk, &timeo, @@ -1325,6 +1326,8 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, return sock_intr_errno(timeo); } + tls_strp_msg_load(&ctx->strp, released); + return 1; } @@ -1570,7 +1573,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, clear_skb = NULL; if (unlikely(darg->async)) { - err = tls_strp_msg_hold(sk, skb, &ctx->async_hold); + err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold); if (err) __skb_queue_tail(&ctx->async_hold, darg->skb); return err; @@ -1734,9 +1737,7 @@ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, static void tls_rx_rec_done(struct tls_sw_context_rx *ctx) { - consume_skb(ctx->recv_pkt); - ctx->recv_pkt = NULL; - __strp_unpause(&ctx->strp); + tls_strp_msg_done(&ctx->strp); } /* This function traverses the rx_list in tls receive context to copies the @@ -1823,7 +1824,7 @@ out: return copied ? : err; } -static void +static bool tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot, size_t len_left, size_t decrypted, ssize_t done, size_t *flushed_at) @@ -1831,14 +1832,14 @@ tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot, size_t max_rec; if (len_left <= decrypted) - return; + return false; max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE; if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec) - return; + return false; *flushed_at = done; - sk_flush_backlog(sk); + return sk_flush_backlog(sk); } static long tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx, @@ -1916,6 +1917,7 @@ int tls_sw_recvmsg(struct sock *sk, long timeo; bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool is_peek = flags & MSG_PEEK; + bool released = true; bool bpf_strp_enabled; bool zc_capable; @@ -1952,7 +1954,8 @@ int tls_sw_recvmsg(struct sock *sk, struct tls_decrypt_arg darg; int to_decrypt, chunk; - err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo); + err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, released, + timeo); if (err <= 0) { if (psock) { chunk = sk_msg_recvmsg(sk, psock, msg, len, @@ -1968,8 +1971,8 @@ int tls_sw_recvmsg(struct sock *sk, memset(&darg.inargs, 0, sizeof(darg.inargs)); - rxm = strp_msg(ctx->recv_pkt); - tlm = tls_msg(ctx->recv_pkt); + rxm = strp_msg(tls_strp_msg(ctx)); + tlm = tls_msg(tls_strp_msg(ctx)); to_decrypt = rxm->full_len - prot->overhead_size; @@ -2008,8 +2011,9 @@ put_on_rx_list_err: } /* periodically flush backlog, and feed strparser */ - tls_read_flush_backlog(sk, prot, len, to_decrypt, - decrypted + copied, &flushed_at); + released = tls_read_flush_backlog(sk, prot, len, to_decrypt, + decrypted + copied, + &flushed_at); /* TLS 1.3 may have updated the length by more than overhead */ rxm = strp_msg(darg.skb); @@ -2020,7 +2024,7 @@ put_on_rx_list_err: bool partially_consumed = chunk > len; struct sk_buff *skb = darg.skb; - DEBUG_NET_WARN_ON_ONCE(darg.skb == ctx->recv_pkt); + DEBUG_NET_WARN_ON_ONCE(darg.skb == tls_strp_msg(ctx)); if (async) { /* TLS 1.2-only, to_decrypt must be text len */ @@ -2034,6 +2038,7 @@ put_on_rx_list: } if (bpf_strp_enabled) { + released = true; err = sk_psock_tls_strp_read(psock, skb); if (err != __SK_PASS) { rxm->offset = rxm->offset + rxm->full_len; @@ -2140,7 +2145,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, struct tls_decrypt_arg darg; err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK, - timeo); + true, timeo); if (err <= 0) goto splice_read_end; @@ -2204,19 +2209,17 @@ bool tls_sw_sock_is_readable(struct sock *sk) !skb_queue_empty(&ctx->rx_list); } -static int tls_read_size(struct strparser *strp, struct sk_buff *skb) +int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_prot_info *prot = &tls_ctx->prot_info; char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; - struct strp_msg *rxm = strp_msg(skb); - struct tls_msg *tlm = tls_msg(skb); size_t cipher_overhead; size_t data_len = 0; int ret; /* Verify that we have a full TLS header, or wait for more data */ - if (rxm->offset + prot->prepend_size > skb->len) + if (strp->stm.offset + prot->prepend_size > skb->len) return 0; /* Sanity-check size of on-stack buffer. */ @@ -2226,11 +2229,11 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) } /* Linearize header to local buffer */ - ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size); + ret = skb_copy_bits(skb, strp->stm.offset, header, prot->prepend_size); if (ret < 0) goto read_failure; - tlm->control = header[0]; + strp->mark = header[0]; data_len = ((header[4] & 0xFF) | (header[3] << 8)); @@ -2257,7 +2260,7 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) } tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE, - TCP_SKB_CB(skb)->seq + rxm->offset); + TCP_SKB_CB(skb)->seq + strp->stm.offset); return data_len + TLS_HEADER_SIZE; read_failure: @@ -2266,14 +2269,11 @@ read_failure: return ret; } -static void tls_queue(struct strparser *strp, struct sk_buff *skb) +void tls_rx_msg_ready(struct tls_strparser *strp) { - struct tls_context *tls_ctx = tls_get_ctx(strp->sk); - struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - - ctx->recv_pkt = skb; - strp_pause(strp); + struct tls_sw_context_rx *ctx; + ctx = container_of(strp, struct tls_sw_context_rx, strp); ctx->saved_data_ready(strp->sk); } @@ -2283,7 +2283,7 @@ static void tls_data_ready(struct sock *sk) struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct sk_psock *psock; - strp_data_ready(&ctx->strp); + tls_strp_data_ready(&ctx->strp); psock = sk_psock_get(sk); if (psock) { @@ -2359,13 +2359,11 @@ void tls_sw_release_resources_rx(struct sock *sk) kfree(tls_ctx->rx.iv); if (ctx->aead_recv) { - kfree_skb(ctx->recv_pkt); - ctx->recv_pkt = NULL; __skb_queue_purge(&ctx->rx_list); crypto_free_aead(ctx->aead_recv); - strp_stop(&ctx->strp); + tls_strp_stop(&ctx->strp); /* If tls_sw_strparser_arm() was not called (cleanup paths) - * we still want to strp_stop(), but sk->sk_data_ready was + * we still want to tls_strp_stop(), but sk->sk_data_ready was * never swapped. */ if (ctx->saved_data_ready) { @@ -2380,7 +2378,7 @@ void tls_sw_strparser_done(struct tls_context *tls_ctx) { struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - strp_done(&ctx->strp); + tls_strp_done(&ctx->strp); } void tls_sw_free_ctx_rx(struct tls_context *tls_ctx) @@ -2453,8 +2451,6 @@ void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx) rx_ctx->saved_data_ready = sk->sk_data_ready; sk->sk_data_ready = tls_data_ready; write_unlock_bh(&sk->sk_callback_lock); - - strp_check_rcv(&rx_ctx->strp); } void tls_update_rx_zc_capable(struct tls_context *tls_ctx) @@ -2474,7 +2470,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) struct tls_sw_context_rx *sw_ctx_rx = NULL; struct cipher_context *cctx; struct crypto_aead **aead; - struct strp_callbacks cb; u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size; struct crypto_tfm *tfm; char *iv, *rec_seq, *key, *salt, *cipher_name; @@ -2708,12 +2703,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) crypto_info->version != TLS_1_3_VERSION && !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC); - /* Set up strparser */ - memset(&cb, 0, sizeof(cb)); - cb.rcv_msg = tls_queue; - cb.parse_msg = tls_read_size; - - strp_init(&sw_ctx_rx->strp, sk, &cb); + tls_strp_init(&sw_ctx_rx->strp, sk); } goto out; -- cgit v1.2.3