diff options
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r-- | net/tls/tls_sw.c | 366 |
1 files changed, 229 insertions, 137 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index f127fac88acf..52fbe727d7c1 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -48,21 +48,11 @@ static int tls_do_decryption(struct sock *sk, struct scatterlist *sgout, char *iv_recv, size_t data_len, - struct sk_buff *skb, - gfp_t flags) + struct aead_request *aead_req) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - struct strp_msg *rxm = strp_msg(skb); - struct aead_request *aead_req; - int ret; - unsigned int req_size = sizeof(struct aead_request) + - crypto_aead_reqsize(ctx->aead_recv); - - aead_req = kzalloc(req_size, flags); - if (!aead_req) - return -ENOMEM; aead_request_set_tfm(aead_req, ctx->aead_recv); aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); @@ -73,20 +63,6 @@ static int tls_do_decryption(struct sock *sk, crypto_req_done, &ctx->async_wait); ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait); - - if (ret < 0) - goto out; - - rxm->offset += tls_ctx->rx.prepend_size; - rxm->full_len -= tls_ctx->rx.overhead_size; - tls_advance_record_sn(sk, &tls_ctx->rx); - - ctx->decrypted = true; - - ctx->saved_data_ready(sk); - -out: - kfree(aead_req); return ret; } @@ -224,8 +200,7 @@ static int tls_push_record(struct sock *sk, int flags, struct aead_request *req; int rc; - req = kzalloc(sizeof(struct aead_request) + - crypto_aead_reqsize(ctx->aead_send), sk->sk_allocation); + req = aead_request_alloc(ctx->aead_send, sk->sk_allocation); if (!req) return -ENOMEM; @@ -267,7 +242,7 @@ static int tls_push_record(struct sock *sk, int flags, tls_advance_record_sn(sk, &tls_ctx->tx); out_req: - kfree(req); + aead_request_free(req); return rc; } @@ -328,7 +303,12 @@ static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from, } } + /* Mark the end in the last sg entry if newly added */ + if (num_elem > *pages_used) + sg_mark_end(&to[num_elem - 1]); out: + if (rc) + iov_iter_revert(from, size - *size_used); *size_used = size; *pages_used = num_elem; @@ -377,6 +357,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) int record_room; bool full_record; int orig_size; + bool is_kvec = msg->msg_iter.type & ITER_KVEC; if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) return -ENOTSUPP; @@ -425,8 +406,7 @@ alloc_encrypted: try_to_copy -= required_size - ctx->sg_encrypted_size; full_record = true; } - - if (full_record || eor) { + if (!is_kvec && (full_record || eor)) { ret = zerocopy_from_iter(sk, &msg->msg_iter, try_to_copy, &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size, @@ -438,15 +418,11 @@ alloc_encrypted: copied += try_to_copy; ret = tls_push_record(sk, msg->msg_flags, record_type); - if (!ret) - continue; - if (ret == -EAGAIN) + if (ret) goto send_end; + continue; - copied -= try_to_copy; fallback_to_reg_send: - iov_iter_revert(&msg->msg_iter, - ctx->sg_plaintext_size - orig_size); trim_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size, @@ -646,6 +622,9 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags, return NULL; } + if (sk->sk_shutdown & RCV_SHUTDOWN) + return NULL; + if (sock_flag(sk, SOCK_DONE)) return NULL; @@ -670,52 +649,167 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags, return skb; } -static int decrypt_skb(struct sock *sk, struct sk_buff *skb, - struct scatterlist *sgout) +/* This function decrypts the input skb into either out_iov or in out_sg + * or in skb buffers itself. The input parameter 'zc' indicates if + * zero-copy mode needs to be tried or not. With zero-copy mode, either + * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are + * NULL, then the decryption happens inside skb buffers itself, i.e. + * zero-copy gets disabled and 'zc' is updated. + */ + +static int decrypt_internal(struct sock *sk, struct sk_buff *skb, + struct iov_iter *out_iov, + struct scatterlist *out_sg, + int *chunk, bool *zc) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE]; - struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2]; - struct scatterlist *sgin = &sgin_arr[0]; struct strp_msg *rxm = strp_msg(skb); - int ret, nsg = ARRAY_SIZE(sgin_arr); + int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0; + struct aead_request *aead_req; struct sk_buff *unused; + u8 *aad, *iv, *mem = NULL; + struct scatterlist *sgin = NULL; + struct scatterlist *sgout = NULL; + const int data_len = rxm->full_len - tls_ctx->rx.overhead_size; + + if (*zc && (out_iov || out_sg)) { + if (out_iov) + n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1; + else + n_sgout = sg_nents(out_sg); + } else { + n_sgout = 0; + *zc = false; + } + + n_sgin = skb_cow_data(skb, 0, &unused); + if (n_sgin < 1) + return -EBADMSG; + + /* Increment to accommodate AAD */ + n_sgin = n_sgin + 1; + + nsg = n_sgin + n_sgout; + + aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); + mem_size = aead_size + (nsg * sizeof(struct scatterlist)); + mem_size = mem_size + TLS_AAD_SPACE_SIZE; + mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv); + + /* Allocate a single block of memory which contains + * aead_req || sgin[] || sgout[] || aad || iv. + * This order achieves correct alignment for aead_req, sgin, sgout. + */ + mem = kmalloc(mem_size, sk->sk_allocation); + if (!mem) + return -ENOMEM; + + /* Segment the allocated memory */ + aead_req = (struct aead_request *)mem; + sgin = (struct scatterlist *)(mem + aead_size); + sgout = sgin + n_sgin; + aad = (u8 *)(sgout + n_sgout); + iv = aad + TLS_AAD_SPACE_SIZE; - ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, + /* Prepare IV */ + err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, tls_ctx->rx.iv_size); - if (ret < 0) - return ret; - - memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); - if (!sgout) { - nsg = skb_cow_data(skb, 0, &unused) + 1; - sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation); - sgout = sgin; + if (err < 0) { + kfree(mem); + return err; } + memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); - sg_init_table(sgin, nsg); - sg_set_buf(&sgin[0], ctx->rx_aad_ciphertext, TLS_AAD_SPACE_SIZE); + /* Prepare AAD */ + tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size, + tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size, + ctx->control); - nsg = skb_to_sgvec(skb, &sgin[1], + /* Prepare sgin */ + sg_init_table(sgin, n_sgin); + sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE); + err = skb_to_sgvec(skb, &sgin[1], rxm->offset + tls_ctx->rx.prepend_size, rxm->full_len - tls_ctx->rx.prepend_size); + if (err < 0) { + kfree(mem); + return err; + } - tls_make_aad(ctx->rx_aad_ciphertext, - rxm->full_len - tls_ctx->rx.overhead_size, - tls_ctx->rx.rec_seq, - tls_ctx->rx.rec_seq_size, - ctx->control); + if (n_sgout) { + if (out_iov) { + sg_init_table(sgout, n_sgout); + sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE); - ret = tls_do_decryption(sk, sgin, sgout, iv, - rxm->full_len - tls_ctx->rx.overhead_size, - skb, sk->sk_allocation); + *chunk = 0; + err = zerocopy_from_iter(sk, out_iov, data_len, &pages, + chunk, &sgout[1], + (n_sgout - 1), false); + if (err < 0) + goto fallback_to_reg_recv; + } else if (out_sg) { + memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); + } else { + goto fallback_to_reg_recv; + } + } else { +fallback_to_reg_recv: + sgout = sgin; + pages = 0; + *chunk = 0; + *zc = false; + } - if (sgin != &sgin_arr[0]) - kfree(sgin); + /* Prepare and submit AEAD request */ + err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req); - return ret; + /* Release the pages in case iov was mapped to pages */ + for (; pages > 0; pages--) + put_page(sg_page(&sgout[pages])); + + kfree(mem); + return err; +} + +static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, + struct iov_iter *dest, int *chunk, bool *zc) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct strp_msg *rxm = strp_msg(skb); + int err = 0; + +#ifdef CONFIG_TLS_DEVICE + err = tls_device_decrypted(sk, skb); + if (err < 0) + return err; +#endif + if (!ctx->decrypted) { + err = decrypt_internal(sk, skb, dest, NULL, chunk, zc); + if (err < 0) + return err; + } else { + *zc = false; + } + + rxm->offset += tls_ctx->rx.prepend_size; + rxm->full_len -= tls_ctx->rx.overhead_size; + tls_advance_record_sn(sk, &tls_ctx->rx); + ctx->decrypted = true; + ctx->saved_data_ready(sk); + + return err; +} + +int decrypt_skb(struct sock *sk, struct sk_buff *skb, + struct scatterlist *sgout) +{ + bool zc = true; + int chunk; + + return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc); } static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, @@ -756,6 +850,7 @@ int tls_sw_recvmsg(struct sock *sk, bool cmsg = false; int target, err = 0; long timeo; + bool is_kvec = msg->msg_iter.type & ITER_KVEC; flags |= nonblock; @@ -793,43 +888,17 @@ int tls_sw_recvmsg(struct sock *sk, } if (!ctx->decrypted) { - int page_count; - int to_copy; - - page_count = iov_iter_npages(&msg->msg_iter, - MAX_SKB_FRAGS); - to_copy = rxm->full_len - tls_ctx->rx.overhead_size; - if (to_copy <= len && page_count < MAX_SKB_FRAGS && - likely(!(flags & MSG_PEEK))) { - struct scatterlist sgin[MAX_SKB_FRAGS + 1]; - int pages = 0; + int to_copy = rxm->full_len - tls_ctx->rx.overhead_size; + if (!is_kvec && to_copy <= len && + likely(!(flags & MSG_PEEK))) zc = true; - sg_init_table(sgin, MAX_SKB_FRAGS + 1); - sg_set_buf(&sgin[0], ctx->rx_aad_plaintext, - TLS_AAD_SPACE_SIZE); - - err = zerocopy_from_iter(sk, &msg->msg_iter, - to_copy, &pages, - &chunk, &sgin[1], - MAX_SKB_FRAGS, false); - if (err < 0) - goto fallback_to_reg_recv; - - err = decrypt_skb(sk, skb, sgin); - for (; pages > 0; pages--) - put_page(sg_page(&sgin[pages])); - if (err < 0) { - tls_err_abort(sk, EBADMSG); - goto recv_end; - } - } else { -fallback_to_reg_recv: - err = decrypt_skb(sk, skb, NULL); - if (err < 0) { - tls_err_abort(sk, EBADMSG); - goto recv_end; - } + + err = decrypt_skb_update(sk, skb, &msg->msg_iter, + &chunk, &zc); + if (err < 0) { + tls_err_abort(sk, EBADMSG); + goto recv_end; } ctx->decrypted = true; } @@ -880,6 +949,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, int err = 0; long timeo; int chunk; + bool zc = false; lock_sock(sk); @@ -896,7 +966,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, } if (!ctx->decrypted) { - err = decrypt_skb(sk, skb, NULL); + err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc); if (err < 0) { tls_err_abort(sk, EBADMSG); @@ -919,29 +989,30 @@ splice_read_end: return copied ? : err; } -__poll_t tls_sw_poll_mask(struct socket *sock, __poll_t events) +unsigned int tls_sw_poll(struct file *file, struct socket *sock, + struct poll_table_struct *wait) { + unsigned int ret; struct sock *sk = sock->sk; struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - __poll_t mask; - /* Grab EPOLLOUT and EPOLLHUP from the underlying socket */ - mask = ctx->sk_poll_mask(sock, events); + /* Grab POLLOUT and POLLHUP from the underlying socket */ + ret = ctx->sk_poll(file, sock, wait); - /* Clear EPOLLIN bits, and set based on recv_pkt */ - mask &= ~(EPOLLIN | EPOLLRDNORM); + /* Clear POLLIN bits, and set based on recv_pkt */ + ret &= ~(POLLIN | POLLRDNORM); if (ctx->recv_pkt) - mask |= EPOLLIN | EPOLLRDNORM; + ret |= POLLIN | POLLRDNORM; - return mask; + return ret; } static int tls_read_size(struct strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - char header[tls_ctx->rx.prepend_size]; + char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; struct strp_msg *rxm = strp_msg(skb); size_t cipher_overhead; size_t data_len = 0; @@ -951,6 +1022,12 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) if (rxm->offset + tls_ctx->rx.prepend_size > skb->len) return 0; + /* Sanity-check size of on-stack buffer. */ + if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) { + ret = -EINVAL; + goto read_failure; + } + /* Linearize header to local buffer */ ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size); @@ -978,6 +1055,10 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) goto read_failure; } +#ifdef CONFIG_TLS_DEVICE + handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset, + *(u64*)tls_ctx->rx.rec_seq); +#endif return data_len + TLS_HEADER_SIZE; read_failure: @@ -990,16 +1071,13 @@ static void tls_queue(struct strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - struct strp_msg *rxm; - - rxm = strp_msg(skb); ctx->decrypted = false; ctx->recv_pkt = skb; strp_pause(strp); - strp->sk->sk_state_change(strp->sk); + ctx->saved_data_ready(strp->sk); } static void tls_data_ready(struct sock *sk) @@ -1015,23 +1093,20 @@ void tls_sw_free_resources_tx(struct sock *sk) struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - if (ctx->aead_send) - crypto_free_aead(ctx->aead_send); + crypto_free_aead(ctx->aead_send); tls_free_both_sg(sk); kfree(ctx); } -void tls_sw_free_resources_rx(struct sock *sk) +void tls_sw_release_resources_rx(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); if (ctx->aead_recv) { - if (ctx->recv_pkt) { - kfree_skb(ctx->recv_pkt); - ctx->recv_pkt = NULL; - } + kfree_skb(ctx->recv_pkt); + ctx->recv_pkt = NULL; crypto_free_aead(ctx->aead_recv); strp_stop(&ctx->strp); write_lock_bh(&sk->sk_callback_lock); @@ -1041,6 +1116,14 @@ void tls_sw_free_resources_rx(struct sock *sk) strp_done(&ctx->strp); lock_sock(sk); } +} + +void tls_sw_free_resources_rx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + + tls_sw_release_resources_rx(sk); kfree(ctx); } @@ -1065,28 +1148,38 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) } if (tx) { - sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); - if (!sw_ctx_tx) { - rc = -ENOMEM; - goto out; + if (!ctx->priv_ctx_tx) { + sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); + if (!sw_ctx_tx) { + rc = -ENOMEM; + goto out; + } + ctx->priv_ctx_tx = sw_ctx_tx; + } else { + sw_ctx_tx = + (struct tls_sw_context_tx *)ctx->priv_ctx_tx; } - crypto_init_wait(&sw_ctx_tx->async_wait); - ctx->priv_ctx_tx = sw_ctx_tx; } else { - sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); - if (!sw_ctx_rx) { - rc = -ENOMEM; - goto out; + if (!ctx->priv_ctx_rx) { + sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); + if (!sw_ctx_rx) { + rc = -ENOMEM; + goto out; + } + ctx->priv_ctx_rx = sw_ctx_rx; + } else { + sw_ctx_rx = + (struct tls_sw_context_rx *)ctx->priv_ctx_rx; } - crypto_init_wait(&sw_ctx_rx->async_wait); - ctx->priv_ctx_rx = sw_ctx_rx; } if (tx) { + crypto_init_wait(&sw_ctx_tx->async_wait); crypto_info = &ctx->crypto_send; cctx = &ctx->tx; aead = &sw_ctx_tx->aead_send; } else { + crypto_init_wait(&sw_ctx_rx->async_wait); crypto_info = &ctx->crypto_recv; cctx = &ctx->rx; aead = &sw_ctx_rx->aead_recv; @@ -1111,7 +1204,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) } /* Sanity-check the IV size for stack allocations. */ - if (iv_size > MAX_IV_SIZE) { + if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) { rc = -EINVAL; goto free_priv; } @@ -1129,12 +1222,11 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); cctx->rec_seq_size = rec_seq_size; - cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); + cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); if (!cctx->rec_seq) { rc = -ENOMEM; goto free_iv; } - memcpy(cctx->rec_seq, rec_seq, rec_seq_size); if (sw_ctx_tx) { sg_init_table(sw_ctx_tx->sg_encrypted_data, @@ -1191,7 +1283,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) sk->sk_data_ready = tls_data_ready; write_unlock_bh(&sk->sk_callback_lock); - sw_ctx_rx->sk_poll_mask = sk->sk_socket->ops->poll_mask; + sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll; strp_check_rcv(&sw_ctx_rx->strp); } |