diff options
Diffstat (limited to 'net/tls/tls_sw.c')
| -rw-r--r-- | net/tls/tls_sw.c | 344 |
1 files changed, 213 insertions, 131 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 20b191227969..2b2d0bae14a9 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -42,8 +42,6 @@ #include <net/strparser.h> #include <net/tls.h> -#define MAX_IV_SIZE TLS_CIPHER_AES_GCM_128_IV_SIZE - static int __skb_nsg(struct sk_buff *skb, int offset, int len, unsigned int recursion_level) { @@ -121,23 +119,25 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len) } static int padding_length(struct tls_sw_context_rx *ctx, - struct tls_context *tls_ctx, struct sk_buff *skb) + struct tls_prot_info *prot, struct sk_buff *skb) { struct strp_msg *rxm = strp_msg(skb); int sub = 0; /* Determine zero-padding length */ - if (tls_ctx->prot_info.version == TLS_1_3_VERSION) { + if (prot->version == TLS_1_3_VERSION) { char content_type = 0; int err; int back = 17; while (content_type == 0) { - if (back > rxm->full_len) + if (back > rxm->full_len - prot->prepend_size) return -EBADMSG; err = skb_copy_bits(skb, rxm->offset + rxm->full_len - back, &content_type, 1); + if (err) + return err; if (content_type) break; sub++; @@ -168,13 +168,24 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) /* Propagate if there was an err */ if (err) { + if (err == -EBADMSG) + TLS_INC_STATS(sock_net(skb->sk), + LINUX_MIB_TLSDECRYPTERROR); ctx->async_wait.err = err; tls_err_abort(skb->sk, err); } else { struct strp_msg *rxm = strp_msg(skb); - rxm->full_len -= padding_length(ctx, tls_ctx, skb); - rxm->offset += prot->prepend_size; - rxm->full_len -= prot->overhead_size; + int pad; + + pad = padding_length(ctx, prot, skb); + if (pad < 0) { + ctx->async_wait.err = pad; + tls_err_abort(skb->sk, pad); + } else { + rxm->full_len -= pad; + rxm->offset += prot->prepend_size; + rxm->full_len -= prot->overhead_size; + } } /* After using skb->sk to propagate sk through crypto async callback @@ -225,7 +236,7 @@ static int tls_do_decryption(struct sock *sk, /* Using skb->sk to push sk through to crypto async callback * handler. This allows propagating errors up to the socket * if needed. It _must_ be cleared in the async handler - * before kfree_skb is called. We _know_ skb->sk is NULL + * before consume_skb is called. We _know_ skb->sk is NULL * because it is a clone from strparser. */ skb->sk = sk; @@ -245,6 +256,8 @@ static int tls_do_decryption(struct sock *sk, return ret; ret = crypto_wait_req(ret, &ctx->async_wait); + } else if (ret == -EBADMSG) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); } if (async) @@ -479,11 +492,18 @@ static int tls_do_encryption(struct sock *sk, struct tls_rec *rec = ctx->open_rec; struct sk_msg *msg_en = &rec->msg_encrypted; struct scatterlist *sge = sk_msg_elem(msg_en, start); - int rc; + int rc, iv_offset = 0; + + /* For CCM based ciphers, first byte of IV is a constant */ + if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) { + rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE; + iv_offset = 1; + } - memcpy(rec->iv_data, tls_ctx->tx.iv, sizeof(rec->iv_data)); - xor_iv_with_seq(prot->version, rec->iv_data, - tls_ctx->tx.rec_seq); + memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv, + prot->iv_size + prot->salt_size); + + xor_iv_with_seq(prot->version, rec->iv_data, tls_ctx->tx.rec_seq); sge->offset += prot->prepend_size; sge->length -= prot->prepend_size; @@ -519,7 +539,7 @@ static int tls_do_encryption(struct sock *sk, /* Unhook the record from context if encryption is not failure */ ctx->open_rec = NULL; - tls_advance_record_sn(sk, &tls_ctx->tx, prot->version); + tls_advance_record_sn(sk, prot, &tls_ctx->tx); return rc; } @@ -690,8 +710,7 @@ static int tls_push_record(struct sock *sk, int flags, } i = msg_pl->sg.start; - sg_chain(rec->sg_aead_in, 2, rec->inplace_crypto ? - &msg_en->sg.data[i] : &msg_pl->sg.data[i]); + sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]); i = msg_en->sg.end; sk_msg_iter_var_prev(i); @@ -751,8 +770,14 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, policy = !(flags & MSG_SENDPAGE_NOPOLICY); psock = sk_psock_get(sk); - if (!psock || !policy) - return tls_push_record(sk, flags, record_type); + if (!psock || !policy) { + err = tls_push_record(sk, flags, record_type); + if (err) { + *copied -= sk_msg_free(sk, msg); + tls_free_open_rec(sk); + } + return err; + } more_data: enospc = sk_msg_full(msg); if (psock->eval == __SK_NONE) { @@ -882,15 +907,9 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) return -ENOTSUPP; + mutex_lock(&tls_ctx->tx_lock); lock_sock(sk); - /* Wait till there is any pending write on socket */ - if (unlikely(sk->sk_write_pending)) { - ret = wait_on_pending_writer(sk, &timeo); - if (unlikely(ret)) - goto send_end; - } - if (unlikely(msg->msg_controllen)) { ret = tls_proccess_cmsg(sk, msg, &record_type); if (ret) { @@ -956,8 +975,6 @@ alloc_encrypted: if (ret) goto fallback_to_reg_send; - rec->inplace_crypto = 0; - num_zc++; copied += try_to_copy; @@ -970,7 +987,7 @@ alloc_encrypted: num_async++; else if (ret == -ENOMEM) goto wait_for_memory; - else if (ret == -ENOSPC) + else if (ctx->open_rec && ret == -ENOSPC) goto rollback_iter; else if (ret != -EAGAIN) goto send_end; @@ -1039,11 +1056,12 @@ wait_for_memory: ret = sk_stream_wait_memory(sk, &timeo); if (ret) { trim_sgl: - tls_trim_both_msgs(sk, orig_size); + if (ctx->open_rec) + tls_trim_both_msgs(sk, orig_size); goto send_end; } - if (msg_en->sg.size < required_size) + if (ctx->open_rec && msg_en->sg.size < required_size) goto alloc_encrypted; } @@ -1076,6 +1094,7 @@ send_end: ret = sk_stream_error(sk, msg->msg_flags, ret); release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); return copied ? copied : ret; } @@ -1099,13 +1118,6 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page, eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST)); sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); - /* Wait till there is any pending write on socket */ - if (unlikely(sk->sk_write_pending)) { - ret = wait_on_pending_writer(sk, &timeo); - if (unlikely(ret)) - goto sendpage_end; - } - /* Call the sk_stream functions to manage the sndbuf mem. */ while (size > 0) { size_t copy, required_size; @@ -1128,7 +1140,6 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page, full_record = false; record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; - copied = 0; copy = size; if (copy >= record_room) { copy = record_room; @@ -1162,7 +1173,6 @@ alloc_payload: tls_ctx->pending_open_record_frags = true; if (full_record || eor || sk_msg_full(msg_pl)) { - rec->inplace_crypto = 0; ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, record_type, &copied, flags); if (ret) { @@ -1183,11 +1193,13 @@ wait_for_sndbuf: wait_for_memory: ret = sk_stream_wait_memory(sk, &timeo); if (ret) { - tls_trim_both_msgs(sk, msg_pl->sg.size); + if (ctx->open_rec) + tls_trim_both_msgs(sk, msg_pl->sg.size); goto sendpage_end; } - goto alloc_payload; + if (ctx->open_rec) + goto alloc_payload; } if (num_async) { @@ -1202,18 +1214,32 @@ sendpage_end: return copied ? copied : ret; } +int tls_sw_sendpage_locked(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | + MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY | + MSG_NO_SHARED_FRAGS)) + return -ENOTSUPP; + + return tls_sw_do_sendpage(sk, page, offset, size, flags); +} + int tls_sw_sendpage(struct sock *sk, struct page *page, int offset, size_t size, int flags) { + struct tls_context *tls_ctx = tls_get_ctx(sk); int ret; if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY)) return -ENOTSUPP; + mutex_lock(&tls_ctx->tx_lock); lock_sock(sk); ret = tls_sw_do_sendpage(sk, page, offset, size, flags); release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); return ret; } @@ -1344,6 +1370,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, struct scatterlist *sgout = NULL; const int data_len = rxm->full_len - prot->overhead_size + prot->tail_size; + int iv_offset = 0; if (*zc && (out_iov || out_sg)) { if (out_iov) @@ -1386,18 +1413,25 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, aad = (u8 *)(sgout + n_sgout); iv = aad + prot->aad_size; + /* For CCM based ciphers, first byte of nonce+iv is always '2' */ + if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) { + iv[0] = 2; + iv_offset = 1; + } + /* Prepare IV */ err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, - iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, + iv + iv_offset + prot->salt_size, prot->iv_size); if (err < 0) { kfree(mem); return err; } if (prot->version == TLS_1_3_VERSION) - memcpy(iv, tls_ctx->rx.iv, crypto_aead_ivsize(ctx->aead_recv)); + memcpy(iv + iv_offset, tls_ctx->rx.iv, + crypto_aead_ivsize(ctx->aead_recv)); else - memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); + memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size); xor_iv_with_seq(prot->version, iv, tls_ctx->rx.rec_seq); @@ -1463,24 +1497,24 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_prot_info *prot = &tls_ctx->prot_info; - int version = prot->version; struct strp_msg *rxm = strp_msg(skb); - int err = 0; + int pad, err = 0; if (!ctx->decrypted) { -#ifdef CONFIG_TLS_DEVICE - err = tls_device_decrypted(sk, skb); - if (err < 0) - return err; -#endif + if (tls_ctx->rx_conf == TLS_HW) { + err = tls_device_decrypted(sk, tls_ctx, skb, rxm); + if (err < 0) + return err; + } + /* Still not decrypted after tls_device */ if (!ctx->decrypted) { err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async); if (err < 0) { if (err == -EINPROGRESS) - tls_advance_record_sn(sk, &tls_ctx->rx, - version); + tls_advance_record_sn(sk, prot, + &tls_ctx->rx); return err; } @@ -1488,11 +1522,15 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, *zc = false; } - rxm->full_len -= padding_length(ctx, tls_ctx, skb); + pad = padding_length(ctx, prot, skb); + if (pad < 0) + return pad; + + rxm->full_len -= pad; rxm->offset += prot->prepend_size; rxm->full_len -= prot->overhead_size; - tls_advance_record_sn(sk, &tls_ctx->rx, version); - ctx->decrypted = true; + tls_advance_record_sn(sk, prot, &tls_ctx->rx); + ctx->decrypted = 1; ctx->saved_data_ready(sk); } else { *zc = false; @@ -1524,7 +1562,7 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, rxm->full_len -= len; return false; } - kfree_skb(skb); + consume_skb(skb); } /* Finished with message */ @@ -1633,7 +1671,7 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, if (!is_peek) { skb_unlink(skb, &ctx->rx_list); - kfree_skb(skb); + consume_skb(skb); } skb = next_skb; @@ -1685,15 +1723,14 @@ int tls_sw_recvmsg(struct sock *sk, copied = err; } - len = len - copied; - if (len) { - target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); - timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); - } else { + if (len <= copied) goto recv_end; - } - do { + target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); + len = len - copied; + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + while (len && (decrypted + copied < target || ctx->recv_pkt)) { bool retain_skb = false; bool zc = false; int to_decrypt; @@ -1824,11 +1861,7 @@ pick_next_record: } else { break; } - - /* If we have a new message from strparser, continue now. */ - if (decrypted >= target && !ctx->recv_pkt) - break; - } while (len); + } recv_end: if (num_async) { @@ -1907,7 +1940,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, tls_err_abort(sk, EBADMSG); goto splice_read_end; } - ctx->decrypted = true; + ctx->decrypted = 1; } rxm = strp_msg(skb); @@ -1937,7 +1970,8 @@ bool tls_sw_stream_read(const struct sock *sk) ingress_empty = list_empty(&psock->ingress_msg); rcu_read_unlock(); - return !ingress_empty || ctx->recv_pkt; + return !ingress_empty || ctx->recv_pkt || + !skb_queue_empty(&ctx->rx_list); } static int tls_read_size(struct strparser *strp, struct sk_buff *skb) @@ -1991,10 +2025,9 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) ret = -EINVAL; goto read_failure; } -#ifdef CONFIG_TLS_DEVICE - handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset, - *(u64*)tls_ctx->rx.rec_seq); -#endif + + tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE, + TCP_SKB_CB(skb)->seq + rxm->offset); return data_len + TLS_HEADER_SIZE; read_failure: @@ -2008,7 +2041,7 @@ static void tls_queue(struct strparser *strp, struct sk_buff *skb) struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - ctx->decrypted = false; + ctx->decrypted = 0; ctx->recv_pkt = skb; strp_pause(strp); @@ -2031,7 +2064,16 @@ static void tls_data_ready(struct sock *sk) } } -void tls_sw_free_resources_tx(struct sock *sk) +void tls_sw_cancel_work_tx(struct tls_context *tls_ctx) +{ + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + + set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask); + set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask); + cancel_delayed_work_sync(&ctx->tx_work.work); +} + +void tls_sw_release_resources_tx(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); @@ -2042,30 +2084,13 @@ void tls_sw_free_resources_tx(struct sock *sk) if (atomic_read(&ctx->encrypt_pending)) crypto_wait_req(-EINPROGRESS, &ctx->async_wait); - release_sock(sk); - cancel_delayed_work_sync(&ctx->tx_work.work); - lock_sock(sk); - - /* Tx whatever records we can transmit and abandon the rest */ tls_tx_records(sk, -1); /* Free up un-sent records in tx_list. First, free * the partially sent record if any at head of tx_list. */ if (tls_ctx->partially_sent_record) { - struct scatterlist *sg = tls_ctx->partially_sent_record; - - while (1) { - put_page(sg_page(sg)); - sk_mem_uncharge(sk, sg->length); - - if (sg_is_last(sg)) - break; - sg++; - } - - tls_ctx->partially_sent_record = NULL; - + tls_free_partial_record(sk, tls_ctx); rec = list_first_entry(&ctx->tx_list, struct tls_rec, list); list_del(&rec->list); @@ -2082,6 +2107,11 @@ void tls_sw_free_resources_tx(struct sock *sk) crypto_free_aead(ctx->aead_send); tls_free_open_rec(sk); +} + +void tls_sw_free_ctx_tx(struct tls_context *tls_ctx) +{ + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); kfree(ctx); } @@ -2091,31 +2121,49 @@ void tls_sw_release_resources_rx(struct sock *sk) struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + kfree(tls_ctx->rx.rec_seq); + kfree(tls_ctx->rx.iv); + if (ctx->aead_recv) { kfree_skb(ctx->recv_pkt); ctx->recv_pkt = NULL; skb_queue_purge(&ctx->rx_list); crypto_free_aead(ctx->aead_recv); strp_stop(&ctx->strp); - write_lock_bh(&sk->sk_callback_lock); - sk->sk_data_ready = ctx->saved_data_ready; - write_unlock_bh(&sk->sk_callback_lock); - release_sock(sk); - strp_done(&ctx->strp); - lock_sock(sk); + /* If tls_sw_strparser_arm() was not called (cleanup paths) + * we still want to strp_stop(), but sk->sk_data_ready was + * never swapped. + */ + if (ctx->saved_data_ready) { + write_lock_bh(&sk->sk_callback_lock); + sk->sk_data_ready = ctx->saved_data_ready; + write_unlock_bh(&sk->sk_callback_lock); + } } } -void tls_sw_free_resources_rx(struct sock *sk) +void tls_sw_strparser_done(struct tls_context *tls_ctx) { - struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - tls_sw_release_resources_rx(sk); + strp_done(&ctx->strp); +} + +void tls_sw_free_ctx_rx(struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); kfree(ctx); } +void tls_sw_free_resources_rx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + + tls_sw_release_resources_rx(sk); + tls_sw_free_ctx_rx(tls_ctx); +} + /* The work handler to transmitt the encrypted records in tx_list */ static void tx_work_handler(struct work_struct *work) { @@ -2124,14 +2172,22 @@ static void tx_work_handler(struct work_struct *work) struct tx_work, work); struct sock *sk = tx_work->sk; struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_sw_context_tx *ctx; - if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) + if (unlikely(!tls_ctx)) + return; + + ctx = tls_sw_ctx_tx(tls_ctx); + if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask)) return; + if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) + return; + mutex_lock(&tls_ctx->tx_lock); lock_sock(sk); tls_tx_records(sk, -1); release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); } void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) @@ -2139,12 +2195,21 @@ void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); /* Schedule the transmission if tx list is ready */ - if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) { - /* Schedule the transmission */ - if (!test_and_set_bit(BIT_TX_SCHEDULED, - &tx_ctx->tx_bitmask)) - schedule_delayed_work(&tx_ctx->tx_work.work, 0); - } + if (is_tx_ready(tx_ctx) && + !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) + schedule_delayed_work(&tx_ctx->tx_work.work, 0); +} + +void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx); + + write_lock_bh(&sk->sk_callback_lock); + rx_ctx->saved_data_ready = sk->sk_data_ready; + sk->sk_data_ready = tls_data_ready; + write_unlock_bh(&sk->sk_callback_lock); + + strp_check_rcv(&rx_ctx->strp); } int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) @@ -2154,14 +2219,15 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) struct tls_crypto_info *crypto_info; struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; struct tls12_crypto_info_aes_gcm_256 *gcm_256_info; + struct tls12_crypto_info_aes_ccm_128 *ccm_128_info; struct tls_sw_context_tx *sw_ctx_tx = NULL; struct tls_sw_context_rx *sw_ctx_rx = NULL; struct cipher_context *cctx; struct crypto_aead **aead; struct strp_callbacks cb; - u16 nonce_size, tag_size, iv_size, rec_seq_size; + u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size; struct crypto_tfm *tfm; - char *iv, *rec_seq, *key, *salt; + char *iv, *rec_seq, *key, *salt, *cipher_name; size_t keysize; int rc = 0; @@ -2226,6 +2292,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE; key = gcm_128_info->key; salt = gcm_128_info->salt; + salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE; + cipher_name = "gcm(aes)"; break; } case TLS_CIPHER_AES_GCM_256: { @@ -2241,6 +2309,25 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE; key = gcm_256_info->key; salt = gcm_256_info->salt; + salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE; + cipher_name = "gcm(aes)"; + break; + } + case TLS_CIPHER_AES_CCM_128: { + nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; + tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE; + iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; + iv = ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->iv; + rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE; + rec_seq = + ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->rec_seq; + ccm_128_info = + (struct tls12_crypto_info_aes_ccm_128 *)crypto_info; + keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE; + key = ccm_128_info->key; + salt = ccm_128_info->salt; + salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE; + cipher_name = "ccm(aes)"; break; } default: @@ -2248,8 +2335,9 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) goto free_priv; } - /* Sanity-check the IV size for stack allocations. */ - if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) { + /* Sanity-check the sizes for stack allocations. */ + if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE || + rec_seq_size > TLS_MAX_REC_SEQ_SIZE) { rc = -EINVAL; goto free_priv; } @@ -2270,16 +2358,16 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) prot->overhead_size = prot->prepend_size + prot->tag_size + prot->tail_size; prot->iv_size = iv_size; - cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, - GFP_KERNEL); + prot->salt_size = salt_size; + cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL); if (!cctx->iv) { rc = -ENOMEM; goto free_priv; } /* Note: 128 & 256 bit salt are the same size */ - memcpy(cctx->iv, salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); - memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); prot->rec_seq_size = rec_seq_size; + memcpy(cctx->iv, salt, salt_size); + memcpy(cctx->iv + salt_size, iv, iv_size); cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); if (!cctx->rec_seq) { rc = -ENOMEM; @@ -2287,7 +2375,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) } if (!*aead) { - *aead = crypto_alloc_aead("gcm(aes)", 0, 0); + *aead = crypto_alloc_aead(cipher_name, 0, 0); if (IS_ERR(*aead)) { rc = PTR_ERR(*aead); *aead = NULL; @@ -2310,10 +2398,11 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv); if (crypto_info->version == TLS_1_3_VERSION) - sw_ctx_rx->async_capable = false; + sw_ctx_rx->async_capable = 0; else sw_ctx_rx->async_capable = - tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC; + !!(tfm->__crt_alg->cra_flags & + CRYPTO_ALG_ASYNC); /* Set up strparser */ memset(&cb, 0, sizeof(cb)); @@ -2321,13 +2410,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) cb.parse_msg = tls_read_size; strp_init(&sw_ctx_rx->strp, sk, &cb); - - write_lock_bh(&sk->sk_callback_lock); - sw_ctx_rx->saved_data_ready = sk->sk_data_ready; - sk->sk_data_ready = tls_data_ready; - write_unlock_bh(&sk->sk_callback_lock); - - strp_check_rcv(&sw_ctx_rx->strp); } goto out; |
