diff options
-rw-r--r-- | include/net/tls.h | 27 | ||||
-rw-r--r-- | include/uapi/linux/tls.h | 2 | ||||
-rw-r--r-- | net/tls/Kconfig | 1 | ||||
-rw-r--r-- | net/tls/tls_main.c | 62 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 587 |
5 files changed, 609 insertions, 70 deletions
diff --git a/include/net/tls.h b/include/net/tls.h index 095b72283861..437a746300bf 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -40,6 +40,7 @@ #include <linux/socket.h> #include <linux/tcp.h> #include <net/tcp.h> +#include <net/strparser.h> #include <uapi/linux/tls.h> @@ -58,8 +59,18 @@ struct tls_sw_context { struct crypto_aead *aead_send; + struct crypto_aead *aead_recv; struct crypto_wait async_wait; + /* Receive context */ + struct strparser strp; + void (*saved_data_ready)(struct sock *sk); + unsigned int (*sk_poll)(struct file *file, struct socket *sock, + struct poll_table_struct *wait); + struct sk_buff *recv_pkt; + u8 control; + bool decrypted; + /* Sending context */ char aad_space[TLS_AAD_SPACE_SIZE]; @@ -96,12 +107,17 @@ struct tls_context { struct tls_crypto_info crypto_send; struct tls12_crypto_info_aes_gcm_128 crypto_send_aes_gcm_128; }; + union { + struct tls_crypto_info crypto_recv; + struct tls12_crypto_info_aes_gcm_128 crypto_recv_aes_gcm_128; + }; void *priv_ctx; u8 conf:2; struct cipher_context tx; + struct cipher_context rx; struct scatterlist *partially_sent_record; u16 partially_sent_offset; @@ -128,12 +144,19 @@ int tls_sk_attach(struct sock *sk, int optname, char __user *optval, unsigned int optlen); -int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx); +int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx); int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); int tls_sw_sendpage(struct sock *sk, struct page *page, int offset, size_t size, int flags); void tls_sw_close(struct sock *sk, long timeout); -void tls_sw_free_tx_resources(struct sock *sk); +void tls_sw_free_resources(struct sock *sk); +int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int nonblock, int flags, int *addr_len); +unsigned int tls_sw_poll(struct file *file, struct socket *sock, + struct poll_table_struct *wait); +ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, + struct pipe_inode_info *pipe, + size_t len, unsigned int flags); void tls_sk_destruct(struct sock *sk, struct tls_context *ctx); void tls_icsk_clean_acked(struct sock *sk); diff --git a/include/uapi/linux/tls.h b/include/uapi/linux/tls.h index 293b2cdad88d..c6633e97eca4 100644 --- a/include/uapi/linux/tls.h +++ b/include/uapi/linux/tls.h @@ -38,6 +38,7 @@ /* TLS socket options */ #define TLS_TX 1 /* Set transmit parameters */ +#define TLS_RX 2 /* Set receive parameters */ /* Supported versions */ #define TLS_VERSION_MINOR(ver) ((ver) & 0xFF) @@ -59,6 +60,7 @@ #define TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE 8 #define TLS_SET_RECORD_TYPE 1 +#define TLS_GET_RECORD_TYPE 2 struct tls_crypto_info { __u16 version; diff --git a/net/tls/Kconfig b/net/tls/Kconfig index eb583038c67e..89b8745a986f 100644 --- a/net/tls/Kconfig +++ b/net/tls/Kconfig @@ -7,6 +7,7 @@ config TLS select CRYPTO select CRYPTO_AES select CRYPTO_GCM + select STREAM_PARSER default n ---help--- Enable kernel support for TLS protocol. This allows symmetric diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index c405beeda765..6f5c1146da4a 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -54,12 +54,15 @@ enum { enum { TLS_BASE, TLS_SW_TX, + TLS_SW_RX, + TLS_SW_RXTX, TLS_NUM_CONFIG, }; static struct proto *saved_tcpv6_prot; static DEFINE_MUTEX(tcpv6_prot_mutex); static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG]; +static struct proto_ops tls_sw_proto_ops; static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx) { @@ -261,9 +264,14 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) kfree(ctx->tx.rec_seq); kfree(ctx->tx.iv); + kfree(ctx->rx.rec_seq); + kfree(ctx->rx.iv); - if (ctx->conf == TLS_SW_TX) - tls_sw_free_tx_resources(sk); + if (ctx->conf == TLS_SW_TX || + ctx->conf == TLS_SW_RX || + ctx->conf == TLS_SW_RXTX) { + tls_sw_free_resources(sk); + } skip_tx_cleanup: release_sock(sk); @@ -365,8 +373,8 @@ static int tls_getsockopt(struct sock *sk, int level, int optname, return do_tls_getsockopt(sk, optname, optval, optlen); } -static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval, - unsigned int optlen) +static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, + unsigned int optlen, int tx) { struct tls_crypto_info *crypto_info; struct tls_context *ctx = tls_get_ctx(sk); @@ -378,7 +386,11 @@ static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval, goto out; } - crypto_info = &ctx->crypto_send; + if (tx) + crypto_info = &ctx->crypto_send; + else + crypto_info = &ctx->crypto_recv; + /* Currently we don't support set crypto info more than one time */ if (TLS_CRYPTO_INFO_READY(crypto_info)) { rc = -EBUSY; @@ -417,15 +429,31 @@ static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval, } /* currently SW is default, we will have ethtool in future */ - rc = tls_set_sw_offload(sk, ctx); - conf = TLS_SW_TX; + if (tx) { + rc = tls_set_sw_offload(sk, ctx, 1); + if (ctx->conf == TLS_SW_RX) + conf = TLS_SW_RXTX; + else + conf = TLS_SW_TX; + } else { + rc = tls_set_sw_offload(sk, ctx, 0); + if (ctx->conf == TLS_SW_TX) + conf = TLS_SW_RXTX; + else + conf = TLS_SW_RX; + } + if (rc) goto err_crypto_info; ctx->conf = conf; update_sk_prot(sk, ctx); - ctx->sk_write_space = sk->sk_write_space; - sk->sk_write_space = tls_write_space; + if (tx) { + ctx->sk_write_space = sk->sk_write_space; + sk->sk_write_space = tls_write_space; + } else { + sk->sk_socket->ops = &tls_sw_proto_ops; + } goto out; err_crypto_info: @@ -441,8 +469,10 @@ static int do_tls_setsockopt(struct sock *sk, int optname, switch (optname) { case TLS_TX: + case TLS_RX: lock_sock(sk); - rc = do_tls_setsockopt_tx(sk, optval, optlen); + rc = do_tls_setsockopt_conf(sk, optval, optlen, + optname == TLS_TX); release_sock(sk); break; default: @@ -473,6 +503,14 @@ static void build_protos(struct proto *prot, struct proto *base) prot[TLS_SW_TX] = prot[TLS_BASE]; prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg; prot[TLS_SW_TX].sendpage = tls_sw_sendpage; + + prot[TLS_SW_RX] = prot[TLS_BASE]; + prot[TLS_SW_RX].recvmsg = tls_sw_recvmsg; + prot[TLS_SW_RX].close = tls_sk_proto_close; + + prot[TLS_SW_RXTX] = prot[TLS_SW_TX]; + prot[TLS_SW_RXTX].recvmsg = tls_sw_recvmsg; + prot[TLS_SW_RXTX].close = tls_sk_proto_close; } static int tls_init(struct sock *sk) @@ -531,6 +569,10 @@ static int __init tls_register(void) { build_protos(tls_prots[TLSV4], &tcp_prot); + tls_sw_proto_ops = inet_stream_ops; + tls_sw_proto_ops.poll = tls_sw_poll; + tls_sw_proto_ops.splice_read = tls_sw_splice_read; + tcp_register_ulp(&tcp_tls_ulp_ops); return 0; diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 1c79d9ad1731..4dc766b03f00 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -34,11 +34,60 @@ * SOFTWARE. */ +#include <linux/sched/signal.h> #include <linux/module.h> #include <crypto/aead.h> +#include <net/strparser.h> #include <net/tls.h> +static int tls_do_decryption(struct sock *sk, + struct scatterlist *sgin, + struct scatterlist *sgout, + char *iv_recv, + size_t data_len, + struct sk_buff *skb, + gfp_t flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct strp_msg *rxm = strp_msg(skb); + struct aead_request *aead_req; + + int ret; + unsigned int req_size = sizeof(struct aead_request) + + crypto_aead_reqsize(ctx->aead_recv); + + aead_req = kzalloc(req_size, flags); + if (!aead_req) + return -ENOMEM; + + aead_request_set_tfm(aead_req, ctx->aead_recv); + aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); + aead_request_set_crypt(aead_req, sgin, sgout, + data_len + tls_ctx->rx.tag_size, + (u8 *)iv_recv); + aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, + crypto_req_done, &ctx->async_wait); + + ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait); + + if (ret < 0) + goto out; + + rxm->offset += tls_ctx->rx.prepend_size; + rxm->full_len -= tls_ctx->rx.overhead_size; + tls_advance_record_sn(sk, &tls_ctx->rx); + + ctx->decrypted = true; + + ctx->saved_data_ready(sk); + +out: + kfree(aead_req); + return ret; +} + static void trim_sg(struct sock *sk, struct scatterlist *sg, int *sg_num_elem, unsigned int *sg_size, int target_size) { @@ -581,13 +630,404 @@ sendpage_end: return ret; } -void tls_sw_free_tx_resources(struct sock *sk) +static struct sk_buff *tls_wait_data(struct sock *sk, int flags, + long timeo, int *err) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct sk_buff *skb; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + while (!(skb = ctx->recv_pkt)) { + if (sk->sk_err) { + *err = sock_error(sk); + return NULL; + } + + if (sock_flag(sk, SOCK_DONE)) + return NULL; + + if ((flags & MSG_DONTWAIT) || !timeo) { + *err = -EAGAIN; + return NULL; + } + + add_wait_queue(sk_sleep(sk), &wait); + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait); + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + remove_wait_queue(sk_sleep(sk), &wait); + + /* Handle signals */ + if (signal_pending(current)) { + *err = sock_intr_errno(timeo); + return NULL; + } + } + + return skb; +} + +static int decrypt_skb(struct sock *sk, struct sk_buff *skb, + struct scatterlist *sgout) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + tls_ctx->rx.iv_size]; + struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2]; + struct scatterlist *sgin = &sgin_arr[0]; + struct strp_msg *rxm = strp_msg(skb); + int ret, nsg = ARRAY_SIZE(sgin_arr); + char aad_recv[TLS_AAD_SPACE_SIZE]; + struct sk_buff *unused; + + ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, + iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, + tls_ctx->rx.iv_size); + if (ret < 0) + return ret; + + memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); + if (!sgout) { + nsg = skb_cow_data(skb, 0, &unused) + 1; + sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation); + if (!sgout) + sgout = sgin; + } + + sg_init_table(sgin, nsg); + sg_set_buf(&sgin[0], aad_recv, sizeof(aad_recv)); + + nsg = skb_to_sgvec(skb, &sgin[1], + rxm->offset + tls_ctx->rx.prepend_size, + rxm->full_len - tls_ctx->rx.prepend_size); + + tls_make_aad(aad_recv, + rxm->full_len - tls_ctx->rx.overhead_size, + tls_ctx->rx.rec_seq, + tls_ctx->rx.rec_seq_size, + ctx->control); + + ret = tls_do_decryption(sk, sgin, sgout, iv, + rxm->full_len - tls_ctx->rx.overhead_size, + skb, sk->sk_allocation); + + if (sgin != &sgin_arr[0]) + kfree(sgin); + + return ret; +} + +static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, + unsigned int len) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct strp_msg *rxm = strp_msg(skb); + + if (len < rxm->full_len) { + rxm->offset += len; + rxm->full_len -= len; + + return false; + } + + /* Finished with message */ + ctx->recv_pkt = NULL; + kfree_skb(skb); + strp_unpause(&ctx->strp); + + return true; +} + +int tls_sw_recvmsg(struct sock *sk, + struct msghdr *msg, + size_t len, + int nonblock, + int flags, + int *addr_len) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + unsigned char control; + struct strp_msg *rxm; + struct sk_buff *skb; + ssize_t copied = 0; + bool cmsg = false; + int err = 0; + long timeo; + + flags |= nonblock; + + if (unlikely(flags & MSG_ERRQUEUE)) + return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); + + lock_sock(sk); + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + do { + bool zc = false; + int chunk = 0; + + skb = tls_wait_data(sk, flags, timeo, &err); + if (!skb) + goto recv_end; + + rxm = strp_msg(skb); + if (!cmsg) { + int cerr; + + cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, + sizeof(ctx->control), &ctx->control); + cmsg = true; + control = ctx->control; + if (ctx->control != TLS_RECORD_TYPE_DATA) { + if (cerr || msg->msg_flags & MSG_CTRUNC) { + err = -EIO; + goto recv_end; + } + } + } else if (control != ctx->control) { + goto recv_end; + } + + if (!ctx->decrypted) { + int page_count; + int to_copy; + + page_count = iov_iter_npages(&msg->msg_iter, + MAX_SKB_FRAGS); + to_copy = rxm->full_len - tls_ctx->rx.overhead_size; + if (to_copy <= len && page_count < MAX_SKB_FRAGS && + likely(!(flags & MSG_PEEK))) { + struct scatterlist sgin[MAX_SKB_FRAGS + 1]; + char unused[21]; + int pages = 0; + + zc = true; + sg_init_table(sgin, MAX_SKB_FRAGS + 1); + sg_set_buf(&sgin[0], unused, 13); + + err = zerocopy_from_iter(sk, &msg->msg_iter, + to_copy, &pages, + &chunk, &sgin[1], + MAX_SKB_FRAGS, false); + if (err < 0) + goto fallback_to_reg_recv; + + err = decrypt_skb(sk, skb, sgin); + for (; pages > 0; pages--) + put_page(sg_page(&sgin[pages])); + if (err < 0) { + tls_err_abort(sk, EBADMSG); + goto recv_end; + } + } else { +fallback_to_reg_recv: + err = decrypt_skb(sk, skb, NULL); + if (err < 0) { + tls_err_abort(sk, EBADMSG); + goto recv_end; + } + } + ctx->decrypted = true; + } + + if (!zc) { + chunk = min_t(unsigned int, rxm->full_len, len); + err = skb_copy_datagram_msg(skb, rxm->offset, msg, + chunk); + if (err < 0) + goto recv_end; + } + + copied += chunk; + len -= chunk; + if (likely(!(flags & MSG_PEEK))) { + u8 control = ctx->control; + + if (tls_sw_advance_skb(sk, skb, chunk)) { + /* Return full control message to + * userspace before trying to parse + * another message type + */ + msg->msg_flags |= MSG_EOR; + if (control != TLS_RECORD_TYPE_DATA) + goto recv_end; + } + } + } while (len); + +recv_end: + release_sock(sk); + return copied ? : err; +} + +ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, + struct pipe_inode_info *pipe, + size_t len, unsigned int flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sock->sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct strp_msg *rxm = NULL; + struct sock *sk = sock->sk; + struct sk_buff *skb; + ssize_t copied = 0; + int err = 0; + long timeo; + int chunk; + + lock_sock(sk); + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + skb = tls_wait_data(sk, flags, timeo, &err); + if (!skb) + goto splice_read_end; + + /* splice does not support reading control messages */ + if (ctx->control != TLS_RECORD_TYPE_DATA) { + err = -ENOTSUPP; + goto splice_read_end; + } + + if (!ctx->decrypted) { + err = decrypt_skb(sk, skb, NULL); + + if (err < 0) { + tls_err_abort(sk, EBADMSG); + goto splice_read_end; + } + ctx->decrypted = true; + } + rxm = strp_msg(skb); + + chunk = min_t(unsigned int, rxm->full_len, len); + copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); + if (copied < 0) + goto splice_read_end; + + if (likely(!(flags & MSG_PEEK))) + tls_sw_advance_skb(sk, skb, copied); + +splice_read_end: + release_sock(sk); + return copied ? : err; +} + +unsigned int tls_sw_poll(struct file *file, struct socket *sock, + struct poll_table_struct *wait) +{ + unsigned int ret; + struct sock *sk = sock->sk; + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + + /* Grab POLLOUT and POLLHUP from the underlying socket */ + ret = ctx->sk_poll(file, sock, wait); + + /* Clear POLLIN bits, and set based on recv_pkt */ + ret &= ~(POLLIN | POLLRDNORM); + if (ctx->recv_pkt) + ret |= POLLIN | POLLRDNORM; + + return ret; +} + +static int tls_read_size(struct strparser *strp, struct sk_buff *skb) +{ + struct tls_context *tls_ctx = tls_get_ctx(strp->sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + char header[tls_ctx->rx.prepend_size]; + struct strp_msg *rxm = strp_msg(skb); + size_t cipher_overhead; + size_t data_len = 0; + int ret; + + /* Verify that we have a full TLS header, or wait for more data */ + if (rxm->offset + tls_ctx->rx.prepend_size > skb->len) + return 0; + + /* Linearize header to local buffer */ + ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size); + + if (ret < 0) + goto read_failure; + + ctx->control = header[0]; + + data_len = ((header[4] & 0xFF) | (header[3] << 8)); + + cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size; + + if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) { + ret = -EMSGSIZE; + goto read_failure; + } + if (data_len < cipher_overhead) { + ret = -EBADMSG; + goto read_failure; + } + + if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) || + header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) { + ret = -EINVAL; + goto read_failure; + } + + return data_len + TLS_HEADER_SIZE; + +read_failure: + tls_err_abort(strp->sk, ret); + + return ret; +} + +static void tls_queue(struct strparser *strp, struct sk_buff *skb) +{ + struct tls_context *tls_ctx = tls_get_ctx(strp->sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct strp_msg *rxm; + + rxm = strp_msg(skb); + + ctx->decrypted = false; + + ctx->recv_pkt = skb; + strp_pause(strp); + + strp->sk->sk_state_change(strp->sk); +} + +static void tls_data_ready(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + + strp_data_ready(&ctx->strp); +} + +void tls_sw_free_resources(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); if (ctx->aead_send) crypto_free_aead(ctx->aead_send); + if (ctx->aead_recv) { + if (ctx->recv_pkt) { + kfree_skb(ctx->recv_pkt); + ctx->recv_pkt = NULL; + } + crypto_free_aead(ctx->aead_recv); + strp_stop(&ctx->strp); + write_lock_bh(&sk->sk_callback_lock); + sk->sk_data_ready = ctx->saved_data_ready; + write_unlock_bh(&sk->sk_callback_lock); + release_sock(sk); + strp_done(&ctx->strp); + lock_sock(sk); + } tls_free_both_sg(sk); @@ -595,12 +1035,15 @@ void tls_sw_free_tx_resources(struct sock *sk) kfree(tls_ctx); } -int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) +int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) { char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE]; struct tls_crypto_info *crypto_info; struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; struct tls_sw_context *sw_ctx; + struct cipher_context *cctx; + struct crypto_aead **aead; + struct strp_callbacks cb; u16 nonce_size, tag_size, iv_size, rec_seq_size; char *iv, *rec_seq; int rc = 0; @@ -610,22 +1053,29 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) goto out; } - if (ctx->priv_ctx) { - rc = -EEXIST; - goto out; - } - - sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL); - if (!sw_ctx) { - rc = -ENOMEM; - goto out; + if (!ctx->priv_ctx) { + sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL); + if (!sw_ctx) { + rc = -ENOMEM; + goto out; + } + crypto_init_wait(&sw_ctx->async_wait); + } else { + sw_ctx = ctx->priv_ctx; } - crypto_init_wait(&sw_ctx->async_wait); - ctx->priv_ctx = (struct tls_offload_context *)sw_ctx; - crypto_info = &ctx->crypto_send; + if (tx) { + crypto_info = &ctx->crypto_send; + cctx = &ctx->tx; + aead = &sw_ctx->aead_send; + } else { + crypto_info = &ctx->crypto_recv; + cctx = &ctx->rx; + aead = &sw_ctx->aead_recv; + } + switch (crypto_info->cipher_type) { case TLS_CIPHER_AES_GCM_128: { nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; @@ -644,48 +1094,49 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) goto free_priv; } - ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size; - ctx->tx.tag_size = tag_size; - ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size; - ctx->tx.iv_size = iv_size; - ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, - GFP_KERNEL); - if (!ctx->tx.iv) { + cctx->prepend_size = TLS_HEADER_SIZE + nonce_size; + cctx->tag_size = tag_size; + cctx->overhead_size = cctx->prepend_size + cctx->tag_size; + cctx->iv_size = iv_size; + cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, + GFP_KERNEL); + if (!cctx->iv) { rc = -ENOMEM; goto free_priv; } - memcpy(ctx->tx.iv, gcm_128_info->salt, - TLS_CIPHER_AES_GCM_128_SALT_SIZE); - memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); - ctx->tx.rec_seq_size = rec_seq_size; - ctx->tx.rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); - if (!ctx->tx.rec_seq) { + memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); + memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); + cctx->rec_seq_size = rec_seq_size; + cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); + if (!cctx->rec_seq) { rc = -ENOMEM; goto free_iv; } - memcpy(ctx->tx.rec_seq, rec_seq, rec_seq_size); - - sg_init_table(sw_ctx->sg_encrypted_data, - ARRAY_SIZE(sw_ctx->sg_encrypted_data)); - sg_init_table(sw_ctx->sg_plaintext_data, - ARRAY_SIZE(sw_ctx->sg_plaintext_data)); - - sg_init_table(sw_ctx->sg_aead_in, 2); - sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space, - sizeof(sw_ctx->aad_space)); - sg_unmark_end(&sw_ctx->sg_aead_in[1]); - sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data); - sg_init_table(sw_ctx->sg_aead_out, 2); - sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space, - sizeof(sw_ctx->aad_space)); - sg_unmark_end(&sw_ctx->sg_aead_out[1]); - sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data); - - if (!sw_ctx->aead_send) { - sw_ctx->aead_send = crypto_alloc_aead("gcm(aes)", 0, 0); - if (IS_ERR(sw_ctx->aead_send)) { - rc = PTR_ERR(sw_ctx->aead_send); - sw_ctx->aead_send = NULL; + memcpy(cctx->rec_seq, rec_seq, rec_seq_size); + + if (tx) { + sg_init_table(sw_ctx->sg_encrypted_data, + ARRAY_SIZE(sw_ctx->sg_encrypted_data)); + sg_init_table(sw_ctx->sg_plaintext_data, + ARRAY_SIZE(sw_ctx->sg_plaintext_data)); + + sg_init_table(sw_ctx->sg_aead_in, 2); + sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space, + sizeof(sw_ctx->aad_space)); + sg_unmark_end(&sw_ctx->sg_aead_in[1]); + sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data); + sg_init_table(sw_ctx->sg_aead_out, 2); + sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space, + sizeof(sw_ctx->aad_space)); + sg_unmark_end(&sw_ctx->sg_aead_out[1]); + sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data); + } + + if (!*aead) { + *aead = crypto_alloc_aead("gcm(aes)", 0, 0); + if (IS_ERR(*aead)) { + rc = PTR_ERR(*aead); + *aead = NULL; goto free_rec_seq; } } @@ -694,21 +1145,41 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE); - rc = crypto_aead_setkey(sw_ctx->aead_send, keyval, + rc = crypto_aead_setkey(*aead, keyval, TLS_CIPHER_AES_GCM_128_KEY_SIZE); if (rc) goto free_aead; - rc = crypto_aead_setauthsize(sw_ctx->aead_send, ctx->tx.tag_size); - if (!rc) - return 0; + rc = crypto_aead_setauthsize(*aead, cctx->tag_size); + if (rc) + goto free_aead; + + if (!tx) { + /* Set up strparser */ + memset(&cb, 0, sizeof(cb)); + cb.rcv_msg = tls_queue; + cb.parse_msg = tls_read_size; + + strp_init(&sw_ctx->strp, sk, &cb); + + write_lock_bh(&sk->sk_callback_lock); + sw_ctx->saved_data_ready = sk->sk_data_ready; + sk->sk_data_ready = tls_data_ready; + write_unlock_bh(&sk->sk_callback_lock); + + sw_ctx->sk_poll = sk->sk_socket->ops->poll; + + strp_check_rcv(&sw_ctx->strp); + } + + goto out; free_aead: - crypto_free_aead(sw_ctx->aead_send); - sw_ctx->aead_send = NULL; + crypto_free_aead(*aead); + *aead = NULL; free_rec_seq: - kfree(ctx->tx.rec_seq); - ctx->tx.rec_seq = NULL; + kfree(cctx->rec_seq); + cctx->rec_seq = NULL; free_iv: kfree(ctx->tx.iv); ctx->tx.iv = NULL; |