diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls.h | 2 | ||||
-rw-r--r-- | net/tls/tls_strp.c | 17 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 23 |
3 files changed, 33 insertions, 9 deletions
diff --git a/net/tls/tls.h b/net/tls/tls.h index 774859b63f0d..4e077068e6d9 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -196,7 +196,7 @@ void tls_strp_msg_done(struct tls_strparser *strp); int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb); void tls_rx_msg_ready(struct tls_strparser *strp); -void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh); +bool tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh); int tls_strp_msg_cow(struct tls_sw_context_rx *ctx); struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx); int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst); diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c index 77e33e1e340e..d71643b494a1 100644 --- a/net/tls/tls_strp.c +++ b/net/tls/tls_strp.c @@ -396,7 +396,6 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) return 0; shinfo = skb_shinfo(strp->anchor); - shinfo->frag_list = NULL; /* If we don't know the length go max plus page for cipher overhead */ need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; @@ -412,6 +411,8 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) page, 0, 0); } + shinfo->frag_list = NULL; + strp->copy_mode = 1; strp->stm.offset = 0; @@ -474,7 +475,7 @@ static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len) strp->stm.offset = offset; } -void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) +bool tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) { struct strp_msg *rxm; struct tls_msg *tlm; @@ -483,8 +484,11 @@ void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len); if (!strp->copy_mode && force_refresh) { - if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len)) - return; + if (unlikely(tcp_inq(strp->sk) < strp->stm.full_len)) { + WRITE_ONCE(strp->msg_ready, 0); + memset(&strp->stm, 0, sizeof(strp->stm)); + return false; + } tls_strp_load_anchor_with_queue(strp, strp->stm.full_len); } @@ -494,6 +498,8 @@ void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) rxm->offset = strp->stm.offset; tlm = tls_msg(strp->anchor); tlm->control = strp->mark; + + return true; } /* Called with lock held on lower socket */ @@ -511,9 +517,8 @@ static int tls_strp_read_sock(struct tls_strparser *strp) if (inq < strp->stm.full_len) return tls_strp_read_copy(strp, true); + tls_strp_load_anchor_with_queue(strp, inq); if (!strp->stm.full_len) { - tls_strp_load_anchor_with_queue(strp, inq); - sz = tls_rx_msg_size(strp, strp->anchor); if (sz < 0) { tls_strp_abort_strp(strp, sz); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index fc88e34b7f33..bac65d0d4e3e 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -872,6 +872,19 @@ more_data: delta = msg->sg.size; psock->eval = sk_psock_msg_verdict(sk, psock, msg); delta -= msg->sg.size; + + if ((s32)delta > 0) { + /* It indicates that we executed bpf_msg_pop_data(), + * causing the plaintext data size to decrease. + * Therefore the encrypted data size also needs to + * correspondingly decrease. We only need to subtract + * delta to calculate the new ciphertext length since + * ktls does not support block encryption. + */ + struct sk_msg *enc = &ctx->open_rec->msg_encrypted; + + sk_msg_trim(sk, enc, enc->sg.size - delta); + } } if (msg->cork_bytes && msg->cork_bytes > msg->sg.size && !enospc && !full_record) { @@ -1371,7 +1384,8 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, return sock_intr_errno(timeo); } - tls_strp_msg_load(&ctx->strp, released); + if (unlikely(!tls_strp_msg_load(&ctx->strp, released))) + return tls_rx_rec_wait(sk, psock, nonblock, false); return 1; } @@ -1794,6 +1808,9 @@ int decrypt_skb(struct sock *sk, struct scatterlist *sgout) return tls_decrypt_sg(sk, NULL, sgout, &darg); } +/* All records returned from a recvmsg() call must have the same type. + * 0 is not a valid content type. Use it as "no type reported, yet". + */ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, u8 *control) { @@ -2037,8 +2054,10 @@ int tls_sw_recvmsg(struct sock *sk, if (err < 0) goto end; + /* process_rx_list() will set @control if it processed any records */ copied = err; - if (len <= copied || (copied && control != TLS_RECORD_TYPE_DATA) || rx_more) + if (len <= copied || rx_more || + (control && control != TLS_RECORD_TYPE_DATA)) goto end; target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); |