diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/Kconfig | 10 | ||||
-rw-r--r-- | net/tls/Makefile | 2 | ||||
-rw-r--r-- | net/tls/tls_device.c | 766 | ||||
-rw-r--r-- | net/tls/tls_device_fallback.c | 450 | ||||
-rw-r--r-- | net/tls/tls_main.c | 139 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 141 |
6 files changed, 1395 insertions, 113 deletions
diff --git a/net/tls/Kconfig b/net/tls/Kconfig index 89b8745a986f..73f05ece53d0 100644 --- a/net/tls/Kconfig +++ b/net/tls/Kconfig @@ -14,3 +14,13 @@ config TLS encryption handling of the TLS protocol to be done in-kernel. If unsure, say N. + +config TLS_DEVICE + bool "Transport Layer Security HW offload" + depends on TLS + select SOCK_VALIDATE_XMIT + default n + help + Enable kernel support for HW offload of the TLS protocol. + + If unsure, say N. diff --git a/net/tls/Makefile b/net/tls/Makefile index a930fd1c4f7b..4d6b728a67d0 100644 --- a/net/tls/Makefile +++ b/net/tls/Makefile @@ -5,3 +5,5 @@ obj-$(CONFIG_TLS) += tls.o tls-y := tls_main.o tls_sw.o + +tls-$(CONFIG_TLS_DEVICE) += tls_device.o tls_device_fallback.o diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c new file mode 100644 index 000000000000..a7a8f8e20ff3 --- /dev/null +++ b/net/tls/tls_device.c @@ -0,0 +1,766 @@ +/* Copyright (c) 2018, Mellanox Technologies All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include <crypto/aead.h> +#include <linux/highmem.h> +#include <linux/module.h> +#include <linux/netdevice.h> +#include <net/dst.h> +#include <net/inet_connection_sock.h> +#include <net/tcp.h> +#include <net/tls.h> + +/* device_offload_lock is used to synchronize tls_dev_add + * against NETDEV_DOWN notifications. + */ +static DECLARE_RWSEM(device_offload_lock); + +static void tls_device_gc_task(struct work_struct *work); + +static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task); +static LIST_HEAD(tls_device_gc_list); +static LIST_HEAD(tls_device_list); +static DEFINE_SPINLOCK(tls_device_lock); + +static void tls_device_free_ctx(struct tls_context *ctx) +{ + struct tls_offload_context *offload_ctx = tls_offload_ctx(ctx); + + kfree(offload_ctx); + kfree(ctx); +} + +static void tls_device_gc_task(struct work_struct *work) +{ + struct tls_context *ctx, *tmp; + unsigned long flags; + LIST_HEAD(gc_list); + + spin_lock_irqsave(&tls_device_lock, flags); + list_splice_init(&tls_device_gc_list, &gc_list); + spin_unlock_irqrestore(&tls_device_lock, flags); + + list_for_each_entry_safe(ctx, tmp, &gc_list, list) { + struct net_device *netdev = ctx->netdev; + + if (netdev) { + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, + TLS_OFFLOAD_CTX_DIR_TX); + dev_put(netdev); + } + + list_del(&ctx->list); + tls_device_free_ctx(ctx); + } +} + +static void tls_device_queue_ctx_destruction(struct tls_context *ctx) +{ + unsigned long flags; + + spin_lock_irqsave(&tls_device_lock, flags); + list_move_tail(&ctx->list, &tls_device_gc_list); + + /* schedule_work inside the spinlock + * to make sure tls_device_down waits for that work. + */ + schedule_work(&tls_device_gc_work); + + spin_unlock_irqrestore(&tls_device_lock, flags); +} + +/* We assume that the socket is already connected */ +static struct net_device *get_netdev_for_sock(struct sock *sk) +{ + struct dst_entry *dst = sk_dst_get(sk); + struct net_device *netdev = NULL; + + if (likely(dst)) { + netdev = dst->dev; + dev_hold(netdev); + } + + dst_release(dst); + + return netdev; +} + +static void destroy_record(struct tls_record_info *record) +{ + int nr_frags = record->num_frags; + skb_frag_t *frag; + + while (nr_frags-- > 0) { + frag = &record->frags[nr_frags]; + __skb_frag_unref(frag); + } + kfree(record); +} + +static void delete_all_records(struct tls_offload_context *offload_ctx) +{ + struct tls_record_info *info, *temp; + + list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) { + list_del(&info->list); + destroy_record(info); + } + + offload_ctx->retransmit_hint = NULL; +} + +static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_record_info *info, *temp; + struct tls_offload_context *ctx; + u64 deleted_records = 0; + unsigned long flags; + + if (!tls_ctx) + return; + + ctx = tls_offload_ctx(tls_ctx); + + spin_lock_irqsave(&ctx->lock, flags); + info = ctx->retransmit_hint; + if (info && !before(acked_seq, info->end_seq)) { + ctx->retransmit_hint = NULL; + list_del(&info->list); + destroy_record(info); + deleted_records++; + } + + list_for_each_entry_safe(info, temp, &ctx->records_list, list) { + if (before(acked_seq, info->end_seq)) + break; + list_del(&info->list); + + destroy_record(info); + deleted_records++; + } + + ctx->unacked_record_sn += deleted_records; + spin_unlock_irqrestore(&ctx->lock, flags); +} + +/* At this point, there should be no references on this + * socket and no in-flight SKBs associated with this + * socket, so it is safe to free all the resources. + */ +void tls_device_sk_destruct(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); + + if (ctx->open_record) + destroy_record(ctx->open_record); + + delete_all_records(ctx); + crypto_free_aead(ctx->aead_send); + ctx->sk_destruct(sk); + clean_acked_data_disable(inet_csk(sk)); + + if (refcount_dec_and_test(&tls_ctx->refcount)) + tls_device_queue_ctx_destruction(tls_ctx); +} +EXPORT_SYMBOL(tls_device_sk_destruct); + +static void tls_append_frag(struct tls_record_info *record, + struct page_frag *pfrag, + int size) +{ + skb_frag_t *frag; + + frag = &record->frags[record->num_frags - 1]; + if (frag->page.p == pfrag->page && + frag->page_offset + frag->size == pfrag->offset) { + frag->size += size; + } else { + ++frag; + frag->page.p = pfrag->page; + frag->page_offset = pfrag->offset; + frag->size = size; + ++record->num_frags; + get_page(pfrag->page); + } + + pfrag->offset += size; + record->len += size; +} + +static int tls_push_record(struct sock *sk, + struct tls_context *ctx, + struct tls_offload_context *offload_ctx, + struct tls_record_info *record, + struct page_frag *pfrag, + int flags, + unsigned char record_type) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct page_frag dummy_tag_frag; + skb_frag_t *frag; + int i; + + /* fill prepend */ + frag = &record->frags[0]; + tls_fill_prepend(ctx, + skb_frag_address(frag), + record->len - ctx->tx.prepend_size, + record_type); + + /* HW doesn't care about the data in the tag, because it fills it. */ + dummy_tag_frag.page = skb_frag_page(frag); + dummy_tag_frag.offset = 0; + + tls_append_frag(record, &dummy_tag_frag, ctx->tx.tag_size); + record->end_seq = tp->write_seq + record->len; + spin_lock_irq(&offload_ctx->lock); + list_add_tail(&record->list, &offload_ctx->records_list); + spin_unlock_irq(&offload_ctx->lock); + offload_ctx->open_record = NULL; + set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); + tls_advance_record_sn(sk, &ctx->tx); + + for (i = 0; i < record->num_frags; i++) { + frag = &record->frags[i]; + sg_unmark_end(&offload_ctx->sg_tx_data[i]); + sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag), + frag->size, frag->page_offset); + sk_mem_charge(sk, frag->size); + get_page(skb_frag_page(frag)); + } + sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]); + + /* all ready, send */ + return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags); +} + +static int tls_create_new_record(struct tls_offload_context *offload_ctx, + struct page_frag *pfrag, + size_t prepend_size) +{ + struct tls_record_info *record; + skb_frag_t *frag; + + record = kmalloc(sizeof(*record), GFP_KERNEL); + if (!record) + return -ENOMEM; + + frag = &record->frags[0]; + __skb_frag_set_page(frag, pfrag->page); + frag->page_offset = pfrag->offset; + skb_frag_size_set(frag, prepend_size); + + get_page(pfrag->page); + pfrag->offset += prepend_size; + + record->num_frags = 1; + record->len = prepend_size; + offload_ctx->open_record = record; + return 0; +} + +static int tls_do_allocation(struct sock *sk, + struct tls_offload_context *offload_ctx, + struct page_frag *pfrag, + size_t prepend_size) +{ + int ret; + + if (!offload_ctx->open_record) { + if (unlikely(!skb_page_frag_refill(prepend_size, pfrag, + sk->sk_allocation))) { + sk->sk_prot->enter_memory_pressure(sk); + sk_stream_moderate_sndbuf(sk); + return -ENOMEM; + } + + ret = tls_create_new_record(offload_ctx, pfrag, prepend_size); + if (ret) + return ret; + + if (pfrag->size > pfrag->offset) + return 0; + } + + if (!sk_page_frag_refill(sk, pfrag)) + return -ENOMEM; + + return 0; +} + +static int tls_push_data(struct sock *sk, + struct iov_iter *msg_iter, + size_t size, int flags, + unsigned char record_type) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); + int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST; + int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE); + struct tls_record_info *record = ctx->open_record; + struct page_frag *pfrag; + size_t orig_size = size; + u32 max_open_record_len; + int copy, rc = 0; + bool done = false; + long timeo; + + if (flags & + ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST)) + return -ENOTSUPP; + + if (sk->sk_err) + return -sk->sk_err; + + timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); + rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo); + if (rc < 0) + return rc; + + pfrag = sk_page_frag(sk); + + /* TLS_HEADER_SIZE is not counted as part of the TLS record, and + * we need to leave room for an authentication tag. + */ + max_open_record_len = TLS_MAX_PAYLOAD_SIZE + + tls_ctx->tx.prepend_size; + do { + rc = tls_do_allocation(sk, ctx, pfrag, + tls_ctx->tx.prepend_size); + if (rc) { + rc = sk_stream_wait_memory(sk, &timeo); + if (!rc) + continue; + + record = ctx->open_record; + if (!record) + break; +handle_error: + if (record_type != TLS_RECORD_TYPE_DATA) { + /* avoid sending partial + * record with type != + * application_data + */ + size = orig_size; + destroy_record(record); + ctx->open_record = NULL; + } else if (record->len > tls_ctx->tx.prepend_size) { + goto last_record; + } + + break; + } + + record = ctx->open_record; + copy = min_t(size_t, size, (pfrag->size - pfrag->offset)); + copy = min_t(size_t, copy, (max_open_record_len - record->len)); + + if (copy_from_iter_nocache(page_address(pfrag->page) + + pfrag->offset, + copy, msg_iter) != copy) { + rc = -EFAULT; + goto handle_error; + } + tls_append_frag(record, pfrag, copy); + + size -= copy; + if (!size) { +last_record: + tls_push_record_flags = flags; + if (more) { + tls_ctx->pending_open_record_frags = + record->num_frags; + break; + } + + done = true; + } + + if (done || record->len >= max_open_record_len || + (record->num_frags >= MAX_SKB_FRAGS - 1)) { + rc = tls_push_record(sk, + tls_ctx, + ctx, + record, + pfrag, + tls_push_record_flags, + record_type); + if (rc < 0) + break; + } + } while (!done); + + if (orig_size - size > 0) + rc = orig_size - size; + + return rc; +} + +int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + unsigned char record_type = TLS_RECORD_TYPE_DATA; + int rc; + + lock_sock(sk); + + if (unlikely(msg->msg_controllen)) { + rc = tls_proccess_cmsg(sk, msg, &record_type); + if (rc) + goto out; + } + + rc = tls_push_data(sk, &msg->msg_iter, size, + msg->msg_flags, record_type); + +out: + release_sock(sk); + return rc; +} + +int tls_device_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + struct iov_iter msg_iter; + char *kaddr = kmap(page); + struct kvec iov; + int rc; + + if (flags & MSG_SENDPAGE_NOTLAST) + flags |= MSG_MORE; + + lock_sock(sk); + + if (flags & MSG_OOB) { + rc = -ENOTSUPP; + goto out; + } + + iov.iov_base = kaddr + offset; + iov.iov_len = size; + iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size); + rc = tls_push_data(sk, &msg_iter, size, + flags, TLS_RECORD_TYPE_DATA); + kunmap(page); + +out: + release_sock(sk); + return rc; +} + +struct tls_record_info *tls_get_record(struct tls_offload_context *context, + u32 seq, u64 *p_record_sn) +{ + u64 record_sn = context->hint_record_sn; + struct tls_record_info *info; + + info = context->retransmit_hint; + if (!info || + before(seq, info->end_seq - info->len)) { + /* if retransmit_hint is irrelevant start + * from the beggining of the list + */ + info = list_first_entry(&context->records_list, + struct tls_record_info, list); + record_sn = context->unacked_record_sn; + } + + list_for_each_entry_from(info, &context->records_list, list) { + if (before(seq, info->end_seq)) { + if (!context->retransmit_hint || + after(info->end_seq, + context->retransmit_hint->end_seq)) { + context->hint_record_sn = record_sn; + context->retransmit_hint = info; + } + *p_record_sn = record_sn; + return info; + } + record_sn++; + } + + return NULL; +} +EXPORT_SYMBOL(tls_get_record); + +static int tls_device_push_pending_record(struct sock *sk, int flags) +{ + struct iov_iter msg_iter; + + iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0); + return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA); +} + +int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) +{ + u16 nonce_size, tag_size, iv_size, rec_seq_size; + struct tls_record_info *start_marker_record; + struct tls_offload_context *offload_ctx; + struct tls_crypto_info *crypto_info; + struct net_device *netdev; + char *iv, *rec_seq; + struct sk_buff *skb; + int rc = -EINVAL; + __be64 rcd_sn; + + if (!ctx) + goto out; + + if (ctx->priv_ctx_tx) { + rc = -EEXIST; + goto out; + } + + start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL); + if (!start_marker_record) { + rc = -ENOMEM; + goto out; + } + + offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE, GFP_KERNEL); + if (!offload_ctx) { + rc = -ENOMEM; + goto free_marker_record; + } + + crypto_info = &ctx->crypto_send; + switch (crypto_info->cipher_type) { + case TLS_CIPHER_AES_GCM_128: + nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; + tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; + iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; + iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv; + rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; + rec_seq = + ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq; + break; + default: + rc = -EINVAL; + goto free_offload_ctx; + } + + ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size; + ctx->tx.tag_size = tag_size; + ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size; + ctx->tx.iv_size = iv_size; + ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, + GFP_KERNEL); + if (!ctx->tx.iv) { + rc = -ENOMEM; + goto free_offload_ctx; + } + + memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); + + ctx->tx.rec_seq_size = rec_seq_size; + ctx->tx.rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); + if (!ctx->tx.rec_seq) { + rc = -ENOMEM; + goto free_iv; + } + memcpy(ctx->tx.rec_seq, rec_seq, rec_seq_size); + + rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info); + if (rc) + goto free_rec_seq; + + /* start at rec_seq - 1 to account for the start marker record */ + memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn)); + offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1; + + start_marker_record->end_seq = tcp_sk(sk)->write_seq; + start_marker_record->len = 0; + start_marker_record->num_frags = 0; + + INIT_LIST_HEAD(&offload_ctx->records_list); + list_add_tail(&start_marker_record->list, &offload_ctx->records_list); + spin_lock_init(&offload_ctx->lock); + sg_init_table(offload_ctx->sg_tx_data, + ARRAY_SIZE(offload_ctx->sg_tx_data)); + + clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked); + ctx->push_pending_record = tls_device_push_pending_record; + offload_ctx->sk_destruct = sk->sk_destruct; + + /* TLS offload is greatly simplified if we don't send + * SKBs where only part of the payload needs to be encrypted. + * So mark the last skb in the write queue as end of record. + */ + skb = tcp_write_queue_tail(sk); + if (skb) + TCP_SKB_CB(skb)->eor = 1; + + refcount_set(&ctx->refcount, 1); + + /* We support starting offload on multiple sockets + * concurrently, so we only need a read lock here. + * This lock must precede get_netdev_for_sock to prevent races between + * NETDEV_DOWN and setsockopt. + */ + down_read(&device_offload_lock); + netdev = get_netdev_for_sock(sk); + if (!netdev) { + pr_err_ratelimited("%s: netdev not found\n", __func__); + rc = -EINVAL; + goto release_lock; + } + + if (!(netdev->features & NETIF_F_HW_TLS_TX)) { + rc = -ENOTSUPP; + goto release_netdev; + } + + /* Avoid offloading if the device is down + * We don't want to offload new flows after + * the NETDEV_DOWN event + */ + if (!(netdev->flags & IFF_UP)) { + rc = -EINVAL; + goto release_netdev; + } + + ctx->priv_ctx_tx = offload_ctx; + rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX, + &ctx->crypto_send, + tcp_sk(sk)->write_seq); + if (rc) + goto release_netdev; + + ctx->netdev = netdev; + + spin_lock_irq(&tls_device_lock); + list_add_tail(&ctx->list, &tls_device_list); + spin_unlock_irq(&tls_device_lock); + + sk->sk_validate_xmit_skb = tls_validate_xmit_skb; + /* following this assignment tls_is_sk_tx_device_offloaded + * will return true and the context might be accessed + * by the netdev's xmit function. + */ + smp_store_release(&sk->sk_destruct, + &tls_device_sk_destruct); + up_read(&device_offload_lock); + goto out; + +release_netdev: + dev_put(netdev); +release_lock: + up_read(&device_offload_lock); + clean_acked_data_disable(inet_csk(sk)); + crypto_free_aead(offload_ctx->aead_send); +free_rec_seq: + kfree(ctx->tx.rec_seq); +free_iv: + kfree(ctx->tx.iv); +free_offload_ctx: + kfree(offload_ctx); + ctx->priv_ctx_tx = NULL; +free_marker_record: + kfree(start_marker_record); +out: + return rc; +} + +static int tls_device_down(struct net_device *netdev) +{ + struct tls_context *ctx, *tmp; + unsigned long flags; + LIST_HEAD(list); + + /* Request a write lock to block new offload attempts */ + down_write(&device_offload_lock); + + spin_lock_irqsave(&tls_device_lock, flags); + list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) { + if (ctx->netdev != netdev || + !refcount_inc_not_zero(&ctx->refcount)) + continue; + + list_move(&ctx->list, &list); + } + spin_unlock_irqrestore(&tls_device_lock, flags); + + list_for_each_entry_safe(ctx, tmp, &list, list) { + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, + TLS_OFFLOAD_CTX_DIR_TX); + ctx->netdev = NULL; + dev_put(netdev); + list_del_init(&ctx->list); + + if (refcount_dec_and_test(&ctx->refcount)) + tls_device_free_ctx(ctx); + } + + up_write(&device_offload_lock); + + flush_work(&tls_device_gc_work); + + return NOTIFY_DONE; +} + +static int tls_dev_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (!(dev->features & NETIF_F_HW_TLS_TX)) + return NOTIFY_DONE; + + switch (event) { + case NETDEV_REGISTER: + case NETDEV_FEAT_CHANGE: + if (dev->tlsdev_ops && + dev->tlsdev_ops->tls_dev_add && + dev->tlsdev_ops->tls_dev_del) + return NOTIFY_DONE; + else + return NOTIFY_BAD; + case NETDEV_DOWN: + return tls_device_down(dev); + } + return NOTIFY_DONE; +} + +static struct notifier_block tls_dev_notifier = { + .notifier_call = tls_dev_event, +}; + +void __init tls_device_init(void) +{ + register_netdevice_notifier(&tls_dev_notifier); +} + +void __exit tls_device_cleanup(void) +{ + unregister_netdevice_notifier(&tls_dev_notifier); + flush_work(&tls_device_gc_work); +} diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c new file mode 100644 index 000000000000..748914abdb60 --- /dev/null +++ b/net/tls/tls_device_fallback.c @@ -0,0 +1,450 @@ +/* Copyright (c) 2018, Mellanox Technologies All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include <net/tls.h> +#include <crypto/aead.h> +#include <crypto/scatterwalk.h> +#include <net/ip6_checksum.h> + +static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk) +{ + struct scatterlist *src = walk->sg; + int diff = walk->offset - src->offset; + + sg_set_page(sg, sg_page(src), + src->length - diff, walk->offset); + + scatterwalk_crypto_chain(sg, sg_next(src), 0, 2); +} + +static int tls_enc_record(struct aead_request *aead_req, + struct crypto_aead *aead, char *aad, + char *iv, __be64 rcd_sn, + struct scatter_walk *in, + struct scatter_walk *out, int *in_len) +{ + unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE]; + struct scatterlist sg_in[3]; + struct scatterlist sg_out[3]; + u16 len; + int rc; + + len = min_t(int, *in_len, ARRAY_SIZE(buf)); + + scatterwalk_copychunks(buf, in, len, 0); + scatterwalk_copychunks(buf, out, len, 1); + + *in_len -= len; + if (!*in_len) + return 0; + + scatterwalk_pagedone(in, 0, 1); + scatterwalk_pagedone(out, 1, 1); + + len = buf[4] | (buf[3] << 8); + len -= TLS_CIPHER_AES_GCM_128_IV_SIZE; + + tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE, + (char *)&rcd_sn, sizeof(rcd_sn), buf[0]); + + memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE, + TLS_CIPHER_AES_GCM_128_IV_SIZE); + + sg_init_table(sg_in, ARRAY_SIZE(sg_in)); + sg_init_table(sg_out, ARRAY_SIZE(sg_out)); + sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE); + sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE); + chain_to_walk(sg_in + 1, in); + chain_to_walk(sg_out + 1, out); + + *in_len -= len; + if (*in_len < 0) { + *in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE; + /* the input buffer doesn't contain the entire record. + * trim len accordingly. The resulting authentication tag + * will contain garbage, but we don't care, so we won't + * include any of it in the output skb + * Note that we assume the output buffer length + * is larger then input buffer length + tag size + */ + if (*in_len < 0) + len += *in_len; + + *in_len = 0; + } + + if (*in_len) { + scatterwalk_copychunks(NULL, in, len, 2); + scatterwalk_pagedone(in, 0, 1); + scatterwalk_copychunks(NULL, out, len, 2); + scatterwalk_pagedone(out, 1, 1); + } + + len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE; + aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv); + + rc = crypto_aead_encrypt(aead_req); + + return rc; +} + +static void tls_init_aead_request(struct aead_request *aead_req, + struct crypto_aead *aead) +{ + aead_request_set_tfm(aead_req, aead); + aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); +} + +static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead, + gfp_t flags) +{ + unsigned int req_size = sizeof(struct aead_request) + + crypto_aead_reqsize(aead); + struct aead_request *aead_req; + + aead_req = kzalloc(req_size, flags); + if (aead_req) + tls_init_aead_request(aead_req, aead); + return aead_req; +} + +static int tls_enc_records(struct aead_request *aead_req, + struct crypto_aead *aead, struct scatterlist *sg_in, + struct scatterlist *sg_out, char *aad, char *iv, + u64 rcd_sn, int len) +{ + struct scatter_walk out, in; + int rc; + + scatterwalk_start(&in, sg_in); + scatterwalk_start(&out, sg_out); + + do { + rc = tls_enc_record(aead_req, aead, aad, iv, + cpu_to_be64(rcd_sn), &in, &out, &len); + rcd_sn++; + + } while (rc == 0 && len); + + scatterwalk_done(&in, 0, 0); + scatterwalk_done(&out, 1, 0); + + return rc; +} + +/* Can't use icsk->icsk_af_ops->send_check here because the ip addresses + * might have been changed by NAT. + */ +static void update_chksum(struct sk_buff *skb, int headln) +{ + struct tcphdr *th = tcp_hdr(skb); + int datalen = skb->len - headln; + const struct ipv6hdr *ipv6h; + const struct iphdr *iph; + + /* We only changed the payload so if we are using partial we don't + * need to update anything. + */ + if (likely(skb->ip_summed == CHECKSUM_PARTIAL)) + return; + + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = offsetof(struct tcphdr, check); + + if (skb->sk->sk_family == AF_INET6) { + ipv6h = ipv6_hdr(skb); + th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr, + datalen, IPPROTO_TCP, 0); + } else { + iph = ip_hdr(skb); + th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen, + IPPROTO_TCP, 0); + } +} + +static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln) +{ + skb_copy_header(nskb, skb); + + skb_put(nskb, skb->len); + memcpy(nskb->data, skb->data, headln); + update_chksum(nskb, headln); + + nskb->destructor = skb->destructor; + nskb->sk = skb->sk; + skb->destructor = NULL; + skb->sk = NULL; + refcount_add(nskb->truesize - skb->truesize, + &nskb->sk->sk_wmem_alloc); +} + +/* This function may be called after the user socket is already + * closed so make sure we don't use anything freed during + * tls_sk_proto_close here + */ + +static int fill_sg_in(struct scatterlist *sg_in, + struct sk_buff *skb, + struct tls_offload_context *ctx, + u64 *rcd_sn, + s32 *sync_size, + int *resync_sgs) +{ + int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb); + int payload_len = skb->len - tcp_payload_offset; + u32 tcp_seq = ntohl(tcp_hdr(skb)->seq); + struct tls_record_info *record; + unsigned long flags; + int remaining; + int i; + + spin_lock_irqsave(&ctx->lock, flags); + record = tls_get_record(ctx, tcp_seq, rcd_sn); + if (!record) { + spin_unlock_irqrestore(&ctx->lock, flags); + WARN(1, "Record not found for seq %u\n", tcp_seq); + return -EINVAL; + } + + *sync_size = tcp_seq - tls_record_start_seq(record); + if (*sync_size < 0) { + int is_start_marker = tls_record_is_start_marker(record); + + spin_unlock_irqrestore(&ctx->lock, flags); + /* This should only occur if the relevant record was + * already acked. In that case it should be ok + * to drop the packet and avoid retransmission. + * + * There is a corner case where the packet contains + * both an acked and a non-acked record. + * We currently don't handle that case and rely + * on TCP to retranmit a packet that doesn't contain + * already acked payload. + */ + if (!is_start_marker) + *sync_size = 0; + return -EINVAL; + } + + remaining = *sync_size; + for (i = 0; remaining > 0; i++) { + skb_frag_t *frag = &record->frags[i]; + + __skb_frag_ref(frag); + sg_set_page(sg_in + i, skb_frag_page(frag), + skb_frag_size(frag), frag->page_offset); + + remaining -= skb_frag_size(frag); + + if (remaining < 0) + sg_in[i].length += remaining; + } + *resync_sgs = i; + + spin_unlock_irqrestore(&ctx->lock, flags); + if (skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset, payload_len) < 0) + return -EINVAL; + + return 0; +} + +static void fill_sg_out(struct scatterlist sg_out[3], void *buf, + struct tls_context *tls_ctx, + struct sk_buff *nskb, + int tcp_payload_offset, + int payload_len, + int sync_size, + void *dummy_buf) +{ + sg_set_buf(&sg_out[0], dummy_buf, sync_size); + sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset, payload_len); + /* Add room for authentication tag produced by crypto */ + dummy_buf += sync_size; + sg_set_buf(&sg_out[2], dummy_buf, TLS_CIPHER_AES_GCM_128_TAG_SIZE); +} + +static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, + struct scatterlist sg_out[3], + struct scatterlist *sg_in, + struct sk_buff *skb, + s32 sync_size, u64 rcd_sn) +{ + int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb); + struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); + int payload_len = skb->len - tcp_payload_offset; + void *buf, *iv, *aad, *dummy_buf; + struct aead_request *aead_req; + struct sk_buff *nskb = NULL; + int buf_len; + + aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC); + if (!aead_req) + return NULL; + + buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE + + TLS_CIPHER_AES_GCM_128_IV_SIZE + + TLS_AAD_SPACE_SIZE + + sync_size + + TLS_CIPHER_AES_GCM_128_TAG_SIZE; + buf = kmalloc(buf_len, GFP_ATOMIC); + if (!buf) + goto free_req; + + iv = buf; + memcpy(iv, tls_ctx->crypto_send_aes_gcm_128.salt, + TLS_CIPHER_AES_GCM_128_SALT_SIZE); + aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE + + TLS_CIPHER_AES_GCM_128_IV_SIZE; + dummy_buf = aad + TLS_AAD_SPACE_SIZE; + + nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC); + if (!nskb) + goto free_buf; + + skb_reserve(nskb, skb_headroom(skb)); + + fill_sg_out(sg_out, buf, tls_ctx, nskb, tcp_payload_offset, + payload_len, sync_size, dummy_buf); + + if (tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv, + rcd_sn, sync_size + payload_len) < 0) + goto free_nskb; + + complete_skb(nskb, skb, tcp_payload_offset); + + /* validate_xmit_skb_list assumes that if the skb wasn't segmented + * nskb->prev will point to the skb itself + */ + nskb->prev = nskb; + +free_buf: + kfree(buf); +free_req: + kfree(aead_req); + return nskb; +free_nskb: + kfree_skb(nskb); + nskb = NULL; + goto free_buf; +} + +static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb) +{ + int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb); + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); + int payload_len = skb->len - tcp_payload_offset; + struct scatterlist *sg_in, sg_out[3]; + struct sk_buff *nskb = NULL; + int sg_in_max_elements; + int resync_sgs = 0; + s32 sync_size = 0; + u64 rcd_sn; + + /* worst case is: + * MAX_SKB_FRAGS in tls_record_info + * MAX_SKB_FRAGS + 1 in SKB head and frags. + */ + sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1; + + if (!payload_len) + return skb; + + sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in), GFP_ATOMIC); + if (!sg_in) + goto free_orig; + + sg_init_table(sg_in, sg_in_max_elements); + sg_init_table(sg_out, ARRAY_SIZE(sg_out)); + + if (fill_sg_in(sg_in, skb, ctx, &rcd_sn, &sync_size, &resync_sgs)) { + /* bypass packets before kernel TLS socket option was set */ + if (sync_size < 0 && payload_len <= -sync_size) + nskb = skb_get(skb); + goto put_sg; + } + + nskb = tls_enc_skb(tls_ctx, sg_out, sg_in, skb, sync_size, rcd_sn); + +put_sg: + while (resync_sgs) + put_page(sg_page(&sg_in[--resync_sgs])); + kfree(sg_in); +free_orig: + kfree_skb(skb); + return nskb; +} + +struct sk_buff *tls_validate_xmit_skb(struct sock *sk, + struct net_device *dev, + struct sk_buff *skb) +{ + if (dev == tls_get_ctx(sk)->netdev) + return skb; + + return tls_sw_fallback(sk, skb); +} + +int tls_sw_fallback_init(struct sock *sk, + struct tls_offload_context *offload_ctx, + struct tls_crypto_info *crypto_info) +{ + const u8 *key; + int rc; + + offload_ctx->aead_send = + crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(offload_ctx->aead_send)) { + rc = PTR_ERR(offload_ctx->aead_send); + pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n", rc); + offload_ctx->aead_send = NULL; + goto err_out; + } + + key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key; + + rc = crypto_aead_setkey(offload_ctx->aead_send, key, + TLS_CIPHER_AES_GCM_128_KEY_SIZE); + if (rc) + goto free_aead; + + rc = crypto_aead_setauthsize(offload_ctx->aead_send, + TLS_CIPHER_AES_GCM_128_TAG_SIZE); + if (rc) + goto free_aead; + + return 0; +free_aead: + crypto_free_aead(offload_ctx->aead_send); +err_out: + return rc; +} diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 20cd93be6236..301f22430469 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -51,12 +51,12 @@ enum { TLSV6, TLS_NUM_PROTS, }; - enum { TLS_BASE, - TLS_SW_TX, - TLS_SW_RX, - TLS_SW_RXTX, + TLS_SW, +#ifdef CONFIG_TLS_DEVICE + TLS_HW, +#endif TLS_HW_RECORD, TLS_NUM_CONFIG, }; @@ -65,14 +65,14 @@ static struct proto *saved_tcpv6_prot; static DEFINE_MUTEX(tcpv6_prot_mutex); static LIST_HEAD(device_list); static DEFINE_MUTEX(device_mutex); -static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG]; +static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; static struct proto_ops tls_sw_proto_ops; -static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx) +static void update_sk_prot(struct sock *sk, struct tls_context *ctx) { int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; - sk->sk_prot = &tls_prots[ip_ver][ctx->conf]; + sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]; } int wait_on_pending_writer(struct sock *sk, long *timeo) @@ -254,7 +254,8 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) lock_sock(sk); sk_proto_close = ctx->sk_proto_close; - if (ctx->conf == TLS_BASE || ctx->conf == TLS_HW_RECORD) { + if ((ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) || + (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)) { free_ctx = true; goto skip_tx_cleanup; } @@ -275,15 +276,26 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) } } - kfree(ctx->tx.rec_seq); - kfree(ctx->tx.iv); - kfree(ctx->rx.rec_seq); - kfree(ctx->rx.iv); + /* We need these for tls_sw_fallback handling of other packets */ + if (ctx->tx_conf == TLS_SW) { + kfree(ctx->tx.rec_seq); + kfree(ctx->tx.iv); + tls_sw_free_resources_tx(sk); + } - if (ctx->conf == TLS_SW_TX || - ctx->conf == TLS_SW_RX || - ctx->conf == TLS_SW_RXTX) { - tls_sw_free_resources(sk); + if (ctx->rx_conf == TLS_SW) { + kfree(ctx->rx.rec_seq); + kfree(ctx->rx.iv); + tls_sw_free_resources_rx(sk); + } + +#ifdef CONFIG_TLS_DEVICE + if (ctx->tx_conf != TLS_HW) { +#else + { +#endif + kfree(ctx); + ctx = NULL; } skip_tx_cleanup: @@ -446,25 +458,29 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, goto err_crypto_info; } - /* currently SW is default, we will have ethtool in future */ if (tx) { - rc = tls_set_sw_offload(sk, ctx, 1); - if (ctx->conf == TLS_SW_RX) - conf = TLS_SW_RXTX; - else - conf = TLS_SW_TX; +#ifdef CONFIG_TLS_DEVICE + rc = tls_set_device_offload(sk, ctx); + conf = TLS_HW; + if (rc) { +#else + { +#endif + rc = tls_set_sw_offload(sk, ctx, 1); + conf = TLS_SW; + } } else { rc = tls_set_sw_offload(sk, ctx, 0); - if (ctx->conf == TLS_SW_TX) - conf = TLS_SW_RXTX; - else - conf = TLS_SW_RX; + conf = TLS_SW; } if (rc) goto err_crypto_info; - ctx->conf = conf; + if (tx) + ctx->tx_conf = conf; + else + ctx->rx_conf = conf; update_sk_prot(sk, ctx); if (tx) { ctx->sk_write_space = sk->sk_write_space; @@ -540,7 +556,8 @@ static int tls_hw_prot(struct sock *sk) ctx->hash = sk->sk_prot->hash; ctx->unhash = sk->sk_prot->unhash; ctx->sk_proto_close = sk->sk_prot->close; - ctx->conf = TLS_HW_RECORD; + ctx->rx_conf = TLS_HW_RECORD; + ctx->tx_conf = TLS_HW_RECORD; update_sk_prot(sk, ctx); rc = 1; break; @@ -584,29 +601,40 @@ static int tls_hw_hash(struct sock *sk) return err; } -static void build_protos(struct proto *prot, struct proto *base) +static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], + struct proto *base) { - prot[TLS_BASE] = *base; - prot[TLS_BASE].setsockopt = tls_setsockopt; - prot[TLS_BASE].getsockopt = tls_getsockopt; - prot[TLS_BASE].close = tls_sk_proto_close; - - prot[TLS_SW_TX] = prot[TLS_BASE]; - prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg; - prot[TLS_SW_TX].sendpage = tls_sw_sendpage; - - prot[TLS_SW_RX] = prot[TLS_BASE]; - prot[TLS_SW_RX].recvmsg = tls_sw_recvmsg; - prot[TLS_SW_RX].close = tls_sk_proto_close; - - prot[TLS_SW_RXTX] = prot[TLS_SW_TX]; - prot[TLS_SW_RXTX].recvmsg = tls_sw_recvmsg; - prot[TLS_SW_RXTX].close = tls_sk_proto_close; - - prot[TLS_HW_RECORD] = *base; - prot[TLS_HW_RECORD].hash = tls_hw_hash; - prot[TLS_HW_RECORD].unhash = tls_hw_unhash; - prot[TLS_HW_RECORD].close = tls_sk_proto_close; + prot[TLS_BASE][TLS_BASE] = *base; + prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; + prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; + prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; + + prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; + prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; + prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage; + + prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; + prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; + prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; + + prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; + prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; + prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; + +#ifdef CONFIG_TLS_DEVICE + prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; + prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; + prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage; + + prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; + prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; + prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage; +#endif + + prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; + prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash; + prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash; + prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close; } static int tls_init(struct sock *sk) @@ -637,7 +665,7 @@ static int tls_init(struct sock *sk) ctx->getsockopt = sk->sk_prot->getsockopt; ctx->sk_proto_close = sk->sk_prot->close; - /* Build IPv6 TLS whenever the address of tcpv6_prot changes */ + /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ if (ip_ver == TLSV6 && unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) { mutex_lock(&tcpv6_prot_mutex); @@ -648,7 +676,8 @@ static int tls_init(struct sock *sk) mutex_unlock(&tcpv6_prot_mutex); } - ctx->conf = TLS_BASE; + ctx->tx_conf = TLS_BASE; + ctx->rx_conf = TLS_BASE; update_sk_prot(sk, ctx); out: return rc; @@ -686,6 +715,9 @@ static int __init tls_register(void) tls_sw_proto_ops.poll = tls_sw_poll; tls_sw_proto_ops.splice_read = tls_sw_splice_read; +#ifdef CONFIG_TLS_DEVICE + tls_device_init(); +#endif tcp_register_ulp(&tcp_tls_ulp_ops); return 0; @@ -694,6 +726,9 @@ static int __init tls_register(void) static void __exit tls_unregister(void) { tcp_unregister_ulp(&tcp_tls_ulp_ops); +#ifdef CONFIG_TLS_DEVICE + tls_device_cleanup(); +#endif } module_init(tls_register); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index e1c93ce74e0f..839e1e165a0c 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -52,7 +52,7 @@ static int tls_do_decryption(struct sock *sk, gfp_t flags) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct strp_msg *rxm = strp_msg(skb); struct aead_request *aead_req; @@ -122,7 +122,7 @@ out: static void trim_both_sgl(struct sock *sk, int target_size) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); trim_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem, @@ -141,7 +141,7 @@ static void trim_both_sgl(struct sock *sk, int target_size) static int alloc_encrypted_sg(struct sock *sk, int len) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); int rc = 0; rc = sk_alloc_sg(sk, len, @@ -155,7 +155,7 @@ static int alloc_encrypted_sg(struct sock *sk, int len) static int alloc_plaintext_sg(struct sock *sk, int len) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); int rc = 0; rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0, @@ -181,7 +181,7 @@ static void free_sg(struct sock *sk, struct scatterlist *sg, static void tls_free_both_sg(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem, &ctx->sg_encrypted_size); @@ -191,7 +191,7 @@ static void tls_free_both_sg(struct sock *sk) } static int tls_do_encryption(struct tls_context *tls_ctx, - struct tls_sw_context *ctx, size_t data_len, + struct tls_sw_context_tx *ctx, size_t data_len, gfp_t flags) { unsigned int req_size = sizeof(struct aead_request) + @@ -227,7 +227,7 @@ static int tls_push_record(struct sock *sk, int flags, unsigned char record_type) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); int rc; sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1); @@ -339,7 +339,7 @@ static int memcopy_from_iter(struct sock *sk, struct iov_iter *from, int bytes) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct scatterlist *sg = ctx->sg_plaintext_data; int copy, i, rc = 0; @@ -367,7 +367,7 @@ out: int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); int ret = 0; int required_size; long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); @@ -522,7 +522,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, int offset, size_t size, int flags) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); int ret = 0; long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); bool eor; @@ -636,7 +636,7 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags, long timeo, int *err) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct sk_buff *skb; DEFINE_WAIT_FUNC(wait, woken_wake_function); @@ -674,7 +674,7 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb, struct scatterlist *sgout) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE]; struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2]; struct scatterlist *sgin = &sgin_arr[0]; @@ -692,8 +692,7 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb, if (!sgout) { nsg = skb_cow_data(skb, 0, &unused) + 1; sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation); - if (!sgout) - sgout = sgin; + sgout = sgin; } sg_init_table(sgin, nsg); @@ -723,7 +722,7 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, unsigned int len) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct strp_msg *rxm = strp_msg(skb); if (len < rxm->full_len) { @@ -749,7 +748,7 @@ int tls_sw_recvmsg(struct sock *sk, int *addr_len) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); unsigned char control; struct strp_msg *rxm; struct sk_buff *skb; @@ -869,7 +868,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, size_t len, unsigned int flags) { struct tls_context *tls_ctx = tls_get_ctx(sock->sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct strp_msg *rxm = NULL; struct sock *sk = sock->sk; struct sk_buff *skb; @@ -922,7 +921,7 @@ unsigned int tls_sw_poll(struct file *file, struct socket *sock, unsigned int ret; struct sock *sk = sock->sk; struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); /* Grab POLLOUT and POLLHUP from the underlying socket */ ret = ctx->sk_poll(file, sock, wait); @@ -938,7 +937,7 @@ unsigned int tls_sw_poll(struct file *file, struct socket *sock, static int tls_read_size(struct strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); char header[tls_ctx->rx.prepend_size]; struct strp_msg *rxm = strp_msg(skb); size_t cipher_overhead; @@ -987,7 +986,7 @@ read_failure: static void tls_queue(struct strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct strp_msg *rxm; rxm = strp_msg(skb); @@ -1003,18 +1002,28 @@ static void tls_queue(struct strparser *strp, struct sk_buff *skb) static void tls_data_ready(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); strp_data_ready(&ctx->strp); } -void tls_sw_free_resources(struct sock *sk) +void tls_sw_free_resources_tx(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); if (ctx->aead_send) crypto_free_aead(ctx->aead_send); + tls_free_both_sg(sk); + + kfree(ctx); +} + +void tls_sw_free_resources_rx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + if (ctx->aead_recv) { if (ctx->recv_pkt) { kfree_skb(ctx->recv_pkt); @@ -1030,10 +1039,7 @@ void tls_sw_free_resources(struct sock *sk) lock_sock(sk); } - tls_free_both_sg(sk); - kfree(ctx); - kfree(tls_ctx); } int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) @@ -1041,7 +1047,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE]; struct tls_crypto_info *crypto_info; struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; - struct tls_sw_context *sw_ctx; + struct tls_sw_context_tx *sw_ctx_tx = NULL; + struct tls_sw_context_rx *sw_ctx_rx = NULL; struct cipher_context *cctx; struct crypto_aead **aead; struct strp_callbacks cb; @@ -1054,27 +1061,32 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) goto out; } - if (!ctx->priv_ctx) { - sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL); - if (!sw_ctx) { + if (tx) { + sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); + if (!sw_ctx_tx) { rc = -ENOMEM; goto out; } - crypto_init_wait(&sw_ctx->async_wait); + crypto_init_wait(&sw_ctx_tx->async_wait); + ctx->priv_ctx_tx = sw_ctx_tx; } else { - sw_ctx = ctx->priv_ctx; + sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); + if (!sw_ctx_rx) { + rc = -ENOMEM; + goto out; + } + crypto_init_wait(&sw_ctx_rx->async_wait); + ctx->priv_ctx_rx = sw_ctx_rx; } - ctx->priv_ctx = (struct tls_offload_context *)sw_ctx; - if (tx) { crypto_info = &ctx->crypto_send; cctx = &ctx->tx; - aead = &sw_ctx->aead_send; + aead = &sw_ctx_tx->aead_send; } else { crypto_info = &ctx->crypto_recv; cctx = &ctx->rx; - aead = &sw_ctx->aead_recv; + aead = &sw_ctx_rx->aead_recv; } switch (crypto_info->cipher_type) { @@ -1121,22 +1133,24 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) } memcpy(cctx->rec_seq, rec_seq, rec_seq_size); - if (tx) { - sg_init_table(sw_ctx->sg_encrypted_data, - ARRAY_SIZE(sw_ctx->sg_encrypted_data)); - sg_init_table(sw_ctx->sg_plaintext_data, - ARRAY_SIZE(sw_ctx->sg_plaintext_data)); - - sg_init_table(sw_ctx->sg_aead_in, 2); - sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space, - sizeof(sw_ctx->aad_space)); - sg_unmark_end(&sw_ctx->sg_aead_in[1]); - sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data); - sg_init_table(sw_ctx->sg_aead_out, 2); - sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space, - sizeof(sw_ctx->aad_space)); - sg_unmark_end(&sw_ctx->sg_aead_out[1]); - sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data); + if (sw_ctx_tx) { + sg_init_table(sw_ctx_tx->sg_encrypted_data, + ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data)); + sg_init_table(sw_ctx_tx->sg_plaintext_data, + ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data)); + + sg_init_table(sw_ctx_tx->sg_aead_in, 2); + sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space, + sizeof(sw_ctx_tx->aad_space)); + sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]); + sg_chain(sw_ctx_tx->sg_aead_in, 2, + sw_ctx_tx->sg_plaintext_data); + sg_init_table(sw_ctx_tx->sg_aead_out, 2); + sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space, + sizeof(sw_ctx_tx->aad_space)); + sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]); + sg_chain(sw_ctx_tx->sg_aead_out, 2, + sw_ctx_tx->sg_encrypted_data); } if (!*aead) { @@ -1161,22 +1175,22 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) if (rc) goto free_aead; - if (!tx) { + if (sw_ctx_rx) { /* Set up strparser */ memset(&cb, 0, sizeof(cb)); cb.rcv_msg = tls_queue; cb.parse_msg = tls_read_size; - strp_init(&sw_ctx->strp, sk, &cb); + strp_init(&sw_ctx_rx->strp, sk, &cb); write_lock_bh(&sk->sk_callback_lock); - sw_ctx->saved_data_ready = sk->sk_data_ready; + sw_ctx_rx->saved_data_ready = sk->sk_data_ready; sk->sk_data_ready = tls_data_ready; write_unlock_bh(&sk->sk_callback_lock); - sw_ctx->sk_poll = sk->sk_socket->ops->poll; + sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll; - strp_check_rcv(&sw_ctx->strp); + strp_check_rcv(&sw_ctx_rx->strp); } goto out; @@ -1188,11 +1202,16 @@ free_rec_seq: kfree(cctx->rec_seq); cctx->rec_seq = NULL; free_iv: - kfree(ctx->tx.iv); - ctx->tx.iv = NULL; + kfree(cctx->iv); + cctx->iv = NULL; free_priv: - kfree(ctx->priv_ctx); - ctx->priv_ctx = NULL; + if (tx) { + kfree(ctx->priv_ctx_tx); + ctx->priv_ctx_tx = NULL; + } else { + kfree(ctx->priv_ctx_rx); + ctx->priv_ctx_rx = NULL; + } out: return rc; } |