diff options
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r-- | net/tls/tls_sw.c | 489 |
1 files changed, 205 insertions, 284 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index a8976ef95528..939d1673f508 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -44,6 +44,11 @@ #include <net/strparser.h> #include <net/tls.h> +struct tls_decrypt_arg { + bool zc; + bool async; +}; + noinline void tls_err_abort(struct sock *sk, int err) { WARN_ON_ONCE(err >= 0); @@ -128,32 +133,31 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len) return __skb_nsg(skb, offset, len, 0); } -static int padding_length(struct tls_sw_context_rx *ctx, - struct tls_prot_info *prot, struct sk_buff *skb) +static int padding_length(struct tls_prot_info *prot, struct sk_buff *skb) { struct strp_msg *rxm = strp_msg(skb); + struct tls_msg *tlm = tls_msg(skb); int sub = 0; /* Determine zero-padding length */ if (prot->version == TLS_1_3_VERSION) { + int offset = rxm->full_len - TLS_TAG_SIZE - 1; char content_type = 0; int err; - int back = 17; while (content_type == 0) { - if (back > rxm->full_len - prot->prepend_size) + if (offset < prot->prepend_size) return -EBADMSG; - err = skb_copy_bits(skb, - rxm->offset + rxm->full_len - back, + err = skb_copy_bits(skb, rxm->offset + offset, &content_type, 1); if (err) return err; if (content_type) break; sub++; - back++; + offset--; } - ctx->control = content_type; + tlm->control = content_type; } return sub; } @@ -169,7 +173,6 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) struct scatterlist *sg; struct sk_buff *skb; unsigned int pages; - int pending; skb = (struct sk_buff *)req->data; tls_ctx = tls_get_ctx(skb->sk); @@ -185,17 +188,12 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) tls_err_abort(skb->sk, err); } else { struct strp_msg *rxm = strp_msg(skb); - int pad; - pad = padding_length(ctx, prot, skb); - if (pad < 0) { - ctx->async_wait.err = pad; - tls_err_abort(skb->sk, pad); - } else { - rxm->full_len -= pad; - rxm->offset += prot->prepend_size; - rxm->full_len -= prot->overhead_size; - } + /* No TLS 1.3 support with async crypto */ + WARN_ON(prot->tail_size); + + rxm->offset += prot->prepend_size; + rxm->full_len -= prot->overhead_size; } /* After using skb->sk to propagate sk through crypto async callback @@ -217,9 +215,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) kfree(aead_req); spin_lock_bh(&ctx->decrypt_compl_lock); - pending = atomic_dec_return(&ctx->decrypt_pending); - - if (!pending && ctx->async_notify) + if (!atomic_dec_return(&ctx->decrypt_pending)) complete(&ctx->async_wait.completion); spin_unlock_bh(&ctx->decrypt_compl_lock); } @@ -231,7 +227,7 @@ static int tls_do_decryption(struct sock *sk, char *iv_recv, size_t data_len, struct aead_request *aead_req, - bool async) + struct tls_decrypt_arg *darg) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_prot_info *prot = &tls_ctx->prot_info; @@ -244,7 +240,7 @@ static int tls_do_decryption(struct sock *sk, data_len + prot->tag_size, (u8 *)iv_recv); - if (async) { + if (darg->async) { /* Using skb->sk to push sk through to crypto async callback * handler. This allows propagating errors up to the socket * if needed. It _must_ be cleared in the async handler @@ -264,14 +260,15 @@ static int tls_do_decryption(struct sock *sk, ret = crypto_aead_decrypt(aead_req); if (ret == -EINPROGRESS) { - if (async) - return ret; + if (darg->async) + return 0; ret = crypto_wait_req(ret, &ctx->async_wait); } + darg->async = false; - if (async) - atomic_dec(&ctx->decrypt_pending); + if (ret == -EBADMSG) + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); return ret; } @@ -1346,15 +1343,14 @@ static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock, return skb; } -static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from, +static int tls_setup_from_iter(struct iov_iter *from, int length, int *pages_used, - unsigned int *size_used, struct scatterlist *to, int to_max_pages) { int rc = 0, i = 0, num_elem = *pages_used, maxpages; struct page *pages[MAX_SKB_FRAGS]; - unsigned int size = *size_used; + unsigned int size = 0; ssize_t copied, use; size_t offset; @@ -1397,8 +1393,7 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from, sg_mark_end(&to[num_elem - 1]); out: if (rc) - iov_iter_revert(from, size - *size_used); - *size_used = size; + iov_iter_revert(from, size); *pages_used = num_elem; return rc; @@ -1415,12 +1410,13 @@ out: static int decrypt_internal(struct sock *sk, struct sk_buff *skb, struct iov_iter *out_iov, struct scatterlist *out_sg, - int *chunk, bool *zc, bool async) + struct tls_decrypt_arg *darg) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_prot_info *prot = &tls_ctx->prot_info; struct strp_msg *rxm = strp_msg(skb); + struct tls_msg *tlm = tls_msg(skb); int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0; struct aead_request *aead_req; struct sk_buff *unused; @@ -1431,7 +1427,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, prot->tail_size; int iv_offset = 0; - if (*zc && (out_iov || out_sg)) { + if (darg->zc && (out_iov || out_sg)) { if (out_iov) n_sgout = 1 + iov_iter_npages_cap(out_iov, INT_MAX, data_len); @@ -1441,7 +1437,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, rxm->full_len - prot->prepend_size); } else { n_sgout = 0; - *zc = false; + darg->zc = false; n_sgin = skb_cow_data(skb, 0, &unused); } @@ -1456,7 +1452,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); mem_size = aead_size + (nsg * sizeof(struct scatterlist)); mem_size = mem_size + prot->aad_size; - mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv); + mem_size = mem_size + MAX_IV_SIZE; /* Allocate a single block of memory which contains * aead_req || sgin[] || sgout[] || aad || iv. @@ -1486,26 +1482,26 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, } /* Prepare IV */ - err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, - iv + iv_offset + prot->salt_size, - prot->iv_size); - if (err < 0) { - kfree(mem); - return err; - } if (prot->version == TLS_1_3_VERSION || - prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) + prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) { memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->iv_size + prot->salt_size); - else + } else { + err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, + iv + iv_offset + prot->salt_size, + prot->iv_size); + if (err < 0) { + kfree(mem); + return err; + } memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size); - + } xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq); /* Prepare AAD */ tls_make_aad(aad, rxm->full_len - prot->overhead_size + prot->tail_size, - tls_ctx->rx.rec_seq, ctx->control, prot); + tls_ctx->rx.rec_seq, tlm->control, prot); /* Prepare sgin */ sg_init_table(sgin, n_sgin); @@ -1523,9 +1519,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, sg_init_table(sgout, n_sgout); sg_set_buf(&sgout[0], aad, prot->aad_size); - *chunk = 0; - err = tls_setup_from_iter(sk, out_iov, data_len, - &pages, chunk, &sgout[1], + err = tls_setup_from_iter(out_iov, data_len, + &pages, &sgout[1], (n_sgout - 1)); if (err < 0) goto fallback_to_reg_recv; @@ -1538,15 +1533,14 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, fallback_to_reg_recv: sgout = sgin; pages = 0; - *chunk = data_len; - *zc = false; + darg->zc = false; } /* Prepare and submit AEAD request */ err = tls_do_decryption(sk, skb, sgin, sgout, iv, - data_len, aead_req, async); - if (err == -EINPROGRESS) - return err; + data_len, aead_req, darg); + if (darg->async) + return 0; /* Release the pages in case iov was mapped to pages */ for (; pages > 0; pages--) @@ -1557,87 +1551,83 @@ fallback_to_reg_recv: } static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, - struct iov_iter *dest, int *chunk, bool *zc, - bool async) + struct iov_iter *dest, + struct tls_decrypt_arg *darg) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_prot_info *prot = &tls_ctx->prot_info; struct strp_msg *rxm = strp_msg(skb); - int pad, err = 0; + struct tls_msg *tlm = tls_msg(skb); + int pad, err; - if (!ctx->decrypted) { - if (tls_ctx->rx_conf == TLS_HW) { - err = tls_device_decrypted(sk, tls_ctx, skb, rxm); - if (err < 0) - return err; - } + if (tlm->decrypted) { + darg->zc = false; + darg->async = false; + return 0; + } - /* Still not decrypted after tls_device */ - if (!ctx->decrypted) { - err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, - async); - if (err < 0) { - if (err == -EINPROGRESS) - tls_advance_record_sn(sk, prot, - &tls_ctx->rx); - else if (err == -EBADMSG) - TLS_INC_STATS(sock_net(sk), - LINUX_MIB_TLSDECRYPTERROR); - return err; - } - } else { - *zc = false; + if (tls_ctx->rx_conf == TLS_HW) { + err = tls_device_decrypted(sk, tls_ctx, skb, rxm); + if (err < 0) + return err; + if (err > 0) { + tlm->decrypted = 1; + darg->zc = false; + darg->async = false; + goto decrypt_done; } + } - pad = padding_length(ctx, prot, skb); - if (pad < 0) - return pad; + err = decrypt_internal(sk, skb, dest, NULL, darg); + if (err < 0) + return err; + if (darg->async) + goto decrypt_next; - rxm->full_len -= pad; - rxm->offset += prot->prepend_size; - rxm->full_len -= prot->overhead_size; - tls_advance_record_sn(sk, prot, &tls_ctx->rx); - ctx->decrypted = 1; - ctx->saved_data_ready(sk); - } else { - *zc = false; - } +decrypt_done: + pad = padding_length(prot, skb); + if (pad < 0) + return pad; - return err; + rxm->full_len -= pad; + rxm->offset += prot->prepend_size; + rxm->full_len -= prot->overhead_size; + tlm->decrypted = 1; +decrypt_next: + tls_advance_record_sn(sk, prot, &tls_ctx->rx); + + return 0; } int decrypt_skb(struct sock *sk, struct sk_buff *skb, struct scatterlist *sgout) { - bool zc = true; - int chunk; + struct tls_decrypt_arg darg = { .zc = true, }; - return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false); + return decrypt_internal(sk, skb, NULL, sgout, &darg); } -static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, - unsigned int len) +static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, + u8 *control) { - struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - - if (skb) { - struct strp_msg *rxm = strp_msg(skb); - - if (len < rxm->full_len) { - rxm->offset += len; - rxm->full_len -= len; - return false; + int err; + + if (!*control) { + *control = tlm->control; + if (!*control) + return -EBADMSG; + + err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, + sizeof(*control), control); + if (*control != TLS_RECORD_TYPE_DATA) { + if (err || msg->msg_flags & MSG_CTRUNC) + return -EIO; } - consume_skb(skb); + } else if (*control != tlm->control) { + return 0; } - /* Finished with message */ - ctx->recv_pkt = NULL; - __strp_unpause(&ctx->strp); - - return true; + return 1; } /* This function traverses the rx_list in tls receive context to copies the @@ -1648,31 +1638,23 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, static int process_rx_list(struct tls_sw_context_rx *ctx, struct msghdr *msg, u8 *control, - bool *cmsg, size_t skip, size_t len, bool zc, bool is_peek) { struct sk_buff *skb = skb_peek(&ctx->rx_list); - u8 ctrl = *control; - u8 msgc = *cmsg; struct tls_msg *tlm; ssize_t copied = 0; - - /* Set the record type in 'control' if caller didn't pass it */ - if (!ctrl && skb) { - tlm = tls_msg(skb); - ctrl = tlm->control; - } + int err; while (skip && skb) { struct strp_msg *rxm = strp_msg(skb); tlm = tls_msg(skb); - /* Cannot process a record of different type */ - if (ctrl != tlm->control) - return 0; + err = tls_record_content_type(msg, tlm, control); + if (err <= 0) + goto out; if (skip < rxm->full_len) break; @@ -1688,30 +1670,15 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, tlm = tls_msg(skb); - /* Cannot process a record of different type */ - if (ctrl != tlm->control) - return 0; - - /* Set record type if not already done. For a non-data record, - * do not proceed if record type could not be copied. - */ - if (!msgc) { - int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, - sizeof(ctrl), &ctrl); - msgc = true; - if (ctrl != TLS_RECORD_TYPE_DATA) { - if (cerr || msg->msg_flags & MSG_CTRUNC) - return -EIO; - - *cmsg = msgc; - } - } + err = tls_record_content_type(msg, tlm, control); + if (err <= 0) + goto out; if (!zc || (rxm->full_len - skip) > len) { - int err = skb_copy_datagram_msg(skb, rxm->offset + skip, + err = skb_copy_datagram_msg(skb, rxm->offset + skip, msg, chunk); if (err < 0) - return err; + goto out; } len = len - chunk; @@ -1738,21 +1705,21 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, next_skb = skb_peek_next(skb, &ctx->rx_list); if (!is_peek) { - skb_unlink(skb, &ctx->rx_list); + __skb_unlink(skb, &ctx->rx_list); consume_skb(skb); } skb = next_skb; } + err = 0; - *control = ctrl; - return copied; +out: + return copied ? : err; } int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, - int nonblock, int flags, int *addr_len) { @@ -1766,16 +1733,13 @@ int tls_sw_recvmsg(struct sock *sk, struct tls_msg *tlm; struct sk_buff *skb; ssize_t copied = 0; - bool cmsg = false; + bool async = false; int target, err = 0; long timeo; bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool is_peek = flags & MSG_PEEK; bool bpf_strp_enabled; - int num_async = 0; - int pending; - - flags |= nonblock; + bool zc_capable; if (unlikely(flags & MSG_ERRQUEUE)) return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); @@ -1784,81 +1748,64 @@ int tls_sw_recvmsg(struct sock *sk, lock_sock(sk); bpf_strp_enabled = sk_psock_strp_enabled(psock); + /* If crypto failed the connection is broken */ + err = ctx->async_wait.err; + if (err) + goto end; + /* Process pending decrypted records. It must be non-zero-copy */ - err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false, - is_peek); - if (err < 0) { - tls_err_abort(sk, err); + err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek); + if (err < 0) goto end; - } else { - copied = err; - } + copied = err; if (len <= copied) - goto recv_end; + goto end; target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); len = len - copied; timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek && + prot->version != TLS_1_3_VERSION; + decrypted = 0; while (len && (decrypted + copied < target || ctx->recv_pkt)) { - bool retain_skb = false; - bool zc = false; - int to_decrypt; - int chunk = 0; - bool async_capable; - bool async = false; + struct tls_decrypt_arg darg = {}; + int to_decrypt, chunk; skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err); if (!skb) { if (psock) { - int ret = sk_msg_recvmsg(sk, psock, msg, len, - flags); - - if (ret > 0) { - decrypted += ret; - len -= ret; - continue; - } + chunk = sk_msg_recvmsg(sk, psock, msg, len, + flags); + if (chunk > 0) + goto leave_on_list; } goto recv_end; - } else { - tlm = tls_msg(skb); - if (prot->version == TLS_1_3_VERSION) - tlm->control = 0; - else - tlm->control = ctx->control; } rxm = strp_msg(skb); + tlm = tls_msg(skb); to_decrypt = rxm->full_len - prot->overhead_size; - if (to_decrypt <= len && !is_kvec && !is_peek && - ctx->control == TLS_RECORD_TYPE_DATA && - prot->version != TLS_1_3_VERSION && - !bpf_strp_enabled) - zc = true; + if (zc_capable && to_decrypt <= len && + tlm->control == TLS_RECORD_TYPE_DATA) + darg.zc = true; /* Do not use async mode if record is non-data */ - if (ctx->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled) - async_capable = ctx->async_capable; + if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled) + darg.async = ctx->async_capable; else - async_capable = false; + darg.async = false; - err = decrypt_skb_update(sk, skb, &msg->msg_iter, - &chunk, &zc, async_capable); - if (err < 0 && err != -EINPROGRESS) { + err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg); + if (err < 0) { tls_err_abort(sk, -EBADMSG); goto recv_end; } - if (err == -EINPROGRESS) { - async = true; - num_async++; - } else if (prot->version == TLS_1_3_VERSION) { - tlm->control = ctx->control; - } + async |= darg.async; /* If the type of records being processed is not known yet, * set it to record type just dequeued. If it is already known, @@ -1867,131 +1814,105 @@ int tls_sw_recvmsg(struct sock *sk, * is known just after record is dequeued from stream parser. * For tls1.3, we disable async. */ - - if (!control) - control = tlm->control; - else if (control != tlm->control) + err = tls_record_content_type(msg, tlm, &control); + if (err <= 0) goto recv_end; - if (!cmsg) { - int cerr; - - cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, - sizeof(control), &control); - cmsg = true; - if (control != TLS_RECORD_TYPE_DATA) { - if (cerr || msg->msg_flags & MSG_CTRUNC) { - err = -EIO; - goto recv_end; - } - } + ctx->recv_pkt = NULL; + __strp_unpause(&ctx->strp); + __skb_queue_tail(&ctx->rx_list, skb); + + if (async) { + /* TLS 1.2-only, to_decrypt must be text length */ + chunk = min_t(int, to_decrypt, len); +leave_on_list: + decrypted += chunk; + len -= chunk; + continue; } + /* TLS 1.3 may have updated the length by more than overhead */ + chunk = rxm->full_len; - if (async) - goto pick_next_record; + if (!darg.zc) { + bool partially_consumed = chunk > len; - if (!zc) { if (bpf_strp_enabled) { err = sk_psock_tls_strp_read(psock, skb); if (err != __SK_PASS) { rxm->offset = rxm->offset + rxm->full_len; rxm->full_len = 0; + __skb_unlink(skb, &ctx->rx_list); if (err == __SK_DROP) consume_skb(skb); - ctx->recv_pkt = NULL; - __strp_unpause(&ctx->strp); continue; } } - if (rxm->full_len > len) { - retain_skb = true; + if (partially_consumed) chunk = len; - } else { - chunk = rxm->full_len; - } err = skb_copy_datagram_msg(skb, rxm->offset, msg, chunk); if (err < 0) goto recv_end; - if (!is_peek) { - rxm->offset = rxm->offset + chunk; - rxm->full_len = rxm->full_len - chunk; + if (is_peek) + goto leave_on_list; + + if (partially_consumed) { + rxm->offset += chunk; + rxm->full_len -= chunk; + goto leave_on_list; } } -pick_next_record: - if (chunk > len) - chunk = len; - decrypted += chunk; len -= chunk; - /* For async or peek case, queue the current skb */ - if (async || is_peek || retain_skb) { - skb_queue_tail(&ctx->rx_list, skb); - skb = NULL; - } + __skb_unlink(skb, &ctx->rx_list); + consume_skb(skb); - if (tls_sw_advance_skb(sk, skb, chunk)) { - /* Return full control message to - * userspace before trying to parse - * another message type - */ - msg->msg_flags |= MSG_EOR; - if (control != TLS_RECORD_TYPE_DATA) - goto recv_end; - } else { + /* Return full control message to userspace before trying + * to parse another message type + */ + msg->msg_flags |= MSG_EOR; + if (control != TLS_RECORD_TYPE_DATA) break; - } } recv_end: - if (num_async) { + if (async) { + int ret, pending; + /* Wait for all previously submitted records to be decrypted */ spin_lock_bh(&ctx->decrypt_compl_lock); - ctx->async_notify = true; + reinit_completion(&ctx->async_wait.completion); pending = atomic_read(&ctx->decrypt_pending); spin_unlock_bh(&ctx->decrypt_compl_lock); if (pending) { - err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait); - if (err) { - /* one of async decrypt failed */ - tls_err_abort(sk, err); - copied = 0; + ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait); + if (ret) { + if (err >= 0 || err == -EINPROGRESS) + err = ret; decrypted = 0; goto end; } - } else { - reinit_completion(&ctx->async_wait.completion); } - /* There can be no concurrent accesses, since we have no - * pending decrypt operations - */ - WRITE_ONCE(ctx->async_notify, false); - /* Drain records from the rx_list & copy if required */ if (is_peek || is_kvec) - err = process_rx_list(ctx, msg, &control, &cmsg, copied, + err = process_rx_list(ctx, msg, &control, copied, decrypted, false, is_peek); else - err = process_rx_list(ctx, msg, &control, &cmsg, 0, + err = process_rx_list(ctx, msg, &control, 0, decrypted, true, is_peek); - if (err < 0) { - tls_err_abort(sk, err); - copied = 0; - goto end; - } + decrypted = max(err, 0); } copied += decrypted; end: release_sock(sk); - sk_defer_free_flush(sk); if (psock) sk_psock_put(sk, psock); return copied ? : err; @@ -2005,13 +1926,13 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct strp_msg *rxm = NULL; struct sock *sk = sock->sk; + struct tls_msg *tlm; struct sk_buff *skb; ssize_t copied = 0; bool from_queue; int err = 0; long timeo; int chunk; - bool zc = false; lock_sock(sk); @@ -2021,26 +1942,29 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, if (from_queue) { skb = __skb_dequeue(&ctx->rx_list); } else { + struct tls_decrypt_arg darg = {}; + skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, &err); if (!skb) goto splice_read_end; - err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false); + err = decrypt_skb_update(sk, skb, NULL, &darg); if (err < 0) { tls_err_abort(sk, -EBADMSG); goto splice_read_end; } } + rxm = strp_msg(skb); + tlm = tls_msg(skb); + /* splice does not support reading control messages */ - if (ctx->control != TLS_RECORD_TYPE_DATA) { + if (tlm->control != TLS_RECORD_TYPE_DATA) { err = -EINVAL; goto splice_read_end; } - rxm = strp_msg(skb); - chunk = min_t(unsigned int, rxm->full_len, len); copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); if (copied < 0) @@ -2060,7 +1984,6 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, splice_read_end: release_sock(sk); - sk_defer_free_flush(sk); return copied ? : err; } @@ -2084,10 +2007,10 @@ bool tls_sw_sock_is_readable(struct sock *sk) static int tls_read_size(struct strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); - struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_prot_info *prot = &tls_ctx->prot_info; char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; struct strp_msg *rxm = strp_msg(skb); + struct tls_msg *tlm = tls_msg(skb); size_t cipher_overhead; size_t data_len = 0; int ret; @@ -2104,11 +2027,11 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) /* Linearize header to local buffer */ ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size); - if (ret < 0) goto read_failure; - ctx->control = header[0]; + tlm->decrypted = 0; + tlm->control = header[0]; data_len = ((header[4] & 0xFF) | (header[3] << 8)); @@ -2149,8 +2072,6 @@ static void tls_queue(struct strparser *strp, struct sk_buff *skb) struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - ctx->decrypted = 0; - ctx->recv_pkt = skb; strp_pause(strp); @@ -2241,7 +2162,7 @@ void tls_sw_release_resources_rx(struct sock *sk) if (ctx->aead_recv) { kfree_skb(ctx->recv_pkt); ctx->recv_pkt = NULL; - skb_queue_purge(&ctx->rx_list); + __skb_queue_purge(&ctx->rx_list); crypto_free_aead(ctx->aead_recv); strp_stop(&ctx->strp); /* If tls_sw_strparser_arm() was not called (cleanup paths) @@ -2501,7 +2422,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) /* Sanity-check the sizes for stack allocations. */ if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE || - rec_seq_size > TLS_MAX_REC_SEQ_SIZE) { + rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE) { rc = -EINVAL; goto free_priv; } |