summaryrefslogtreecommitdiff
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/tls.h3
-rw-r--r--net/tls/tls_device.c10
-rw-r--r--net/tls/tls_device_fallback.c31
-rw-r--r--net/tls/tls_main.c85
-rw-r--r--net/tls/tls_proc.c5
-rw-r--r--net/tls/tls_strp.c3
-rw-r--r--net/tls/tls_sw.c157
7 files changed, 207 insertions, 87 deletions
diff --git a/net/tls/tls.h b/net/tls/tls.h
index e5e47452308a..774859b63f0d 100644
--- a/net/tls/tls.h
+++ b/net/tls/tls.h
@@ -145,7 +145,8 @@ void tls_err_abort(struct sock *sk, int err);
int init_prot_info(struct tls_prot_info *prot,
const struct tls_crypto_info *crypto_info,
const struct tls_cipher_desc *cipher_desc);
-int tls_set_sw_offload(struct sock *sk, int tx);
+int tls_set_sw_offload(struct sock *sk, int tx,
+ struct tls_crypto_info *new_crypto_info);
void tls_update_rx_zc_capable(struct tls_context *tls_ctx);
void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx);
void tls_sw_strparser_done(struct tls_context *tls_ctx);
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index dc063c2c7950..f672a62a9a52 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -157,7 +157,7 @@ static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
offload_ctx->retransmit_hint = NULL;
}
-static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
+static void tls_tcp_clean_acked(struct sock *sk, u32 acked_seq)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_record_info *info, *temp;
@@ -204,7 +204,7 @@ void tls_device_sk_destruct(struct sock *sk)
destroy_record(ctx->open_record);
delete_all_records(ctx);
crypto_free_aead(ctx->aead_send);
- clean_acked_data_disable(inet_csk(sk));
+ clean_acked_data_disable(tcp_sk(sk));
}
tls_device_queue_ctx_destruction(tls_ctx);
@@ -1126,7 +1126,7 @@ int tls_set_device_offload(struct sock *sk)
start_marker_record->num_frags = 0;
list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
- clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
+ clean_acked_data_enable(tcp_sk(sk), &tls_tcp_clean_acked);
ctx->push_pending_record = tls_device_push_pending_record;
/* TLS offload is greatly simplified if we don't send
@@ -1172,7 +1172,7 @@ int tls_set_device_offload(struct sock *sk)
release_lock:
up_read(&device_offload_lock);
- clean_acked_data_disable(inet_csk(sk));
+ clean_acked_data_disable(tcp_sk(sk));
crypto_free_aead(offload_ctx->aead_send);
free_offload_ctx:
kfree(offload_ctx);
@@ -1227,7 +1227,7 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
context->resync_nh_reset = 1;
ctx->priv_ctx_rx = context;
- rc = tls_set_sw_offload(sk, 0);
+ rc = tls_set_sw_offload(sk, 0, NULL);
if (rc)
goto release_ctx;
diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
index f9e3d3d90dcf..03d508a45aae 100644
--- a/net/tls/tls_device_fallback.c
+++ b/net/tls/tls_device_fallback.c
@@ -37,17 +37,6 @@
#include "tls.h"
-static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
-{
- struct scatterlist *src = walk->sg;
- int diff = walk->offset - src->offset;
-
- sg_set_page(sg, sg_page(src),
- src->length - diff, walk->offset);
-
- scatterwalk_crypto_chain(sg, sg_next(src), 2);
-}
-
static int tls_enc_record(struct aead_request *aead_req,
struct crypto_aead *aead, char *aad,
char *iv, __be64 rcd_sn,
@@ -69,16 +58,13 @@ static int tls_enc_record(struct aead_request *aead_req,
buf_size = TLS_HEADER_SIZE + cipher_desc->iv;
len = min_t(int, *in_len, buf_size);
- scatterwalk_copychunks(buf, in, len, 0);
- scatterwalk_copychunks(buf, out, len, 1);
+ memcpy_from_scatterwalk(buf, in, len);
+ memcpy_to_scatterwalk(out, buf, len);
*in_len -= len;
if (!*in_len)
return 0;
- scatterwalk_pagedone(in, 0, 1);
- scatterwalk_pagedone(out, 1, 1);
-
len = buf[4] | (buf[3] << 8);
len -= cipher_desc->iv;
@@ -90,8 +76,8 @@ static int tls_enc_record(struct aead_request *aead_req,
sg_init_table(sg_out, ARRAY_SIZE(sg_out));
sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE);
sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE);
- chain_to_walk(sg_in + 1, in);
- chain_to_walk(sg_out + 1, out);
+ scatterwalk_get_sglist(in, sg_in + 1);
+ scatterwalk_get_sglist(out, sg_out + 1);
*in_len -= len;
if (*in_len < 0) {
@@ -110,10 +96,8 @@ static int tls_enc_record(struct aead_request *aead_req,
}
if (*in_len) {
- scatterwalk_copychunks(NULL, in, len, 2);
- scatterwalk_pagedone(in, 0, 1);
- scatterwalk_copychunks(NULL, out, len, 2);
- scatterwalk_pagedone(out, 1, 1);
+ scatterwalk_skip(in, len);
+ scatterwalk_skip(out, len);
}
len -= cipher_desc->tag;
@@ -162,9 +146,6 @@ static int tls_enc_records(struct aead_request *aead_req,
} while (rc == 0 && len);
- scatterwalk_done(&in, 0, 0);
- scatterwalk_done(&out, 1, 0);
-
return rc;
}
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 6b4b9f2749a6..a3ccb3135e51 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -423,9 +423,10 @@ static __poll_t tls_sk_poll(struct file *file, struct socket *sock,
ctx = tls_sw_ctx_rx(tls_ctx);
psock = sk_psock_get(sk);
- if (skb_queue_empty_lockless(&ctx->rx_list) &&
- !tls_strp_msg_ready(ctx) &&
- sk_psock_queue_empty(psock))
+ if ((skb_queue_empty_lockless(&ctx->rx_list) &&
+ !tls_strp_msg_ready(ctx) &&
+ sk_psock_queue_empty(psock)) ||
+ READ_ONCE(ctx->key_update_pending))
mask &= ~(EPOLLIN | EPOLLRDNORM);
if (psock)
@@ -612,11 +613,13 @@ static int validate_crypto_info(const struct tls_crypto_info *crypto_info,
static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
unsigned int optlen, int tx)
{
- struct tls_crypto_info *crypto_info;
- struct tls_crypto_info *alt_crypto_info;
+ struct tls_crypto_info *crypto_info, *alt_crypto_info;
+ struct tls_crypto_info *old_crypto_info = NULL;
struct tls_context *ctx = tls_get_ctx(sk);
const struct tls_cipher_desc *cipher_desc;
union tls_crypto_context *crypto_ctx;
+ union tls_crypto_context tmp = {};
+ bool update = false;
int rc = 0;
int conf;
@@ -633,9 +636,21 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
crypto_info = &crypto_ctx->info;
- /* Currently we don't support set crypto info more than one time */
- if (TLS_CRYPTO_INFO_READY(crypto_info))
- return -EBUSY;
+ if (TLS_CRYPTO_INFO_READY(crypto_info)) {
+ /* Currently we only support setting crypto info more
+ * than one time for TLS 1.3
+ */
+ if (crypto_info->version != TLS_1_3_VERSION) {
+ TLS_INC_STATS(sock_net(sk), tx ? LINUX_MIB_TLSTXREKEYERROR
+ : LINUX_MIB_TLSRXREKEYERROR);
+ return -EBUSY;
+ }
+
+ update = true;
+ old_crypto_info = crypto_info;
+ crypto_info = &tmp.info;
+ crypto_ctx = &tmp;
+ }
rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
if (rc) {
@@ -643,7 +658,14 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
goto err_crypto_info;
}
- rc = validate_crypto_info(crypto_info, alt_crypto_info);
+ if (update) {
+ /* Ensure that TLS version and ciphers are not modified */
+ if (crypto_info->version != old_crypto_info->version ||
+ crypto_info->cipher_type != old_crypto_info->cipher_type)
+ rc = -EINVAL;
+ } else {
+ rc = validate_crypto_info(crypto_info, alt_crypto_info);
+ }
if (rc)
goto err_crypto_info;
@@ -673,11 +695,17 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
} else {
- rc = tls_set_sw_offload(sk, 1);
+ rc = tls_set_sw_offload(sk, 1,
+ update ? crypto_info : NULL);
if (rc)
goto err_crypto_info;
- TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
- TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
+
+ if (update) {
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXREKEYOK);
+ } else {
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
+ }
conf = TLS_SW;
}
} else {
@@ -687,14 +715,21 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
} else {
- rc = tls_set_sw_offload(sk, 0);
+ rc = tls_set_sw_offload(sk, 0,
+ update ? crypto_info : NULL);
if (rc)
goto err_crypto_info;
- TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
- TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
+
+ if (update) {
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYOK);
+ } else {
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
+ }
conf = TLS_SW;
}
- tls_sw_strparser_arm(sk, ctx);
+ if (!update)
+ tls_sw_strparser_arm(sk, ctx);
}
if (tx)
@@ -702,6 +737,10 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
else
ctx->rx_conf = conf;
update_sk_prot(sk, ctx);
+
+ if (update)
+ return 0;
+
if (tx) {
ctx->sk_write_space = sk->sk_write_space;
sk->sk_write_space = tls_write_space;
@@ -713,6 +752,10 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
return 0;
err_crypto_info:
+ if (update) {
+ TLS_INC_STATS(sock_net(sk), tx ? LINUX_MIB_TLSTXREKEYERROR
+ : LINUX_MIB_TLSRXREKEYERROR);
+ }
memzero_explicit(crypto_ctx, sizeof(*crypto_ctx));
return rc;
}
@@ -809,6 +852,11 @@ static int tls_setsockopt(struct sock *sk, int level, int optname,
return do_tls_setsockopt(sk, optname, optval, optlen);
}
+static int tls_disconnect(struct sock *sk, int flags)
+{
+ return -EOPNOTSUPP;
+}
+
struct tls_context *tls_ctx_create(struct sock *sk)
{
struct inet_connection_sock *icsk = inet_csk(sk);
@@ -904,6 +952,7 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
prot[TLS_BASE][TLS_BASE] = *base;
prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
+ prot[TLS_BASE][TLS_BASE].disconnect = tls_disconnect;
prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;
prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
@@ -1014,7 +1063,7 @@ static u16 tls_user_config(struct tls_context *ctx, bool tx)
return 0;
}
-static int tls_get_info(struct sock *sk, struct sk_buff *skb)
+static int tls_get_info(struct sock *sk, struct sk_buff *skb, bool net_admin)
{
u16 version, cipher_type;
struct tls_context *ctx;
@@ -1072,7 +1121,7 @@ nla_failure:
return err;
}
-static size_t tls_get_info_size(const struct sock *sk)
+static size_t tls_get_info_size(const struct sock *sk, bool net_admin)
{
size_t size = 0;
diff --git a/net/tls/tls_proc.c b/net/tls/tls_proc.c
index 68982728f620..367666aa07b8 100644
--- a/net/tls/tls_proc.c
+++ b/net/tls/tls_proc.c
@@ -22,6 +22,11 @@ static const struct snmp_mib tls_mib_list[] = {
SNMP_MIB_ITEM("TlsRxDeviceResync", LINUX_MIB_TLSRXDEVICERESYNC),
SNMP_MIB_ITEM("TlsDecryptRetry", LINUX_MIB_TLSDECRYPTRETRY),
SNMP_MIB_ITEM("TlsRxNoPadViolation", LINUX_MIB_TLSRXNOPADVIOL),
+ SNMP_MIB_ITEM("TlsRxRekeyOk", LINUX_MIB_TLSRXREKEYOK),
+ SNMP_MIB_ITEM("TlsRxRekeyError", LINUX_MIB_TLSRXREKEYERROR),
+ SNMP_MIB_ITEM("TlsTxRekeyOk", LINUX_MIB_TLSTXREKEYOK),
+ SNMP_MIB_ITEM("TlsTxRekeyError", LINUX_MIB_TLSTXREKEYERROR),
+ SNMP_MIB_ITEM("TlsRxRekeyReceived", LINUX_MIB_TLSRXREKEYRECEIVED),
SNMP_MIB_SENTINEL
};
diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c
index 77e33e1e340e..65b0da6fdf6a 100644
--- a/net/tls/tls_strp.c
+++ b/net/tls/tls_strp.c
@@ -396,7 +396,6 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort)
return 0;
shinfo = skb_shinfo(strp->anchor);
- shinfo->frag_list = NULL;
/* If we don't know the length go max plus page for cipher overhead */
need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
@@ -412,6 +411,8 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort)
page, 0, 0);
}
+ shinfo->frag_list = NULL;
+
strp->copy_mode = 1;
strp->stm.offset = 0;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index bbf26cc4f6ee..fc88e34b7f33 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -458,7 +458,7 @@ int tls_tx_records(struct sock *sk, int flags)
tx_err:
if (rc < 0 && rc != -EAGAIN)
- tls_err_abort(sk, -EBADMSG);
+ tls_err_abort(sk, rc);
return rc;
}
@@ -908,6 +908,13 @@ more_data:
&msg_redir, send, flags);
lock_sock(sk);
if (err < 0) {
+ /* Regardless of whether the data represented by
+ * msg_redir is sent successfully, we have already
+ * uncharged it via sk_msg_return_zero(). The
+ * msg->sg.size represents the remaining unprocessed
+ * data, which needs to be uncharged here.
+ */
+ sk_mem_uncharge(sk, msg->sg.size);
*copied -= sk_msg_free_nocharge(sk, &msg_redir);
msg->sg.size = 0;
}
@@ -1120,9 +1127,13 @@ alloc_encrypted:
num_async++;
else if (ret == -ENOMEM)
goto wait_for_memory;
- else if (ctx->open_rec && ret == -ENOSPC)
+ else if (ctx->open_rec && ret == -ENOSPC) {
+ if (msg_pl->cork_bytes) {
+ ret = 0;
+ goto send_end;
+ }
goto rollback_iter;
- else if (ret != -EAGAIN)
+ } else if (ret != -EAGAIN)
goto send_end;
}
continue;
@@ -1314,6 +1325,10 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
int ret = 0;
long timeo;
+ /* a rekey is pending, let userspace deal with it */
+ if (unlikely(ctx->key_update_pending))
+ return -EKEYEXPIRED;
+
timeo = sock_rcvtimeo(sk, nonblock);
while (!tls_strp_msg_ready(ctx)) {
@@ -1720,6 +1735,36 @@ tls_decrypt_device(struct sock *sk, struct msghdr *msg,
return 1;
}
+static int tls_check_pending_rekey(struct sock *sk, struct tls_context *ctx,
+ struct sk_buff *skb)
+{
+ const struct strp_msg *rxm = strp_msg(skb);
+ const struct tls_msg *tlm = tls_msg(skb);
+ char hs_type;
+ int err;
+
+ if (likely(tlm->control != TLS_RECORD_TYPE_HANDSHAKE))
+ return 0;
+
+ if (rxm->full_len < 1)
+ return 0;
+
+ err = skb_copy_bits(skb, rxm->offset, &hs_type, 1);
+ if (err < 0) {
+ DEBUG_NET_WARN_ON_ONCE(1);
+ return err;
+ }
+
+ if (hs_type == TLS_HANDSHAKE_KEYUPDATE) {
+ struct tls_sw_context_rx *rx_ctx = ctx->priv_ctx_rx;
+
+ WRITE_ONCE(rx_ctx->key_update_pending, true);
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYRECEIVED);
+ }
+
+ return 0;
+}
+
static int tls_rx_one_record(struct sock *sk, struct msghdr *msg,
struct tls_decrypt_arg *darg)
{
@@ -1739,7 +1784,7 @@ static int tls_rx_one_record(struct sock *sk, struct msghdr *msg,
rxm->full_len -= prot->overhead_size;
tls_advance_record_sn(sk, prot, &tls_ctx->rx);
- return 0;
+ return tls_check_pending_rekey(sk, tls_ctx, darg->skb);
}
int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
@@ -2684,12 +2729,22 @@ int init_prot_info(struct tls_prot_info *prot,
return 0;
}
-int tls_set_sw_offload(struct sock *sk, int tx)
+static void tls_finish_key_update(struct sock *sk, struct tls_context *tls_ctx)
{
+ struct tls_sw_context_rx *ctx = tls_ctx->priv_ctx_rx;
+
+ WRITE_ONCE(ctx->key_update_pending, false);
+ /* wake-up pre-existing poll() */
+ ctx->saved_data_ready(sk);
+}
+
+int tls_set_sw_offload(struct sock *sk, int tx,
+ struct tls_crypto_info *new_crypto_info)
+{
+ struct tls_crypto_info *crypto_info, *src_crypto_info;
struct tls_sw_context_tx *sw_ctx_tx = NULL;
struct tls_sw_context_rx *sw_ctx_rx = NULL;
const struct tls_cipher_desc *cipher_desc;
- struct tls_crypto_info *crypto_info;
char *iv, *rec_seq, *key, *salt;
struct cipher_context *cctx;
struct tls_prot_info *prot;
@@ -2701,44 +2756,47 @@ int tls_set_sw_offload(struct sock *sk, int tx)
ctx = tls_get_ctx(sk);
prot = &ctx->prot_info;
- if (tx) {
- ctx->priv_ctx_tx = init_ctx_tx(ctx, sk);
- if (!ctx->priv_ctx_tx)
- return -ENOMEM;
+ /* new_crypto_info != NULL means rekey */
+ if (!new_crypto_info) {
+ if (tx) {
+ ctx->priv_ctx_tx = init_ctx_tx(ctx, sk);
+ if (!ctx->priv_ctx_tx)
+ return -ENOMEM;
+ } else {
+ ctx->priv_ctx_rx = init_ctx_rx(ctx);
+ if (!ctx->priv_ctx_rx)
+ return -ENOMEM;
+ }
+ }
+ if (tx) {
sw_ctx_tx = ctx->priv_ctx_tx;
crypto_info = &ctx->crypto_send.info;
cctx = &ctx->tx;
aead = &sw_ctx_tx->aead_send;
} else {
- ctx->priv_ctx_rx = init_ctx_rx(ctx);
- if (!ctx->priv_ctx_rx)
- return -ENOMEM;
-
sw_ctx_rx = ctx->priv_ctx_rx;
crypto_info = &ctx->crypto_recv.info;
cctx = &ctx->rx;
aead = &sw_ctx_rx->aead_recv;
}
- cipher_desc = get_cipher_desc(crypto_info->cipher_type);
+ src_crypto_info = new_crypto_info ?: crypto_info;
+
+ cipher_desc = get_cipher_desc(src_crypto_info->cipher_type);
if (!cipher_desc) {
rc = -EINVAL;
goto free_priv;
}
- rc = init_prot_info(prot, crypto_info, cipher_desc);
+ rc = init_prot_info(prot, src_crypto_info, cipher_desc);
if (rc)
goto free_priv;
- iv = crypto_info_iv(crypto_info, cipher_desc);
- key = crypto_info_key(crypto_info, cipher_desc);
- salt = crypto_info_salt(crypto_info, cipher_desc);
- rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc);
-
- memcpy(cctx->iv, salt, cipher_desc->salt);
- memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv);
- memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq);
+ iv = crypto_info_iv(src_crypto_info, cipher_desc);
+ key = crypto_info_key(src_crypto_info, cipher_desc);
+ salt = crypto_info_salt(src_crypto_info, cipher_desc);
+ rec_seq = crypto_info_rec_seq(src_crypto_info, cipher_desc);
if (!*aead) {
*aead = crypto_alloc_aead(cipher_desc->cipher_name, 0, 0);
@@ -2751,20 +2809,30 @@ int tls_set_sw_offload(struct sock *sk, int tx)
ctx->push_pending_record = tls_sw_push_pending_record;
+ /* setkey is the last operation that could fail during a
+ * rekey. if it succeeds, we can start modifying the
+ * context.
+ */
rc = crypto_aead_setkey(*aead, key, cipher_desc->key);
- if (rc)
- goto free_aead;
+ if (rc) {
+ if (new_crypto_info)
+ goto out;
+ else
+ goto free_aead;
+ }
- rc = crypto_aead_setauthsize(*aead, prot->tag_size);
- if (rc)
- goto free_aead;
+ if (!new_crypto_info) {
+ rc = crypto_aead_setauthsize(*aead, prot->tag_size);
+ if (rc)
+ goto free_aead;
+ }
- if (sw_ctx_rx) {
+ if (!tx && !new_crypto_info) {
tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
tls_update_rx_zc_capable(ctx);
sw_ctx_rx->async_capable =
- crypto_info->version != TLS_1_3_VERSION &&
+ src_crypto_info->version != TLS_1_3_VERSION &&
!!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
rc = tls_strp_init(&sw_ctx_rx->strp, sk);
@@ -2772,18 +2840,33 @@ int tls_set_sw_offload(struct sock *sk, int tx)
goto free_aead;
}
+ memcpy(cctx->iv, salt, cipher_desc->salt);
+ memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv);
+ memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq);
+
+ if (new_crypto_info) {
+ unsafe_memcpy(crypto_info, new_crypto_info,
+ cipher_desc->crypto_info,
+ /* size was checked in do_tls_setsockopt_conf */);
+ memzero_explicit(new_crypto_info, cipher_desc->crypto_info);
+ if (!tx)
+ tls_finish_key_update(sk, ctx);
+ }
+
goto out;
free_aead:
crypto_free_aead(*aead);
*aead = NULL;
free_priv:
- if (tx) {
- kfree(ctx->priv_ctx_tx);
- ctx->priv_ctx_tx = NULL;
- } else {
- kfree(ctx->priv_ctx_rx);
- ctx->priv_ctx_rx = NULL;
+ if (!new_crypto_info) {
+ if (tx) {
+ kfree(ctx->priv_ctx_tx);
+ ctx->priv_ctx_tx = NULL;
+ } else {
+ kfree(ctx->priv_ctx_rx);
+ ctx->priv_ctx_rx = NULL;
+ }
}
out:
return rc;