diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls_strp.c | 6 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 44 |
2 files changed, 42 insertions, 8 deletions
diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c index f37f4a0fcd3c..b7ed76c0e576 100644 --- a/net/tls/tls_strp.c +++ b/net/tls/tls_strp.c @@ -396,7 +396,6 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) return 0; shinfo = skb_shinfo(strp->anchor); - shinfo->frag_list = NULL; /* If we don't know the length go max plus page for cipher overhead */ need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; @@ -412,6 +411,8 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) page, 0, 0); } + shinfo->frag_list = NULL; + strp->copy_mode = 1; strp->stm.offset = 0; @@ -511,9 +512,8 @@ static int tls_strp_read_sock(struct tls_strparser *strp) if (inq < strp->stm.full_len) return tls_strp_read_copy(strp, true); + tls_strp_load_anchor_with_queue(strp, inq); if (!strp->stm.full_len) { - tls_strp_load_anchor_with_queue(strp, inq); - sz = tls_rx_msg_size(strp, strp->anchor); if (sz < 0) { tls_strp_abort_strp(strp, sz); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 5310441240e7..96e62e8f1dad 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -274,9 +274,15 @@ static int tls_do_decryption(struct sock *sk, DEBUG_NET_WARN_ON_ONCE(atomic_read(&ctx->decrypt_pending) < 1); atomic_inc(&ctx->decrypt_pending); } else { + DECLARE_CRYPTO_WAIT(wait); + aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, - crypto_req_done, &ctx->async_wait); + crypto_req_done, &wait); + ret = crypto_aead_decrypt(aead_req); + if (ret == -EINPROGRESS || ret == -EBUSY) + ret = crypto_wait_req(ret, &wait); + return ret; } ret = crypto_aead_decrypt(aead_req); @@ -289,7 +295,6 @@ static int tls_do_decryption(struct sock *sk, /* all completions have run, we're not doing async anymore */ darg->async = false; return ret; - ret = ret ?: -EINPROGRESS; } atomic_dec(&ctx->decrypt_pending); @@ -868,6 +873,19 @@ more_data: delta = msg->sg.size; psock->eval = sk_psock_msg_verdict(sk, psock, msg); delta -= msg->sg.size; + + if ((s32)delta > 0) { + /* It indicates that we executed bpf_msg_pop_data(), + * causing the plaintext data size to decrease. + * Therefore the encrypted data size also needs to + * correspondingly decrease. We only need to subtract + * delta to calculate the new ciphertext length since + * ktls does not support block encryption. + */ + struct sk_msg *enc = &ctx->open_rec->msg_encrypted; + + sk_msg_trim(sk, enc, enc->sg.size - delta); + } } if (msg->cork_bytes && msg->cork_bytes > msg->sg.size && !enospc && !full_record) { @@ -904,6 +922,13 @@ more_data: &msg_redir, send, flags); lock_sock(sk); if (err < 0) { + /* Regardless of whether the data represented by + * msg_redir is sent successfully, we have already + * uncharged it via sk_msg_return_zero(). The + * msg->sg.size represents the remaining unprocessed + * data, which needs to be uncharged here. + */ + sk_mem_uncharge(sk, msg->sg.size); *copied -= sk_msg_free_nocharge(sk, &msg_redir); msg->sg.size = 0; } @@ -1075,9 +1100,13 @@ alloc_encrypted: num_async++; else if (ret == -ENOMEM) goto wait_for_memory; - else if (ctx->open_rec && ret == -ENOSPC) + else if (ctx->open_rec && ret == -ENOSPC) { + if (msg_pl->cork_bytes) { + ret = 0; + goto send_end; + } goto rollback_iter; - else if (ret != -EAGAIN) + } else if (ret != -EAGAIN) goto send_end; } continue; @@ -1835,6 +1864,9 @@ int decrypt_skb(struct sock *sk, struct scatterlist *sgout) return tls_decrypt_sg(sk, NULL, sgout, &darg); } +/* All records returned from a recvmsg() call must have the same type. + * 0 is not a valid content type. Use it as "no type reported, yet". + */ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, u8 *control) { @@ -2078,8 +2110,10 @@ int tls_sw_recvmsg(struct sock *sk, if (err < 0) goto end; + /* process_rx_list() will set @control if it processed any records */ copied = err; - if (len <= copied || (copied && control != TLS_RECORD_TYPE_DATA) || rx_more) + if (len <= copied || rx_more || + (control && control != TLS_RECORD_TYPE_DATA)) goto end; target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); |