summaryrefslogtreecommitdiff
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/tls_sw.c224
1 files changed, 206 insertions, 18 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index e28a6ff25d96..8aa4c1dafd6a 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -43,12 +43,126 @@
#define MAX_IV_SIZE TLS_CIPHER_AES_GCM_128_IV_SIZE
+static int __skb_nsg(struct sk_buff *skb, int offset, int len,
+ unsigned int recursion_level)
+{
+ int start = skb_headlen(skb);
+ int i, chunk = start - offset;
+ struct sk_buff *frag_iter;
+ int elt = 0;
+
+ if (unlikely(recursion_level >= 24))
+ return -EMSGSIZE;
+
+ if (chunk > 0) {
+ if (chunk > len)
+ chunk = len;
+ elt++;
+ len -= chunk;
+ if (len == 0)
+ return elt;
+ offset += chunk;
+ }
+
+ for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
+ int end;
+
+ WARN_ON(start > offset + len);
+
+ end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
+ chunk = end - offset;
+ if (chunk > 0) {
+ if (chunk > len)
+ chunk = len;
+ elt++;
+ len -= chunk;
+ if (len == 0)
+ return elt;
+ offset += chunk;
+ }
+ start = end;
+ }
+
+ if (unlikely(skb_has_frag_list(skb))) {
+ skb_walk_frags(skb, frag_iter) {
+ int end, ret;
+
+ WARN_ON(start > offset + len);
+
+ end = start + frag_iter->len;
+ chunk = end - offset;
+ if (chunk > 0) {
+ if (chunk > len)
+ chunk = len;
+ ret = __skb_nsg(frag_iter, offset - start, chunk,
+ recursion_level + 1);
+ if (unlikely(ret < 0))
+ return ret;
+ elt += ret;
+ len -= chunk;
+ if (len == 0)
+ return elt;
+ offset += chunk;
+ }
+ start = end;
+ }
+ }
+ BUG_ON(len);
+ return elt;
+}
+
+/* Return the number of scatterlist elements required to completely map the
+ * skb, or -EMSGSIZE if the recursion depth is exceeded.
+ */
+static int skb_nsg(struct sk_buff *skb, int offset, int len)
+{
+ return __skb_nsg(skb, offset, len, 0);
+}
+
+static void tls_decrypt_done(struct crypto_async_request *req, int err)
+{
+ struct aead_request *aead_req = (struct aead_request *)req;
+ struct decrypt_req_ctx *req_ctx =
+ (struct decrypt_req_ctx *)(aead_req + 1);
+
+ struct scatterlist *sgout = aead_req->dst;
+
+ struct tls_context *tls_ctx = tls_get_ctx(req_ctx->sk);
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ int pending = atomic_dec_return(&ctx->decrypt_pending);
+ struct scatterlist *sg;
+ unsigned int pages;
+
+ /* Propagate if there was an err */
+ if (err) {
+ ctx->async_wait.err = err;
+ tls_err_abort(req_ctx->sk, err);
+ }
+
+ /* Release the skb, pages and memory allocated for crypto req */
+ kfree_skb(req->data);
+
+ /* Skip the first S/G entry as it points to AAD */
+ for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
+ if (!sg)
+ break;
+ put_page(sg_page(sg));
+ }
+
+ kfree(aead_req);
+
+ if (!pending && READ_ONCE(ctx->async_notify))
+ complete(&ctx->async_wait.completion);
+}
+
static int tls_do_decryption(struct sock *sk,
+ struct sk_buff *skb,
struct scatterlist *sgin,
struct scatterlist *sgout,
char *iv_recv,
size_t data_len,
- struct aead_request *aead_req)
+ struct aead_request *aead_req,
+ bool async)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -59,10 +173,34 @@ static int tls_do_decryption(struct sock *sk,
aead_request_set_crypt(aead_req, sgin, sgout,
data_len + tls_ctx->rx.tag_size,
(u8 *)iv_recv);
- aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
- crypto_req_done, &ctx->async_wait);
- ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
+ if (async) {
+ struct decrypt_req_ctx *req_ctx;
+
+ req_ctx = (struct decrypt_req_ctx *)(aead_req + 1);
+ req_ctx->sk = sk;
+
+ aead_request_set_callback(aead_req,
+ CRYPTO_TFM_REQ_MAY_BACKLOG,
+ tls_decrypt_done, skb);
+ atomic_inc(&ctx->decrypt_pending);
+ } else {
+ aead_request_set_callback(aead_req,
+ CRYPTO_TFM_REQ_MAY_BACKLOG,
+ crypto_req_done, &ctx->async_wait);
+ }
+
+ ret = crypto_aead_decrypt(aead_req);
+ if (ret == -EINPROGRESS) {
+ if (async)
+ return ret;
+
+ ret = crypto_wait_req(ret, &ctx->async_wait);
+ }
+
+ if (async)
+ atomic_dec(&ctx->decrypt_pending);
+
return ret;
}
@@ -354,7 +492,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
- int ret = 0;
+ int ret;
int required_size;
long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
bool eor = !(msg->msg_flags & MSG_MORE);
@@ -370,7 +508,8 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
lock_sock(sk);
- if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo))
+ ret = tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo);
+ if (ret)
goto send_end;
if (unlikely(msg->msg_controllen)) {
@@ -505,7 +644,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
- int ret = 0;
+ int ret;
long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
bool eor;
size_t orig_size = size;
@@ -525,7 +664,8 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
- if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo))
+ ret = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
+ if (ret)
goto sendpage_end;
/* Call the sk_stream functions to manage the sndbuf mem. */
@@ -684,12 +824,14 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
else
n_sgout = sg_nents(out_sg);
+ n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size,
+ rxm->full_len - tls_ctx->rx.prepend_size);
} else {
n_sgout = 0;
*zc = false;
+ n_sgin = skb_cow_data(skb, 0, &unused);
}
- n_sgin = skb_cow_data(skb, 0, &unused);
if (n_sgin < 1)
return -EBADMSG;
@@ -769,7 +911,10 @@ fallback_to_reg_recv:
}
/* Prepare and submit AEAD request */
- err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);
+ err = tls_do_decryption(sk, skb, sgin, sgout, iv,
+ data_len, aead_req, *zc);
+ if (err == -EINPROGRESS)
+ return err;
/* Release the pages in case iov was mapped to pages */
for (; pages > 0; pages--)
@@ -794,8 +939,12 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
#endif
if (!ctx->decrypted) {
err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
- if (err < 0)
+ if (err < 0) {
+ if (err == -EINPROGRESS)
+ tls_advance_record_sn(sk, &tls_ctx->rx);
+
return err;
+ }
} else {
*zc = false;
}
@@ -823,18 +972,20 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
- struct strp_msg *rxm = strp_msg(skb);
- if (len < rxm->full_len) {
- rxm->offset += len;
- rxm->full_len -= len;
+ if (skb) {
+ struct strp_msg *rxm = strp_msg(skb);
- return false;
+ if (len < rxm->full_len) {
+ rxm->offset += len;
+ rxm->full_len -= len;
+ return false;
+ }
+ kfree_skb(skb);
}
/* Finished with message */
ctx->recv_pkt = NULL;
- kfree_skb(skb);
__strp_unpause(&ctx->strp);
return true;
@@ -857,6 +1008,7 @@ int tls_sw_recvmsg(struct sock *sk,
int target, err = 0;
long timeo;
bool is_kvec = msg->msg_iter.type & ITER_KVEC;
+ int num_async = 0;
flags |= nonblock;
@@ -869,6 +1021,7 @@ int tls_sw_recvmsg(struct sock *sk,
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
do {
bool zc = false;
+ bool async = false;
int chunk = 0;
skb = tls_wait_data(sk, flags, timeo, &err);
@@ -876,6 +1029,7 @@ int tls_sw_recvmsg(struct sock *sk,
goto recv_end;
rxm = strp_msg(skb);
+
if (!cmsg) {
int cerr;
@@ -902,26 +1056,39 @@ int tls_sw_recvmsg(struct sock *sk,
err = decrypt_skb_update(sk, skb, &msg->msg_iter,
&chunk, &zc);
- if (err < 0) {
+ if (err < 0 && err != -EINPROGRESS) {
tls_err_abort(sk, EBADMSG);
goto recv_end;
}
+
+ if (err == -EINPROGRESS) {
+ async = true;
+ num_async++;
+ goto pick_next_record;
+ }
+
ctx->decrypted = true;
}
if (!zc) {
chunk = min_t(unsigned int, rxm->full_len, len);
+
err = skb_copy_datagram_msg(skb, rxm->offset, msg,
chunk);
if (err < 0)
goto recv_end;
}
+pick_next_record:
copied += chunk;
len -= chunk;
if (likely(!(flags & MSG_PEEK))) {
u8 control = ctx->control;
+ /* For async, drop current skb reference */
+ if (async)
+ skb = NULL;
+
if (tls_sw_advance_skb(sk, skb, chunk)) {
/* Return full control message to
* userspace before trying to parse
@@ -930,14 +1097,33 @@ int tls_sw_recvmsg(struct sock *sk,
msg->msg_flags |= MSG_EOR;
if (control != TLS_RECORD_TYPE_DATA)
goto recv_end;
+ } else {
+ break;
}
}
+
/* If we have a new message from strparser, continue now. */
if (copied >= target && !ctx->recv_pkt)
break;
} while (len);
recv_end:
+ if (num_async) {
+ /* Wait for all previously submitted records to be decrypted */
+ smp_store_mb(ctx->async_notify, true);
+ if (atomic_read(&ctx->decrypt_pending)) {
+ err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+ if (err) {
+ /* one of async decrypt failed */
+ tls_err_abort(sk, err);
+ copied = 0;
+ }
+ } else {
+ reinit_completion(&ctx->async_wait.completion);
+ }
+ WRITE_ONCE(ctx->async_notify, false);
+ }
+
release_sock(sk);
return copied ? : err;
}
@@ -1277,6 +1463,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
goto free_aead;
if (sw_ctx_rx) {
+ (*aead)->reqsize = sizeof(struct decrypt_req_ctx);
+
/* Set up strparser */
memset(&cb, 0, sizeof(cb));
cb.rcv_msg = tls_queue;