summaryrefslogtreecommitdiff
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/tls_device.c304
-rw-r--r--net/tls/tls_device_fallback.c11
-rw-r--r--net/tls/tls_main.c32
-rw-r--r--net/tls/tls_sw.c347
4 files changed, 509 insertions, 185 deletions
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index a7a8f8e20ff3..292742e50bfa 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -52,9 +52,12 @@ static DEFINE_SPINLOCK(tls_device_lock);
static void tls_device_free_ctx(struct tls_context *ctx)
{
- struct tls_offload_context *offload_ctx = tls_offload_ctx(ctx);
+ if (ctx->tx_conf == TLS_HW)
+ kfree(tls_offload_ctx_tx(ctx));
+
+ if (ctx->rx_conf == TLS_HW)
+ kfree(tls_offload_ctx_rx(ctx));
- kfree(offload_ctx);
kfree(ctx);
}
@@ -71,10 +74,11 @@ static void tls_device_gc_task(struct work_struct *work)
list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
struct net_device *netdev = ctx->netdev;
- if (netdev) {
+ if (netdev && ctx->tx_conf == TLS_HW) {
netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
TLS_OFFLOAD_CTX_DIR_TX);
dev_put(netdev);
+ ctx->netdev = NULL;
}
list_del(&ctx->list);
@@ -82,6 +86,22 @@ static void tls_device_gc_task(struct work_struct *work)
}
}
+static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
+ struct net_device *netdev)
+{
+ if (sk->sk_destruct != tls_device_sk_destruct) {
+ refcount_set(&ctx->refcount, 1);
+ dev_hold(netdev);
+ ctx->netdev = netdev;
+ spin_lock_irq(&tls_device_lock);
+ list_add_tail(&ctx->list, &tls_device_list);
+ spin_unlock_irq(&tls_device_lock);
+
+ ctx->sk_destruct = sk->sk_destruct;
+ sk->sk_destruct = tls_device_sk_destruct;
+ }
+}
+
static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
{
unsigned long flags;
@@ -125,7 +145,7 @@ static void destroy_record(struct tls_record_info *record)
kfree(record);
}
-static void delete_all_records(struct tls_offload_context *offload_ctx)
+static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
{
struct tls_record_info *info, *temp;
@@ -141,14 +161,14 @@ static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_record_info *info, *temp;
- struct tls_offload_context *ctx;
+ struct tls_offload_context_tx *ctx;
u64 deleted_records = 0;
unsigned long flags;
if (!tls_ctx)
return;
- ctx = tls_offload_ctx(tls_ctx);
+ ctx = tls_offload_ctx_tx(tls_ctx);
spin_lock_irqsave(&ctx->lock, flags);
info = ctx->retransmit_hint;
@@ -179,15 +199,17 @@ static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
void tls_device_sk_destruct(struct sock *sk)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
- struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+ struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
- if (ctx->open_record)
- destroy_record(ctx->open_record);
+ tls_ctx->sk_destruct(sk);
- delete_all_records(ctx);
- crypto_free_aead(ctx->aead_send);
- ctx->sk_destruct(sk);
- clean_acked_data_disable(inet_csk(sk));
+ if (tls_ctx->tx_conf == TLS_HW) {
+ if (ctx->open_record)
+ destroy_record(ctx->open_record);
+ delete_all_records(ctx);
+ crypto_free_aead(ctx->aead_send);
+ clean_acked_data_disable(inet_csk(sk));
+ }
if (refcount_dec_and_test(&tls_ctx->refcount))
tls_device_queue_ctx_destruction(tls_ctx);
@@ -219,7 +241,7 @@ static void tls_append_frag(struct tls_record_info *record,
static int tls_push_record(struct sock *sk,
struct tls_context *ctx,
- struct tls_offload_context *offload_ctx,
+ struct tls_offload_context_tx *offload_ctx,
struct tls_record_info *record,
struct page_frag *pfrag,
int flags,
@@ -264,7 +286,7 @@ static int tls_push_record(struct sock *sk,
return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
}
-static int tls_create_new_record(struct tls_offload_context *offload_ctx,
+static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
struct page_frag *pfrag,
size_t prepend_size)
{
@@ -290,7 +312,7 @@ static int tls_create_new_record(struct tls_offload_context *offload_ctx,
}
static int tls_do_allocation(struct sock *sk,
- struct tls_offload_context *offload_ctx,
+ struct tls_offload_context_tx *offload_ctx,
struct page_frag *pfrag,
size_t prepend_size)
{
@@ -324,7 +346,7 @@ static int tls_push_data(struct sock *sk,
unsigned char record_type)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
- struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+ struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
struct tls_record_info *record = ctx->open_record;
@@ -477,7 +499,7 @@ out:
return rc;
}
-struct tls_record_info *tls_get_record(struct tls_offload_context *context,
+struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
u32 seq, u64 *p_record_sn)
{
u64 record_sn = context->hint_record_sn;
@@ -520,11 +542,123 @@ static int tls_device_push_pending_record(struct sock *sk, int flags)
return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
}
+void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct net_device *netdev = tls_ctx->netdev;
+ struct tls_offload_context_rx *rx_ctx;
+ u32 is_req_pending;
+ s64 resync_req;
+ u32 req_seq;
+
+ if (tls_ctx->rx_conf != TLS_HW)
+ return;
+
+ rx_ctx = tls_offload_ctx_rx(tls_ctx);
+ resync_req = atomic64_read(&rx_ctx->resync_req);
+ req_seq = ntohl(resync_req >> 32) - ((u32)TLS_HEADER_SIZE - 1);
+ is_req_pending = resync_req;
+
+ if (unlikely(is_req_pending) && req_seq == seq &&
+ atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
+ netdev->tlsdev_ops->tls_dev_resync_rx(netdev, sk,
+ seq + TLS_HEADER_SIZE - 1,
+ rcd_sn);
+}
+
+static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
+{
+ struct strp_msg *rxm = strp_msg(skb);
+ int err = 0, offset = rxm->offset, copy, nsg;
+ struct sk_buff *skb_iter, *unused;
+ struct scatterlist sg[1];
+ char *orig_buf, *buf;
+
+ orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
+ TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
+ if (!orig_buf)
+ return -ENOMEM;
+ buf = orig_buf;
+
+ nsg = skb_cow_data(skb, 0, &unused);
+ if (unlikely(nsg < 0)) {
+ err = nsg;
+ goto free_buf;
+ }
+
+ sg_init_table(sg, 1);
+ sg_set_buf(&sg[0], buf,
+ rxm->full_len + TLS_HEADER_SIZE +
+ TLS_CIPHER_AES_GCM_128_IV_SIZE);
+ skb_copy_bits(skb, offset, buf,
+ TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
+
+ /* We are interested only in the decrypted data not the auth */
+ err = decrypt_skb(sk, skb, sg);
+ if (err != -EBADMSG)
+ goto free_buf;
+ else
+ err = 0;
+
+ copy = min_t(int, skb_pagelen(skb) - offset,
+ rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+
+ if (skb->decrypted)
+ skb_store_bits(skb, offset, buf, copy);
+
+ offset += copy;
+ buf += copy;
+
+ skb_walk_frags(skb, skb_iter) {
+ copy = min_t(int, skb_iter->len,
+ rxm->full_len - offset + rxm->offset -
+ TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+
+ if (skb_iter->decrypted)
+ skb_store_bits(skb_iter, offset, buf, copy);
+
+ offset += copy;
+ buf += copy;
+ }
+
+free_buf:
+ kfree(orig_buf);
+ return err;
+}
+
+int tls_device_decrypted(struct sock *sk, struct sk_buff *skb)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
+ int is_decrypted = skb->decrypted;
+ int is_encrypted = !is_decrypted;
+ struct sk_buff *skb_iter;
+
+ /* Skip if it is already decrypted */
+ if (ctx->sw.decrypted)
+ return 0;
+
+ /* Check if all the data is decrypted already */
+ skb_walk_frags(skb, skb_iter) {
+ is_decrypted &= skb_iter->decrypted;
+ is_encrypted &= !skb_iter->decrypted;
+ }
+
+ ctx->sw.decrypted |= is_decrypted;
+
+ /* Return immedeatly if the record is either entirely plaintext or
+ * entirely ciphertext. Otherwise handle reencrypt partially decrypted
+ * record.
+ */
+ return (is_encrypted || is_decrypted) ? 0 :
+ tls_device_reencrypt(sk, skb);
+}
+
int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
{
u16 nonce_size, tag_size, iv_size, rec_seq_size;
struct tls_record_info *start_marker_record;
- struct tls_offload_context *offload_ctx;
+ struct tls_offload_context_tx *offload_ctx;
struct tls_crypto_info *crypto_info;
struct net_device *netdev;
char *iv, *rec_seq;
@@ -546,7 +680,7 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
goto out;
}
- offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE, GFP_KERNEL);
+ offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
if (!offload_ctx) {
rc = -ENOMEM;
goto free_marker_record;
@@ -582,12 +716,11 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
ctx->tx.rec_seq_size = rec_seq_size;
- ctx->tx.rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
+ ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
if (!ctx->tx.rec_seq) {
rc = -ENOMEM;
goto free_iv;
}
- memcpy(ctx->tx.rec_seq, rec_seq, rec_seq_size);
rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
if (rc)
@@ -609,7 +742,6 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
ctx->push_pending_record = tls_device_push_pending_record;
- offload_ctx->sk_destruct = sk->sk_destruct;
/* TLS offload is greatly simplified if we don't send
* SKBs where only part of the payload needs to be encrypted.
@@ -619,8 +751,6 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
if (skb)
TCP_SKB_CB(skb)->eor = 1;
- refcount_set(&ctx->refcount, 1);
-
/* We support starting offload on multiple sockets
* concurrently, so we only need a read lock here.
* This lock must precede get_netdev_for_sock to prevent races between
@@ -655,19 +785,14 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
if (rc)
goto release_netdev;
- ctx->netdev = netdev;
-
- spin_lock_irq(&tls_device_lock);
- list_add_tail(&ctx->list, &tls_device_list);
- spin_unlock_irq(&tls_device_lock);
+ tls_device_attach(ctx, sk, netdev);
- sk->sk_validate_xmit_skb = tls_validate_xmit_skb;
/* following this assignment tls_is_sk_tx_device_offloaded
* will return true and the context might be accessed
* by the netdev's xmit function.
*/
- smp_store_release(&sk->sk_destruct,
- &tls_device_sk_destruct);
+ smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
+ dev_put(netdev);
up_read(&device_offload_lock);
goto out;
@@ -690,6 +815,105 @@ out:
return rc;
}
+int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
+{
+ struct tls_offload_context_rx *context;
+ struct net_device *netdev;
+ int rc = 0;
+
+ /* We support starting offload on multiple sockets
+ * concurrently, so we only need a read lock here.
+ * This lock must precede get_netdev_for_sock to prevent races between
+ * NETDEV_DOWN and setsockopt.
+ */
+ down_read(&device_offload_lock);
+ netdev = get_netdev_for_sock(sk);
+ if (!netdev) {
+ pr_err_ratelimited("%s: netdev not found\n", __func__);
+ rc = -EINVAL;
+ goto release_lock;
+ }
+
+ if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
+ pr_err_ratelimited("%s: netdev %s with no TLS offload\n",
+ __func__, netdev->name);
+ rc = -ENOTSUPP;
+ goto release_netdev;
+ }
+
+ /* Avoid offloading if the device is down
+ * We don't want to offload new flows after
+ * the NETDEV_DOWN event
+ */
+ if (!(netdev->flags & IFF_UP)) {
+ rc = -EINVAL;
+ goto release_netdev;
+ }
+
+ context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
+ if (!context) {
+ rc = -ENOMEM;
+ goto release_netdev;
+ }
+
+ ctx->priv_ctx_rx = context;
+ rc = tls_set_sw_offload(sk, ctx, 0);
+ if (rc)
+ goto release_ctx;
+
+ rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
+ &ctx->crypto_recv,
+ tcp_sk(sk)->copied_seq);
+ if (rc) {
+ pr_err_ratelimited("%s: The netdev has refused to offload this socket\n",
+ __func__);
+ goto free_sw_resources;
+ }
+
+ tls_device_attach(ctx, sk, netdev);
+ goto release_netdev;
+
+free_sw_resources:
+ tls_sw_free_resources_rx(sk);
+release_ctx:
+ ctx->priv_ctx_rx = NULL;
+release_netdev:
+ dev_put(netdev);
+release_lock:
+ up_read(&device_offload_lock);
+ return rc;
+}
+
+void tls_device_offload_cleanup_rx(struct sock *sk)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct net_device *netdev;
+
+ down_read(&device_offload_lock);
+ netdev = tls_ctx->netdev;
+ if (!netdev)
+ goto out;
+
+ if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
+ pr_err_ratelimited("%s: device is missing NETIF_F_HW_TLS_RX cap\n",
+ __func__);
+ goto out;
+ }
+
+ netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
+ TLS_OFFLOAD_CTX_DIR_RX);
+
+ if (tls_ctx->tx_conf != TLS_HW) {
+ dev_put(netdev);
+ tls_ctx->netdev = NULL;
+ }
+out:
+ up_read(&device_offload_lock);
+ kfree(tls_ctx->rx.rec_seq);
+ kfree(tls_ctx->rx.iv);
+ tls_sw_release_resources_rx(sk);
+}
+
static int tls_device_down(struct net_device *netdev)
{
struct tls_context *ctx, *tmp;
@@ -710,8 +934,12 @@ static int tls_device_down(struct net_device *netdev)
spin_unlock_irqrestore(&tls_device_lock, flags);
list_for_each_entry_safe(ctx, tmp, &list, list) {
- netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
- TLS_OFFLOAD_CTX_DIR_TX);
+ if (ctx->tx_conf == TLS_HW)
+ netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
+ TLS_OFFLOAD_CTX_DIR_TX);
+ if (ctx->rx_conf == TLS_HW)
+ netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
+ TLS_OFFLOAD_CTX_DIR_RX);
ctx->netdev = NULL;
dev_put(netdev);
list_del_init(&ctx->list);
@@ -732,12 +960,16 @@ static int tls_dev_event(struct notifier_block *this, unsigned long event,
{
struct net_device *dev = netdev_notifier_info_to_dev(ptr);
- if (!(dev->features & NETIF_F_HW_TLS_TX))
+ if (!(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
return NOTIFY_DONE;
switch (event) {
case NETDEV_REGISTER:
case NETDEV_FEAT_CHANGE:
+ if ((dev->features & NETIF_F_HW_TLS_RX) &&
+ !dev->tlsdev_ops->tls_dev_resync_rx)
+ return NOTIFY_BAD;
+
if (dev->tlsdev_ops &&
dev->tlsdev_ops->tls_dev_add &&
dev->tlsdev_ops->tls_dev_del)
diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
index 748914abdb60..6102169239d1 100644
--- a/net/tls/tls_device_fallback.c
+++ b/net/tls/tls_device_fallback.c
@@ -42,7 +42,7 @@ static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
sg_set_page(sg, sg_page(src),
src->length - diff, walk->offset);
- scatterwalk_crypto_chain(sg, sg_next(src), 0, 2);
+ scatterwalk_crypto_chain(sg, sg_next(src), 2);
}
static int tls_enc_record(struct aead_request *aead_req,
@@ -214,7 +214,7 @@ static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln)
static int fill_sg_in(struct scatterlist *sg_in,
struct sk_buff *skb,
- struct tls_offload_context *ctx,
+ struct tls_offload_context_tx *ctx,
u64 *rcd_sn,
s32 *sync_size,
int *resync_sgs)
@@ -299,7 +299,7 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
s32 sync_size, u64 rcd_sn)
{
int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
- struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+ struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
int payload_len = skb->len - tcp_payload_offset;
void *buf, *iv, *aad, *dummy_buf;
struct aead_request *aead_req;
@@ -361,7 +361,7 @@ static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb)
{
int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
struct tls_context *tls_ctx = tls_get_ctx(sk);
- struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+ struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
int payload_len = skb->len - tcp_payload_offset;
struct scatterlist *sg_in, sg_out[3];
struct sk_buff *nskb = NULL;
@@ -413,9 +413,10 @@ struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
return tls_sw_fallback(sk, skb);
}
+EXPORT_SYMBOL_GPL(tls_validate_xmit_skb);
int tls_sw_fallback_init(struct sock *sk,
- struct tls_offload_context *offload_ctx,
+ struct tls_offload_context_tx *offload_ctx,
struct tls_crypto_info *crypto_info)
{
const u8 *key;
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 301f22430469..b09867c8b817 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -51,15 +51,6 @@ enum {
TLSV6,
TLS_NUM_PROTS,
};
-enum {
- TLS_BASE,
- TLS_SW,
-#ifdef CONFIG_TLS_DEVICE
- TLS_HW,
-#endif
- TLS_HW_RECORD,
- TLS_NUM_CONFIG,
-};
static struct proto *saved_tcpv6_prot;
static DEFINE_MUTEX(tcpv6_prot_mutex);
@@ -290,7 +281,10 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
}
#ifdef CONFIG_TLS_DEVICE
- if (ctx->tx_conf != TLS_HW) {
+ if (ctx->rx_conf == TLS_HW)
+ tls_device_offload_cleanup_rx(sk);
+
+ if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) {
#else
{
#endif
@@ -470,8 +464,16 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
conf = TLS_SW;
}
} else {
- rc = tls_set_sw_offload(sk, ctx, 0);
- conf = TLS_SW;
+#ifdef CONFIG_TLS_DEVICE
+ rc = tls_set_device_offload_rx(sk, ctx);
+ conf = TLS_HW;
+ if (rc) {
+#else
+ {
+#endif
+ rc = tls_set_sw_offload(sk, ctx, 0);
+ conf = TLS_SW;
+ }
}
if (rc)
@@ -629,6 +631,12 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg;
prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage;
+
+ prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
+
+ prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
+
+ prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
#endif
prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 1f3d9789af30..52fbe727d7c1 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -48,21 +48,11 @@ static int tls_do_decryption(struct sock *sk,
struct scatterlist *sgout,
char *iv_recv,
size_t data_len,
- struct sk_buff *skb,
- gfp_t flags)
+ struct aead_request *aead_req)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
- struct strp_msg *rxm = strp_msg(skb);
- struct aead_request *aead_req;
-
int ret;
- unsigned int req_size = sizeof(struct aead_request) +
- crypto_aead_reqsize(ctx->aead_recv);
-
- aead_req = kzalloc(req_size, flags);
- if (!aead_req)
- return -ENOMEM;
aead_request_set_tfm(aead_req, ctx->aead_recv);
aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
@@ -73,20 +63,6 @@ static int tls_do_decryption(struct sock *sk,
crypto_req_done, &ctx->async_wait);
ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
-
- if (ret < 0)
- goto out;
-
- rxm->offset += tls_ctx->rx.prepend_size;
- rxm->full_len -= tls_ctx->rx.overhead_size;
- tls_advance_record_sn(sk, &tls_ctx->rx);
-
- ctx->decrypted = true;
-
- ctx->saved_data_ready(sk);
-
-out:
- kfree(aead_req);
return ret;
}
@@ -224,8 +200,7 @@ static int tls_push_record(struct sock *sk, int flags,
struct aead_request *req;
int rc;
- req = kzalloc(sizeof(struct aead_request) +
- crypto_aead_reqsize(ctx->aead_send), sk->sk_allocation);
+ req = aead_request_alloc(ctx->aead_send, sk->sk_allocation);
if (!req)
return -ENOMEM;
@@ -267,7 +242,7 @@ static int tls_push_record(struct sock *sk, int flags,
tls_advance_record_sn(sk, &tls_ctx->tx);
out_req:
- kfree(req);
+ aead_request_free(req);
return rc;
}
@@ -328,7 +303,12 @@ static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
}
}
+ /* Mark the end in the last sg entry if newly added */
+ if (num_elem > *pages_used)
+ sg_mark_end(&to[num_elem - 1]);
out:
+ if (rc)
+ iov_iter_revert(from, size - *size_used);
*size_used = size;
*pages_used = num_elem;
@@ -377,6 +357,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
int record_room;
bool full_record;
int orig_size;
+ bool is_kvec = msg->msg_iter.type & ITER_KVEC;
if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
return -ENOTSUPP;
@@ -425,8 +406,7 @@ alloc_encrypted:
try_to_copy -= required_size - ctx->sg_encrypted_size;
full_record = true;
}
-
- if (full_record || eor) {
+ if (!is_kvec && (full_record || eor)) {
ret = zerocopy_from_iter(sk, &msg->msg_iter,
try_to_copy, &ctx->sg_plaintext_num_elem,
&ctx->sg_plaintext_size,
@@ -438,15 +418,11 @@ alloc_encrypted:
copied += try_to_copy;
ret = tls_push_record(sk, msg->msg_flags, record_type);
- if (!ret)
- continue;
- if (ret < 0)
+ if (ret)
goto send_end;
+ continue;
- copied -= try_to_copy;
fallback_to_reg_send:
- iov_iter_revert(&msg->msg_iter,
- ctx->sg_plaintext_size - orig_size);
trim_sg(sk, ctx->sg_plaintext_data,
&ctx->sg_plaintext_num_elem,
&ctx->sg_plaintext_size,
@@ -673,57 +649,167 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
return skb;
}
-static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
- struct scatterlist *sgout)
+/* This function decrypts the input skb into either out_iov or in out_sg
+ * or in skb buffers itself. The input parameter 'zc' indicates if
+ * zero-copy mode needs to be tried or not. With zero-copy mode, either
+ * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
+ * NULL, then the decryption happens inside skb buffers itself, i.e.
+ * zero-copy gets disabled and 'zc' is updated.
+ */
+
+static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
+ struct iov_iter *out_iov,
+ struct scatterlist *out_sg,
+ int *chunk, bool *zc)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
- char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE];
- struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];
- struct scatterlist *sgin = &sgin_arr[0];
struct strp_msg *rxm = strp_msg(skb);
- int ret, nsg = ARRAY_SIZE(sgin_arr);
+ int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
+ struct aead_request *aead_req;
struct sk_buff *unused;
+ u8 *aad, *iv, *mem = NULL;
+ struct scatterlist *sgin = NULL;
+ struct scatterlist *sgout = NULL;
+ const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
+
+ if (*zc && (out_iov || out_sg)) {
+ if (out_iov)
+ n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
+ else
+ n_sgout = sg_nents(out_sg);
+ } else {
+ n_sgout = 0;
+ *zc = false;
+ }
- ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
+ n_sgin = skb_cow_data(skb, 0, &unused);
+ if (n_sgin < 1)
+ return -EBADMSG;
+
+ /* Increment to accommodate AAD */
+ n_sgin = n_sgin + 1;
+
+ nsg = n_sgin + n_sgout;
+
+ aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
+ mem_size = aead_size + (nsg * sizeof(struct scatterlist));
+ mem_size = mem_size + TLS_AAD_SPACE_SIZE;
+ mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
+
+ /* Allocate a single block of memory which contains
+ * aead_req || sgin[] || sgout[] || aad || iv.
+ * This order achieves correct alignment for aead_req, sgin, sgout.
+ */
+ mem = kmalloc(mem_size, sk->sk_allocation);
+ if (!mem)
+ return -ENOMEM;
+
+ /* Segment the allocated memory */
+ aead_req = (struct aead_request *)mem;
+ sgin = (struct scatterlist *)(mem + aead_size);
+ sgout = sgin + n_sgin;
+ aad = (u8 *)(sgout + n_sgout);
+ iv = aad + TLS_AAD_SPACE_SIZE;
+
+ /* Prepare IV */
+ err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
tls_ctx->rx.iv_size);
- if (ret < 0)
- return ret;
-
- memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
- if (!sgout) {
- nsg = skb_cow_data(skb, 0, &unused) + 1;
- sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation);
- sgout = sgin;
+ if (err < 0) {
+ kfree(mem);
+ return err;
}
+ memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
- sg_init_table(sgin, nsg);
- sg_set_buf(&sgin[0], ctx->rx_aad_ciphertext, TLS_AAD_SPACE_SIZE);
+ /* Prepare AAD */
+ tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
+ tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
+ ctx->control);
- nsg = skb_to_sgvec(skb, &sgin[1],
+ /* Prepare sgin */
+ sg_init_table(sgin, n_sgin);
+ sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
+ err = skb_to_sgvec(skb, &sgin[1],
rxm->offset + tls_ctx->rx.prepend_size,
rxm->full_len - tls_ctx->rx.prepend_size);
- if (nsg < 0) {
- ret = nsg;
- goto out;
+ if (err < 0) {
+ kfree(mem);
+ return err;
}
- tls_make_aad(ctx->rx_aad_ciphertext,
- rxm->full_len - tls_ctx->rx.overhead_size,
- tls_ctx->rx.rec_seq,
- tls_ctx->rx.rec_seq_size,
- ctx->control);
+ if (n_sgout) {
+ if (out_iov) {
+ sg_init_table(sgout, n_sgout);
+ sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
- ret = tls_do_decryption(sk, sgin, sgout, iv,
- rxm->full_len - tls_ctx->rx.overhead_size,
- skb, sk->sk_allocation);
+ *chunk = 0;
+ err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
+ chunk, &sgout[1],
+ (n_sgout - 1), false);
+ if (err < 0)
+ goto fallback_to_reg_recv;
+ } else if (out_sg) {
+ memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
+ } else {
+ goto fallback_to_reg_recv;
+ }
+ } else {
+fallback_to_reg_recv:
+ sgout = sgin;
+ pages = 0;
+ *chunk = 0;
+ *zc = false;
+ }
-out:
- if (sgin != &sgin_arr[0])
- kfree(sgin);
+ /* Prepare and submit AEAD request */
+ err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);
- return ret;
+ /* Release the pages in case iov was mapped to pages */
+ for (; pages > 0; pages--)
+ put_page(sg_page(&sgout[pages]));
+
+ kfree(mem);
+ return err;
+}
+
+static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
+ struct iov_iter *dest, int *chunk, bool *zc)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ struct strp_msg *rxm = strp_msg(skb);
+ int err = 0;
+
+#ifdef CONFIG_TLS_DEVICE
+ err = tls_device_decrypted(sk, skb);
+ if (err < 0)
+ return err;
+#endif
+ if (!ctx->decrypted) {
+ err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
+ if (err < 0)
+ return err;
+ } else {
+ *zc = false;
+ }
+
+ rxm->offset += tls_ctx->rx.prepend_size;
+ rxm->full_len -= tls_ctx->rx.overhead_size;
+ tls_advance_record_sn(sk, &tls_ctx->rx);
+ ctx->decrypted = true;
+ ctx->saved_data_ready(sk);
+
+ return err;
+}
+
+int decrypt_skb(struct sock *sk, struct sk_buff *skb,
+ struct scatterlist *sgout)
+{
+ bool zc = true;
+ int chunk;
+
+ return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
}
static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
@@ -764,6 +850,7 @@ int tls_sw_recvmsg(struct sock *sk,
bool cmsg = false;
int target, err = 0;
long timeo;
+ bool is_kvec = msg->msg_iter.type & ITER_KVEC;
flags |= nonblock;
@@ -801,43 +888,17 @@ int tls_sw_recvmsg(struct sock *sk,
}
if (!ctx->decrypted) {
- int page_count;
- int to_copy;
-
- page_count = iov_iter_npages(&msg->msg_iter,
- MAX_SKB_FRAGS);
- to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
- if (to_copy <= len && page_count < MAX_SKB_FRAGS &&
- likely(!(flags & MSG_PEEK))) {
- struct scatterlist sgin[MAX_SKB_FRAGS + 1];
- int pages = 0;
+ int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
+ if (!is_kvec && to_copy <= len &&
+ likely(!(flags & MSG_PEEK)))
zc = true;
- sg_init_table(sgin, MAX_SKB_FRAGS + 1);
- sg_set_buf(&sgin[0], ctx->rx_aad_plaintext,
- TLS_AAD_SPACE_SIZE);
-
- err = zerocopy_from_iter(sk, &msg->msg_iter,
- to_copy, &pages,
- &chunk, &sgin[1],
- MAX_SKB_FRAGS, false);
- if (err < 0)
- goto fallback_to_reg_recv;
-
- err = decrypt_skb(sk, skb, sgin);
- for (; pages > 0; pages--)
- put_page(sg_page(&sgin[pages]));
- if (err < 0) {
- tls_err_abort(sk, EBADMSG);
- goto recv_end;
- }
- } else {
-fallback_to_reg_recv:
- err = decrypt_skb(sk, skb, NULL);
- if (err < 0) {
- tls_err_abort(sk, EBADMSG);
- goto recv_end;
- }
+
+ err = decrypt_skb_update(sk, skb, &msg->msg_iter,
+ &chunk, &zc);
+ if (err < 0) {
+ tls_err_abort(sk, EBADMSG);
+ goto recv_end;
}
ctx->decrypted = true;
}
@@ -888,6 +949,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
int err = 0;
long timeo;
int chunk;
+ bool zc = false;
lock_sock(sk);
@@ -904,7 +966,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
}
if (!ctx->decrypted) {
- err = decrypt_skb(sk, skb, NULL);
+ err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
if (err < 0) {
tls_err_abort(sk, EBADMSG);
@@ -950,7 +1012,7 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
{
struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
- char header[tls_ctx->rx.prepend_size];
+ char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
struct strp_msg *rxm = strp_msg(skb);
size_t cipher_overhead;
size_t data_len = 0;
@@ -960,6 +1022,12 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
return 0;
+ /* Sanity-check size of on-stack buffer. */
+ if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
+ ret = -EINVAL;
+ goto read_failure;
+ }
+
/* Linearize header to local buffer */
ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
@@ -987,6 +1055,10 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
goto read_failure;
}
+#ifdef CONFIG_TLS_DEVICE
+ handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
+ *(u64*)tls_ctx->rx.rec_seq);
+#endif
return data_len + TLS_HEADER_SIZE;
read_failure:
@@ -999,16 +1071,13 @@ static void tls_queue(struct strparser *strp, struct sk_buff *skb)
{
struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
- struct strp_msg *rxm;
-
- rxm = strp_msg(skb);
ctx->decrypted = false;
ctx->recv_pkt = skb;
strp_pause(strp);
- strp->sk->sk_state_change(strp->sk);
+ ctx->saved_data_ready(strp->sk);
}
static void tls_data_ready(struct sock *sk)
@@ -1024,23 +1093,20 @@ void tls_sw_free_resources_tx(struct sock *sk)
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
- if (ctx->aead_send)
- crypto_free_aead(ctx->aead_send);
+ crypto_free_aead(ctx->aead_send);
tls_free_both_sg(sk);
kfree(ctx);
}
-void tls_sw_free_resources_rx(struct sock *sk)
+void tls_sw_release_resources_rx(struct sock *sk)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
if (ctx->aead_recv) {
- if (ctx->recv_pkt) {
- kfree_skb(ctx->recv_pkt);
- ctx->recv_pkt = NULL;
- }
+ kfree_skb(ctx->recv_pkt);
+ ctx->recv_pkt = NULL;
crypto_free_aead(ctx->aead_recv);
strp_stop(&ctx->strp);
write_lock_bh(&sk->sk_callback_lock);
@@ -1050,6 +1116,14 @@ void tls_sw_free_resources_rx(struct sock *sk)
strp_done(&ctx->strp);
lock_sock(sk);
}
+}
+
+void tls_sw_free_resources_rx(struct sock *sk)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+
+ tls_sw_release_resources_rx(sk);
kfree(ctx);
}
@@ -1074,28 +1148,38 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
}
if (tx) {
- sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
- if (!sw_ctx_tx) {
- rc = -ENOMEM;
- goto out;
+ if (!ctx->priv_ctx_tx) {
+ sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
+ if (!sw_ctx_tx) {
+ rc = -ENOMEM;
+ goto out;
+ }
+ ctx->priv_ctx_tx = sw_ctx_tx;
+ } else {
+ sw_ctx_tx =
+ (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
}
- crypto_init_wait(&sw_ctx_tx->async_wait);
- ctx->priv_ctx_tx = sw_ctx_tx;
} else {
- sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
- if (!sw_ctx_rx) {
- rc = -ENOMEM;
- goto out;
+ if (!ctx->priv_ctx_rx) {
+ sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
+ if (!sw_ctx_rx) {
+ rc = -ENOMEM;
+ goto out;
+ }
+ ctx->priv_ctx_rx = sw_ctx_rx;
+ } else {
+ sw_ctx_rx =
+ (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
}
- crypto_init_wait(&sw_ctx_rx->async_wait);
- ctx->priv_ctx_rx = sw_ctx_rx;
}
if (tx) {
+ crypto_init_wait(&sw_ctx_tx->async_wait);
crypto_info = &ctx->crypto_send;
cctx = &ctx->tx;
aead = &sw_ctx_tx->aead_send;
} else {
+ crypto_init_wait(&sw_ctx_rx->async_wait);
crypto_info = &ctx->crypto_recv;
cctx = &ctx->rx;
aead = &sw_ctx_rx->aead_recv;
@@ -1120,7 +1204,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
}
/* Sanity-check the IV size for stack allocations. */
- if (iv_size > MAX_IV_SIZE) {
+ if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
rc = -EINVAL;
goto free_priv;
}
@@ -1138,12 +1222,11 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
cctx->rec_seq_size = rec_seq_size;
- cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
+ cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
if (!cctx->rec_seq) {
rc = -ENOMEM;
goto free_iv;
}
- memcpy(cctx->rec_seq, rec_seq, rec_seq_size);
if (sw_ctx_tx) {
sg_init_table(sw_ctx_tx->sg_encrypted_data,