summaryrefslogtreecommitdiff
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/tls_strp.c6
-rw-r--r--net/tls/tls_sw.c44
2 files changed, 42 insertions, 8 deletions
diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c
index f37f4a0fcd3c..b7ed76c0e576 100644
--- a/net/tls/tls_strp.c
+++ b/net/tls/tls_strp.c
@@ -396,7 +396,6 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort)
return 0;
shinfo = skb_shinfo(strp->anchor);
- shinfo->frag_list = NULL;
/* If we don't know the length go max plus page for cipher overhead */
need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
@@ -412,6 +411,8 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort)
page, 0, 0);
}
+ shinfo->frag_list = NULL;
+
strp->copy_mode = 1;
strp->stm.offset = 0;
@@ -511,9 +512,8 @@ static int tls_strp_read_sock(struct tls_strparser *strp)
if (inq < strp->stm.full_len)
return tls_strp_read_copy(strp, true);
+ tls_strp_load_anchor_with_queue(strp, inq);
if (!strp->stm.full_len) {
- tls_strp_load_anchor_with_queue(strp, inq);
-
sz = tls_rx_msg_size(strp, strp->anchor);
if (sz < 0) {
tls_strp_abort_strp(strp, sz);
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 5310441240e7..96e62e8f1dad 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -274,9 +274,15 @@ static int tls_do_decryption(struct sock *sk,
DEBUG_NET_WARN_ON_ONCE(atomic_read(&ctx->decrypt_pending) < 1);
atomic_inc(&ctx->decrypt_pending);
} else {
+ DECLARE_CRYPTO_WAIT(wait);
+
aead_request_set_callback(aead_req,
CRYPTO_TFM_REQ_MAY_BACKLOG,
- crypto_req_done, &ctx->async_wait);
+ crypto_req_done, &wait);
+ ret = crypto_aead_decrypt(aead_req);
+ if (ret == -EINPROGRESS || ret == -EBUSY)
+ ret = crypto_wait_req(ret, &wait);
+ return ret;
}
ret = crypto_aead_decrypt(aead_req);
@@ -289,7 +295,6 @@ static int tls_do_decryption(struct sock *sk,
/* all completions have run, we're not doing async anymore */
darg->async = false;
return ret;
- ret = ret ?: -EINPROGRESS;
}
atomic_dec(&ctx->decrypt_pending);
@@ -868,6 +873,19 @@ more_data:
delta = msg->sg.size;
psock->eval = sk_psock_msg_verdict(sk, psock, msg);
delta -= msg->sg.size;
+
+ if ((s32)delta > 0) {
+ /* It indicates that we executed bpf_msg_pop_data(),
+ * causing the plaintext data size to decrease.
+ * Therefore the encrypted data size also needs to
+ * correspondingly decrease. We only need to subtract
+ * delta to calculate the new ciphertext length since
+ * ktls does not support block encryption.
+ */
+ struct sk_msg *enc = &ctx->open_rec->msg_encrypted;
+
+ sk_msg_trim(sk, enc, enc->sg.size - delta);
+ }
}
if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
!enospc && !full_record) {
@@ -904,6 +922,13 @@ more_data:
&msg_redir, send, flags);
lock_sock(sk);
if (err < 0) {
+ /* Regardless of whether the data represented by
+ * msg_redir is sent successfully, we have already
+ * uncharged it via sk_msg_return_zero(). The
+ * msg->sg.size represents the remaining unprocessed
+ * data, which needs to be uncharged here.
+ */
+ sk_mem_uncharge(sk, msg->sg.size);
*copied -= sk_msg_free_nocharge(sk, &msg_redir);
msg->sg.size = 0;
}
@@ -1075,9 +1100,13 @@ alloc_encrypted:
num_async++;
else if (ret == -ENOMEM)
goto wait_for_memory;
- else if (ctx->open_rec && ret == -ENOSPC)
+ else if (ctx->open_rec && ret == -ENOSPC) {
+ if (msg_pl->cork_bytes) {
+ ret = 0;
+ goto send_end;
+ }
goto rollback_iter;
- else if (ret != -EAGAIN)
+ } else if (ret != -EAGAIN)
goto send_end;
}
continue;
@@ -1835,6 +1864,9 @@ int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
return tls_decrypt_sg(sk, NULL, sgout, &darg);
}
+/* All records returned from a recvmsg() call must have the same type.
+ * 0 is not a valid content type. Use it as "no type reported, yet".
+ */
static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
u8 *control)
{
@@ -2078,8 +2110,10 @@ int tls_sw_recvmsg(struct sock *sk,
if (err < 0)
goto end;
+ /* process_rx_list() will set @control if it processed any records */
copied = err;
- if (len <= copied || (copied && control != TLS_RECORD_TYPE_DATA) || rx_more)
+ if (len <= copied || rx_more ||
+ (control && control != TLS_RECORD_TYPE_DATA))
goto end;
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);