summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/linux/skmsg.h41
-rw-r--r--net/tls/tls_sw.c439
2 files changed, 414 insertions, 66 deletions
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index 4e84b3c2eff8..0b919f0bc6d6 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -29,7 +29,11 @@ struct sk_msg_sg {
u32 size;
u32 copybreak;
bool copy[MAX_MSG_FRAGS];
- struct scatterlist data[MAX_MSG_FRAGS];
+ /* The extra element is used for chaining the front and sections when
+ * the list becomes partitioned (e.g. end < start). The crypto APIs
+ * require the chaining.
+ */
+ struct scatterlist data[MAX_MSG_FRAGS + 1];
};
struct sk_msg {
@@ -112,6 +116,7 @@ void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
u32 bytes);
void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
+void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes);
@@ -161,8 +166,9 @@ static inline void sk_msg_clear_meta(struct sk_msg *msg)
static inline void sk_msg_init(struct sk_msg *msg)
{
+ BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != MAX_MSG_FRAGS);
memset(msg, 0, sizeof(*msg));
- sg_init_marker(msg->sg.data, ARRAY_SIZE(msg->sg.data));
+ sg_init_marker(msg->sg.data, MAX_MSG_FRAGS);
}
static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
@@ -174,6 +180,12 @@ static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
src->sg.data[which].offset += size;
}
+static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
+{
+ memcpy(dst, src, sizeof(*src));
+ sk_msg_init(src);
+}
+
static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
{
return msg->sg.end >= msg->sg.start ?
@@ -229,6 +241,26 @@ static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
sk_msg_iter_next(msg, end);
}
+static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
+{
+ do {
+ msg->sg.copy[i] = copy_state;
+ sk_msg_iter_var_next(i);
+ if (i == msg->sg.end)
+ break;
+ } while (1);
+}
+
+static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
+{
+ sk_msg_sg_copy(msg, start, true);
+}
+
+static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
+{
+ sk_msg_sg_copy(msg, start, false);
+}
+
static inline struct sk_psock *sk_psock(const struct sock *sk)
{
return rcu_dereference_sk_user_data(sk);
@@ -245,6 +277,11 @@ static inline void sk_psock_queue_msg(struct sk_psock *psock,
list_add_tail(&msg->list, &psock->ingress_msg);
}
+static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
+{
+ return psock ? list_empty(&psock->ingress_msg) : true;
+}
+
static inline void sk_psock_report_error(struct sk_psock *psock, int err)
{
struct sock *sk = psock->sk;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 3b75e0dd51a2..a525fc4c2a4b 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -4,6 +4,7 @@
* Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
* Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
* Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
+ * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
*
* This software is available to you under a choice of one of two
* licenses. You may choose to be licensed under the terms of the GNU
@@ -258,21 +259,58 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required)
return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
}
-static void tls_free_open_rec(struct sock *sk)
+static struct tls_rec *tls_get_rec(struct sock *sk)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
- struct tls_rec *rec = ctx->open_rec;
+ struct sk_msg *msg_pl, *msg_en;
+ struct tls_rec *rec;
+ int mem_size;
- /* Return if there is no open record */
+ mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
+
+ rec = kzalloc(mem_size, sk->sk_allocation);
if (!rec)
- return;
+ return NULL;
+ msg_pl = &rec->msg_plaintext;
+ msg_en = &rec->msg_encrypted;
+
+ sk_msg_init(msg_pl);
+ sk_msg_init(msg_en);
+
+ sg_init_table(rec->sg_aead_in, 2);
+ sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
+ sizeof(rec->aad_space));
+ sg_unmark_end(&rec->sg_aead_in[1]);
+
+ sg_init_table(rec->sg_aead_out, 2);
+ sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
+ sizeof(rec->aad_space));
+ sg_unmark_end(&rec->sg_aead_out[1]);
+
+ return rec;
+}
+
+static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
+{
sk_msg_free(sk, &rec->msg_encrypted);
sk_msg_free(sk, &rec->msg_plaintext);
kfree(rec);
}
+static void tls_free_open_rec(struct sock *sk)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+ struct tls_rec *rec = ctx->open_rec;
+
+ if (rec) {
+ tls_free_rec(sk, rec);
+ ctx->open_rec = NULL;
+ }
+}
+
int tls_tx_records(struct sock *sk, int flags)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
@@ -439,16 +477,135 @@ static int tls_do_encryption(struct sock *sk,
return rc;
}
+static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
+ struct tls_rec **to, struct sk_msg *msg_opl,
+ struct sk_msg *msg_oen, u32 split_point,
+ u32 tx_overhead_size, u32 *orig_end)
+{
+ u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
+ struct scatterlist *sge, *osge, *nsge;
+ u32 orig_size = msg_opl->sg.size;
+ struct scatterlist tmp = { };
+ struct sk_msg *msg_npl;
+ struct tls_rec *new;
+ int ret;
+
+ new = tls_get_rec(sk);
+ if (!new)
+ return -ENOMEM;
+ ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
+ tx_overhead_size, 0);
+ if (ret < 0) {
+ tls_free_rec(sk, new);
+ return ret;
+ }
+
+ *orig_end = msg_opl->sg.end;
+ i = msg_opl->sg.start;
+ sge = sk_msg_elem(msg_opl, i);
+ while (apply && sge->length) {
+ if (sge->length > apply) {
+ u32 len = sge->length - apply;
+
+ get_page(sg_page(sge));
+ sg_set_page(&tmp, sg_page(sge), len,
+ sge->offset + apply);
+ sge->length = apply;
+ bytes += apply;
+ apply = 0;
+ } else {
+ apply -= sge->length;
+ bytes += sge->length;
+ }
+
+ sk_msg_iter_var_next(i);
+ if (i == msg_opl->sg.end)
+ break;
+ sge = sk_msg_elem(msg_opl, i);
+ }
+
+ msg_opl->sg.end = i;
+ msg_opl->sg.curr = i;
+ msg_opl->sg.copybreak = 0;
+ msg_opl->apply_bytes = 0;
+ msg_opl->sg.size = bytes;
+
+ msg_npl = &new->msg_plaintext;
+ msg_npl->apply_bytes = apply;
+ msg_npl->sg.size = orig_size - bytes;
+
+ j = msg_npl->sg.start;
+ nsge = sk_msg_elem(msg_npl, j);
+ if (tmp.length) {
+ memcpy(nsge, &tmp, sizeof(*nsge));
+ sk_msg_iter_var_next(j);
+ nsge = sk_msg_elem(msg_npl, j);
+ }
+
+ osge = sk_msg_elem(msg_opl, i);
+ while (osge->length) {
+ memcpy(nsge, osge, sizeof(*nsge));
+ sg_unmark_end(nsge);
+ sk_msg_iter_var_next(i);
+ sk_msg_iter_var_next(j);
+ if (i == *orig_end)
+ break;
+ osge = sk_msg_elem(msg_opl, i);
+ nsge = sk_msg_elem(msg_npl, j);
+ }
+
+ msg_npl->sg.end = j;
+ msg_npl->sg.curr = j;
+ msg_npl->sg.copybreak = 0;
+
+ *to = new;
+ return 0;
+}
+
+static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
+ struct tls_rec *from, u32 orig_end)
+{
+ struct sk_msg *msg_npl = &from->msg_plaintext;
+ struct sk_msg *msg_opl = &to->msg_plaintext;
+ struct scatterlist *osge, *nsge;
+ u32 i, j;
+
+ i = msg_opl->sg.end;
+ sk_msg_iter_var_prev(i);
+ j = msg_npl->sg.start;
+
+ osge = sk_msg_elem(msg_opl, i);
+ nsge = sk_msg_elem(msg_npl, j);
+
+ if (sg_page(osge) == sg_page(nsge) &&
+ osge->offset + osge->length == nsge->offset) {
+ osge->length += nsge->length;
+ put_page(sg_page(nsge));
+ }
+
+ msg_opl->sg.end = orig_end;
+ msg_opl->sg.curr = orig_end;
+ msg_opl->sg.copybreak = 0;
+ msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
+ msg_opl->sg.size += msg_npl->sg.size;
+
+ sk_msg_free(sk, &to->msg_encrypted);
+ sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
+
+ kfree(from);
+}
+
static int tls_push_record(struct sock *sk, int flags,
unsigned char record_type)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
- struct tls_rec *rec = ctx->open_rec;
+ struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
+ u32 i, split_point, uninitialized_var(orig_end);
struct sk_msg *msg_pl, *msg_en;
struct aead_request *req;
+ bool split;
int rc;
- u32 i;
if (!rec)
return 0;
@@ -456,6 +613,18 @@ static int tls_push_record(struct sock *sk, int flags,
msg_pl = &rec->msg_plaintext;
msg_en = &rec->msg_encrypted;
+ split_point = msg_pl->apply_bytes;
+ split = split_point && split_point < msg_pl->sg.size;
+ if (split) {
+ rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
+ split_point, tls_ctx->tx.overhead_size,
+ &orig_end);
+ if (rc < 0)
+ return rc;
+ sk_msg_trim(sk, msg_en, msg_pl->sg.size +
+ tls_ctx->tx.overhead_size);
+ }
+
rec->tx_flags = flags;
req = &rec->aead_req;
@@ -487,57 +656,139 @@ static int tls_push_record(struct sock *sk, int flags,
rc = tls_do_encryption(sk, tls_ctx, ctx, req, msg_pl->sg.size, i);
if (rc < 0) {
- if (rc != -EINPROGRESS)
+ if (rc != -EINPROGRESS) {
tls_err_abort(sk, EBADMSG);
+ if (split) {
+ tls_ctx->pending_open_record_frags = true;
+ tls_merge_open_record(sk, rec, tmp, orig_end);
+ }
+ }
return rc;
+ } else if (split) {
+ msg_pl = &tmp->msg_plaintext;
+ msg_en = &tmp->msg_encrypted;
+ sk_msg_trim(sk, msg_en, msg_pl->sg.size +
+ tls_ctx->tx.overhead_size);
+ tls_ctx->pending_open_record_frags = true;
+ ctx->open_rec = tmp;
}
return tls_tx_records(sk, flags);
}
-static int tls_sw_push_pending_record(struct sock *sk, int flags)
-{
- return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
-}
-
-static struct tls_rec *get_rec(struct sock *sk)
+static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
+ bool full_record, u8 record_type,
+ size_t *copied, int flags)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
- struct sk_msg *msg_pl, *msg_en;
+ struct sk_msg msg_redir = { };
+ struct sk_psock *psock;
+ struct sock *sk_redir;
struct tls_rec *rec;
- int mem_size;
+ int err = 0, send;
+ bool enospc;
+
+ psock = sk_psock_get(sk);
+ if (!psock)
+ return tls_push_record(sk, flags, record_type);
+more_data:
+ enospc = sk_msg_full(msg);
+ if (psock->eval == __SK_NONE)
+ psock->eval = sk_psock_msg_verdict(sk, psock, msg);
+ if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
+ !enospc && !full_record) {
+ err = -ENOSPC;
+ goto out_err;
+ }
+ msg->cork_bytes = 0;
+ send = msg->sg.size;
+ if (msg->apply_bytes && msg->apply_bytes < send)
+ send = msg->apply_bytes;
+
+ switch (psock->eval) {
+ case __SK_PASS:
+ err = tls_push_record(sk, flags, record_type);
+ if (err < 0) {
+ *copied -= sk_msg_free(sk, msg);
+ tls_free_open_rec(sk);
+ goto out_err;
+ }
+ break;
+ case __SK_REDIRECT:
+ sk_redir = psock->sk_redir;
+ memcpy(&msg_redir, msg, sizeof(*msg));
+ if (msg->apply_bytes < send)
+ msg->apply_bytes = 0;
+ else
+ msg->apply_bytes -= send;
+ sk_msg_return_zero(sk, msg, send);
+ msg->sg.size -= send;
+ release_sock(sk);
+ err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
+ lock_sock(sk);
+ if (err < 0) {
+ *copied -= sk_msg_free_nocharge(sk, &msg_redir);
+ msg->sg.size = 0;
+ }
+ if (msg->sg.size == 0)
+ tls_free_open_rec(sk);
+ break;
+ case __SK_DROP:
+ default:
+ sk_msg_free_partial(sk, msg, send);
+ if (msg->apply_bytes < send)
+ msg->apply_bytes = 0;
+ else
+ msg->apply_bytes -= send;
+ if (msg->sg.size == 0)
+ tls_free_open_rec(sk);
+ *copied -= send;
+ err = -EACCES;
+ }
- /* Return if we already have an open record */
- if (ctx->open_rec)
- return ctx->open_rec;
+ if (likely(!err)) {
+ bool reset_eval = !ctx->open_rec;
- mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
+ rec = ctx->open_rec;
+ if (rec) {
+ msg = &rec->msg_plaintext;
+ if (!msg->apply_bytes)
+ reset_eval = true;
+ }
+ if (reset_eval) {
+ psock->eval = __SK_NONE;
+ if (psock->sk_redir) {
+ sock_put(psock->sk_redir);
+ psock->sk_redir = NULL;
+ }
+ }
+ if (rec)
+ goto more_data;
+ }
+ out_err:
+ sk_psock_put(sk, psock);
+ return err;
+}
+
+static int tls_sw_push_pending_record(struct sock *sk, int flags)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+ struct tls_rec *rec = ctx->open_rec;
+ struct sk_msg *msg_pl;
+ size_t copied;
- rec = kzalloc(mem_size, sk->sk_allocation);
if (!rec)
- return NULL;
+ return 0;
msg_pl = &rec->msg_plaintext;
- msg_en = &rec->msg_encrypted;
-
- sk_msg_init(msg_pl);
- sk_msg_init(msg_en);
-
- sg_init_table(rec->sg_aead_in, 2);
- sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
- sizeof(rec->aad_space));
- sg_unmark_end(&rec->sg_aead_in[1]);
-
- sg_init_table(rec->sg_aead_out, 2);
- sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
- sizeof(rec->aad_space));
- sg_unmark_end(&rec->sg_aead_out[1]);
-
- ctx->open_rec = rec;
- rec->inplace_crypto = 1;
+ copied = msg_pl->sg.size;
+ if (!copied)
+ return 0;
- return rec;
+ return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
+ &copied, flags);
}
int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
@@ -589,7 +840,10 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
goto send_end;
}
- rec = get_rec(sk);
+ if (ctx->open_rec)
+ rec = ctx->open_rec;
+ else
+ rec = ctx->open_rec = tls_get_rec(sk);
if (!rec) {
ret = -ENOMEM;
goto send_end;
@@ -628,6 +882,8 @@ alloc_encrypted:
}
if (!is_kvec && (full_record || eor) && !async_capable) {
+ u32 first = msg_pl->sg.end;
+
ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
msg_pl, try_to_copy);
if (ret)
@@ -637,15 +893,27 @@ alloc_encrypted:
num_zc++;
copied += try_to_copy;
- ret = tls_push_record(sk, msg->msg_flags, record_type);
+
+ sk_msg_sg_copy_set(msg_pl, first);
+ ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+ record_type, &copied,
+ msg->msg_flags);
if (ret) {
if (ret == -EINPROGRESS)
num_async++;
+ else if (ret == -ENOMEM)
+ goto wait_for_memory;
+ else if (ret == -ENOSPC)
+ goto rollback_iter;
else if (ret != -EAGAIN)
goto send_end;
}
continue;
-
+rollback_iter:
+ copied -= try_to_copy;
+ sk_msg_sg_copy_clear(msg_pl, first);
+ iov_iter_revert(&msg->msg_iter,
+ msg_pl->sg.size - orig_size);
fallback_to_reg_send:
sk_msg_trim(sk, msg_pl, orig_size);
}
@@ -678,12 +946,19 @@ fallback_to_reg_send:
tls_ctx->pending_open_record_frags = true;
copied += try_to_copy;
if (full_record || eor) {
- ret = tls_push_record(sk, msg->msg_flags, record_type);
+ ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+ record_type, &copied,
+ msg->msg_flags);
if (ret) {
if (ret == -EINPROGRESS)
num_async++;
- else if (ret != -EAGAIN)
+ else if (ret == -ENOMEM)
+ goto wait_for_memory;
+ else if (ret != -EAGAIN) {
+ if (ret == -ENOSPC)
+ ret = 0;
goto send_end;
+ }
}
}
@@ -742,10 +1017,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
unsigned char record_type = TLS_RECORD_TYPE_DATA;
- size_t orig_size = size;
struct sk_msg *msg_pl;
struct tls_rec *rec;
int num_async = 0;
+ size_t copied = 0;
bool full_record;
int record_room;
int ret = 0;
@@ -778,7 +1053,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
goto sendpage_end;
}
- rec = get_rec(sk);
+ if (ctx->open_rec)
+ rec = ctx->open_rec;
+ else
+ rec = ctx->open_rec = tls_get_rec(sk);
if (!rec) {
ret = -ENOMEM;
goto sendpage_end;
@@ -788,6 +1066,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
full_record = false;
record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
+ copied = 0;
copy = size;
if (copy >= record_room) {
copy = record_room;
@@ -818,16 +1097,23 @@ alloc_payload:
offset += copy;
size -= copy;
+ copied += copy;
tls_ctx->pending_open_record_frags = true;
if (full_record || eor || sk_msg_full(msg_pl)) {
rec->inplace_crypto = 0;
- ret = tls_push_record(sk, flags, record_type);
+ ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+ record_type, &copied, flags);
if (ret) {
if (ret == -EINPROGRESS)
num_async++;
- else if (ret != -EAGAIN)
+ else if (ret == -ENOMEM)
+ goto wait_for_memory;
+ else if (ret != -EAGAIN) {
+ if (ret == -ENOSPC)
+ ret = 0;
goto sendpage_end;
+ }
}
}
continue;
@@ -851,24 +1137,20 @@ wait_for_memory:
}
}
sendpage_end:
- if (orig_size > size)
- ret = orig_size - size;
- else
- ret = sk_stream_error(sk, flags, ret);
-
+ ret = sk_stream_error(sk, flags, ret);
release_sock(sk);
- return ret;
+ return copied ? copied : ret;
}
-static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
- long timeo, int *err)
+static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
+ int flags, long timeo, int *err)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_buff *skb;
DEFINE_WAIT_FUNC(wait, woken_wake_function);
- while (!(skb = ctx->recv_pkt)) {
+ while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
if (sk->sk_err) {
*err = sock_error(sk);
return NULL;
@@ -887,7 +1169,10 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
- sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
+ sk_wait_event(sk, &timeo,
+ ctx->recv_pkt != skb ||
+ !sk_psock_queue_empty(psock),
+ &wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait);
@@ -1164,6 +1449,7 @@ int tls_sw_recvmsg(struct sock *sk,
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ struct sk_psock *psock;
unsigned char control;
struct strp_msg *rxm;
struct sk_buff *skb;
@@ -1179,6 +1465,7 @@ int tls_sw_recvmsg(struct sock *sk,
if (unlikely(flags & MSG_ERRQUEUE))
return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
+ psock = sk_psock_get(sk);
lock_sock(sk);
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
@@ -1188,9 +1475,19 @@ int tls_sw_recvmsg(struct sock *sk,
bool async = false;
int chunk = 0;
- skb = tls_wait_data(sk, flags, timeo, &err);
- if (!skb)
+ skb = tls_wait_data(sk, psock, flags, timeo, &err);
+ if (!skb) {
+ if (psock) {
+ int ret = __tcp_bpf_recvmsg(sk, psock, msg, len);
+
+ if (ret > 0) {
+ copied += ret;
+ len -= ret;
+ continue;
+ }
+ }
goto recv_end;
+ }
rxm = strp_msg(skb);
@@ -1296,6 +1593,8 @@ recv_end:
}
release_sock(sk);
+ if (psock)
+ sk_psock_put(sk, psock);
return copied ? : err;
}
@@ -1318,7 +1617,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
- skb = tls_wait_data(sk, flags, timeo, &err);
+ skb = tls_wait_data(sk, NULL, flags, timeo, &err);
if (!skb)
goto splice_read_end;
@@ -1356,11 +1655,16 @@ bool tls_sw_stream_read(const struct sock *sk)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ bool ingress_empty = true;
+ struct sk_psock *psock;
- if (ctx->recv_pkt)
- return true;
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (psock)
+ ingress_empty = list_empty(&psock->ingress_msg);
+ rcu_read_unlock();
- return false;
+ return !ingress_empty || ctx->recv_pkt;
}
static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
@@ -1439,8 +1743,15 @@ static void tls_data_ready(struct sock *sk)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ struct sk_psock *psock;
strp_data_ready(&ctx->strp);
+
+ psock = sk_psock_get(sk);
+ if (psock && !list_empty(&psock->ingress_msg)) {
+ ctx->saved_data_ready(sk);
+ sk_psock_put(sk, psock);
+ }
}
void tls_sw_free_resources_tx(struct sock *sk)