diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls_sw.c | 224 |
1 files changed, 206 insertions, 18 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index e28a6ff25d96..8aa4c1dafd6a 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -43,12 +43,126 @@ #define MAX_IV_SIZE TLS_CIPHER_AES_GCM_128_IV_SIZE +static int __skb_nsg(struct sk_buff *skb, int offset, int len, + unsigned int recursion_level) +{ + int start = skb_headlen(skb); + int i, chunk = start - offset; + struct sk_buff *frag_iter; + int elt = 0; + + if (unlikely(recursion_level >= 24)) + return -EMSGSIZE; + + if (chunk > 0) { + if (chunk > len) + chunk = len; + elt++; + len -= chunk; + if (len == 0) + return elt; + offset += chunk; + } + + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + int end; + + WARN_ON(start > offset + len); + + end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]); + chunk = end - offset; + if (chunk > 0) { + if (chunk > len) + chunk = len; + elt++; + len -= chunk; + if (len == 0) + return elt; + offset += chunk; + } + start = end; + } + + if (unlikely(skb_has_frag_list(skb))) { + skb_walk_frags(skb, frag_iter) { + int end, ret; + + WARN_ON(start > offset + len); + + end = start + frag_iter->len; + chunk = end - offset; + if (chunk > 0) { + if (chunk > len) + chunk = len; + ret = __skb_nsg(frag_iter, offset - start, chunk, + recursion_level + 1); + if (unlikely(ret < 0)) + return ret; + elt += ret; + len -= chunk; + if (len == 0) + return elt; + offset += chunk; + } + start = end; + } + } + BUG_ON(len); + return elt; +} + +/* Return the number of scatterlist elements required to completely map the + * skb, or -EMSGSIZE if the recursion depth is exceeded. + */ +static int skb_nsg(struct sk_buff *skb, int offset, int len) +{ + return __skb_nsg(skb, offset, len, 0); +} + +static void tls_decrypt_done(struct crypto_async_request *req, int err) +{ + struct aead_request *aead_req = (struct aead_request *)req; + struct decrypt_req_ctx *req_ctx = + (struct decrypt_req_ctx *)(aead_req + 1); + + struct scatterlist *sgout = aead_req->dst; + + struct tls_context *tls_ctx = tls_get_ctx(req_ctx->sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + int pending = atomic_dec_return(&ctx->decrypt_pending); + struct scatterlist *sg; + unsigned int pages; + + /* Propagate if there was an err */ + if (err) { + ctx->async_wait.err = err; + tls_err_abort(req_ctx->sk, err); + } + + /* Release the skb, pages and memory allocated for crypto req */ + kfree_skb(req->data); + + /* Skip the first S/G entry as it points to AAD */ + for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) { + if (!sg) + break; + put_page(sg_page(sg)); + } + + kfree(aead_req); + + if (!pending && READ_ONCE(ctx->async_notify)) + complete(&ctx->async_wait.completion); +} + static int tls_do_decryption(struct sock *sk, + struct sk_buff *skb, struct scatterlist *sgin, struct scatterlist *sgout, char *iv_recv, size_t data_len, - struct aead_request *aead_req) + struct aead_request *aead_req, + bool async) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); @@ -59,10 +173,34 @@ static int tls_do_decryption(struct sock *sk, aead_request_set_crypt(aead_req, sgin, sgout, data_len + tls_ctx->rx.tag_size, (u8 *)iv_recv); - aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, - crypto_req_done, &ctx->async_wait); - ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait); + if (async) { + struct decrypt_req_ctx *req_ctx; + + req_ctx = (struct decrypt_req_ctx *)(aead_req + 1); + req_ctx->sk = sk; + + aead_request_set_callback(aead_req, + CRYPTO_TFM_REQ_MAY_BACKLOG, + tls_decrypt_done, skb); + atomic_inc(&ctx->decrypt_pending); + } else { + aead_request_set_callback(aead_req, + CRYPTO_TFM_REQ_MAY_BACKLOG, + crypto_req_done, &ctx->async_wait); + } + + ret = crypto_aead_decrypt(aead_req); + if (ret == -EINPROGRESS) { + if (async) + return ret; + + ret = crypto_wait_req(ret, &ctx->async_wait); + } + + if (async) + atomic_dec(&ctx->decrypt_pending); + return ret; } @@ -354,7 +492,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - int ret = 0; + int ret; int required_size; long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); bool eor = !(msg->msg_flags & MSG_MORE); @@ -370,7 +508,8 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) lock_sock(sk); - if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo)) + ret = tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo); + if (ret) goto send_end; if (unlikely(msg->msg_controllen)) { @@ -505,7 +644,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - int ret = 0; + int ret; long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); bool eor; size_t orig_size = size; @@ -525,7 +664,8 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); - if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo)) + ret = tls_complete_pending_work(sk, tls_ctx, flags, &timeo); + if (ret) goto sendpage_end; /* Call the sk_stream functions to manage the sndbuf mem. */ @@ -684,12 +824,14 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1; else n_sgout = sg_nents(out_sg); + n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size, + rxm->full_len - tls_ctx->rx.prepend_size); } else { n_sgout = 0; *zc = false; + n_sgin = skb_cow_data(skb, 0, &unused); } - n_sgin = skb_cow_data(skb, 0, &unused); if (n_sgin < 1) return -EBADMSG; @@ -769,7 +911,10 @@ fallback_to_reg_recv: } /* Prepare and submit AEAD request */ - err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req); + err = tls_do_decryption(sk, skb, sgin, sgout, iv, + data_len, aead_req, *zc); + if (err == -EINPROGRESS) + return err; /* Release the pages in case iov was mapped to pages */ for (; pages > 0; pages--) @@ -794,8 +939,12 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, #endif if (!ctx->decrypted) { err = decrypt_internal(sk, skb, dest, NULL, chunk, zc); - if (err < 0) + if (err < 0) { + if (err == -EINPROGRESS) + tls_advance_record_sn(sk, &tls_ctx->rx); + return err; + } } else { *zc = false; } @@ -823,18 +972,20 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - struct strp_msg *rxm = strp_msg(skb); - if (len < rxm->full_len) { - rxm->offset += len; - rxm->full_len -= len; + if (skb) { + struct strp_msg *rxm = strp_msg(skb); - return false; + if (len < rxm->full_len) { + rxm->offset += len; + rxm->full_len -= len; + return false; + } + kfree_skb(skb); } /* Finished with message */ ctx->recv_pkt = NULL; - kfree_skb(skb); __strp_unpause(&ctx->strp); return true; @@ -857,6 +1008,7 @@ int tls_sw_recvmsg(struct sock *sk, int target, err = 0; long timeo; bool is_kvec = msg->msg_iter.type & ITER_KVEC; + int num_async = 0; flags |= nonblock; @@ -869,6 +1021,7 @@ int tls_sw_recvmsg(struct sock *sk, timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); do { bool zc = false; + bool async = false; int chunk = 0; skb = tls_wait_data(sk, flags, timeo, &err); @@ -876,6 +1029,7 @@ int tls_sw_recvmsg(struct sock *sk, goto recv_end; rxm = strp_msg(skb); + if (!cmsg) { int cerr; @@ -902,26 +1056,39 @@ int tls_sw_recvmsg(struct sock *sk, err = decrypt_skb_update(sk, skb, &msg->msg_iter, &chunk, &zc); - if (err < 0) { + if (err < 0 && err != -EINPROGRESS) { tls_err_abort(sk, EBADMSG); goto recv_end; } + + if (err == -EINPROGRESS) { + async = true; + num_async++; + goto pick_next_record; + } + ctx->decrypted = true; } if (!zc) { chunk = min_t(unsigned int, rxm->full_len, len); + err = skb_copy_datagram_msg(skb, rxm->offset, msg, chunk); if (err < 0) goto recv_end; } +pick_next_record: copied += chunk; len -= chunk; if (likely(!(flags & MSG_PEEK))) { u8 control = ctx->control; + /* For async, drop current skb reference */ + if (async) + skb = NULL; + if (tls_sw_advance_skb(sk, skb, chunk)) { /* Return full control message to * userspace before trying to parse @@ -930,14 +1097,33 @@ int tls_sw_recvmsg(struct sock *sk, msg->msg_flags |= MSG_EOR; if (control != TLS_RECORD_TYPE_DATA) goto recv_end; + } else { + break; } } + /* If we have a new message from strparser, continue now. */ if (copied >= target && !ctx->recv_pkt) break; } while (len); recv_end: + if (num_async) { + /* Wait for all previously submitted records to be decrypted */ + smp_store_mb(ctx->async_notify, true); + if (atomic_read(&ctx->decrypt_pending)) { + err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait); + if (err) { + /* one of async decrypt failed */ + tls_err_abort(sk, err); + copied = 0; + } + } else { + reinit_completion(&ctx->async_wait.completion); + } + WRITE_ONCE(ctx->async_notify, false); + } + release_sock(sk); return copied ? : err; } @@ -1277,6 +1463,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) goto free_aead; if (sw_ctx_rx) { + (*aead)->reqsize = sizeof(struct decrypt_req_ctx); + /* Set up strparser */ memset(&cb, 0, sizeof(cb)); cb.rcv_msg = tls_queue; |