summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJakub Kicinski <kuba@kernel.org>2022-09-23 03:27:45 +0300
committerJakub Kicinski <kuba@kernel.org>2022-09-23 03:27:46 +0300
commit03d25cf7a0e9e83448a67cd1e3c51712f3273984 (patch)
treea8ddd9adfa7f787556161af4b0d385637dd02f3b
parent8db3d514e96715c897fe793c4d5fc0fd86aca517 (diff)
parent4960c414db3582b266dce660bd8eff41157fe2f9 (diff)
downloadlinux-03d25cf7a0e9e83448a67cd1e3c51712f3273984.tar.xz
Merge branch 'support-256-bit-tls-keys-with-device-offload'
Gal Pressman says: ==================== Support 256 bit TLS keys with device offload This series adds support for 256 bit TLS keys with device offload, and a cleanup patch to remove repeating code: - Patches #1-2 add cipher sizes descriptors which allow reducing the amount of code duplications. - Patch #3 allows 256 bit keys to be TX offloaded in the tls module (RX already supported). - Patch #4 adds 256 bit keys support to the mlx5 driver. ==================== Link: https://lore.kernel.org/r/20220920130150.3546-1-gal@nvidia.com Signed-off-by: Jakub Kicinski <kuba@kernel.org>
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls.h7
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_rx.c45
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_tx.c41
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_txrx.c27
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_utils.h8
-rw-r--r--include/net/tls.h10
-rw-r--r--net/tls/tls_device.c61
-rw-r--r--net/tls/tls_device_fallback.c79
-rw-r--r--net/tls/tls_main.c17
9 files changed, 227 insertions, 68 deletions
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls.h b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls.h
index 948400dee525..299334b2f935 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls.h
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls.h
@@ -25,7 +25,8 @@ static inline bool mlx5e_is_ktls_device(struct mlx5_core_dev *mdev)
if (!MLX5_CAP_GEN(mdev, log_max_dek))
return false;
- return MLX5_CAP_TLS(mdev, tls_1_2_aes_gcm_128);
+ return (MLX5_CAP_TLS(mdev, tls_1_2_aes_gcm_128) ||
+ MLX5_CAP_TLS(mdev, tls_1_2_aes_gcm_256));
}
static inline bool mlx5e_ktls_type_check(struct mlx5_core_dev *mdev,
@@ -36,6 +37,10 @@ static inline bool mlx5e_ktls_type_check(struct mlx5_core_dev *mdev,
if (crypto_info->version == TLS_1_2_VERSION)
return MLX5_CAP_TLS(mdev, tls_1_2_aes_gcm_128);
break;
+ case TLS_CIPHER_AES_GCM_256:
+ if (crypto_info->version == TLS_1_2_VERSION)
+ return MLX5_CAP_TLS(mdev, tls_1_2_aes_gcm_256);
+ break;
}
return false;
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_rx.c b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_rx.c
index 5203393adf88..3e54834747ce 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_rx.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_rx.c
@@ -43,7 +43,7 @@ struct mlx5e_ktls_rx_resync_ctx {
};
struct mlx5e_ktls_offload_context_rx {
- struct tls12_crypto_info_aes_gcm_128 crypto_info;
+ union mlx5e_crypto_info crypto_info;
struct accel_rule rule;
struct sock *sk;
struct mlx5e_rq_stats *rq_stats;
@@ -362,7 +362,6 @@ static void resync_init(struct mlx5e_ktls_rx_resync_ctx *resync,
static void resync_handle_seq_match(struct mlx5e_ktls_offload_context_rx *priv_rx,
struct mlx5e_channel *c)
{
- struct tls12_crypto_info_aes_gcm_128 *info = &priv_rx->crypto_info;
struct mlx5e_ktls_resync_resp *ktls_resync;
struct mlx5e_icosq *sq;
bool trigger_poll;
@@ -373,7 +372,31 @@ static void resync_handle_seq_match(struct mlx5e_ktls_offload_context_rx *priv_r
spin_lock_bh(&ktls_resync->lock);
spin_lock_bh(&priv_rx->lock);
- memcpy(info->rec_seq, &priv_rx->resync.sw_rcd_sn_be, sizeof(info->rec_seq));
+ switch (priv_rx->crypto_info.crypto_info.cipher_type) {
+ case TLS_CIPHER_AES_GCM_128: {
+ struct tls12_crypto_info_aes_gcm_128 *info =
+ &priv_rx->crypto_info.crypto_info_128;
+
+ memcpy(info->rec_seq, &priv_rx->resync.sw_rcd_sn_be,
+ sizeof(info->rec_seq));
+ break;
+ }
+ case TLS_CIPHER_AES_GCM_256: {
+ struct tls12_crypto_info_aes_gcm_256 *info =
+ &priv_rx->crypto_info.crypto_info_256;
+
+ memcpy(info->rec_seq, &priv_rx->resync.sw_rcd_sn_be,
+ sizeof(info->rec_seq));
+ break;
+ }
+ default:
+ WARN_ONCE(1, "Unsupported cipher type %u\n",
+ priv_rx->crypto_info.crypto_info.cipher_type);
+ spin_unlock_bh(&priv_rx->lock);
+ spin_unlock_bh(&ktls_resync->lock);
+ return;
+ }
+
if (list_empty(&priv_rx->list)) {
list_add_tail(&priv_rx->list, &ktls_resync->list);
trigger_poll = !test_and_set_bit(MLX5E_SQ_STATE_PENDING_TLS_RX_RESYNC, &sq->state);
@@ -604,8 +627,20 @@ int mlx5e_ktls_add_rx(struct net_device *netdev, struct sock *sk,
INIT_LIST_HEAD(&priv_rx->list);
spin_lock_init(&priv_rx->lock);
- priv_rx->crypto_info =
- *(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
+ switch (crypto_info->cipher_type) {
+ case TLS_CIPHER_AES_GCM_128:
+ priv_rx->crypto_info.crypto_info_128 =
+ *(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
+ break;
+ case TLS_CIPHER_AES_GCM_256:
+ priv_rx->crypto_info.crypto_info_256 =
+ *(struct tls12_crypto_info_aes_gcm_256 *)crypto_info;
+ break;
+ default:
+ WARN_ONCE(1, "Unsupported cipher type %u\n",
+ crypto_info->cipher_type);
+ return -EOPNOTSUPP;
+ }
rxq = mlx5e_ktls_sk_get_rxq(sk);
priv_rx->rxq = rxq;
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_tx.c b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_tx.c
index 3a1f76eac542..2e0335246967 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_tx.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_tx.c
@@ -93,7 +93,7 @@ struct mlx5e_ktls_offload_context_tx {
bool ctx_post_pending;
/* control / resync */
struct list_head list_node; /* member of the pool */
- struct tls12_crypto_info_aes_gcm_128 crypto_info;
+ union mlx5e_crypto_info crypto_info;
struct tls_offload_context_tx *tx_ctx;
struct mlx5_core_dev *mdev;
struct mlx5e_tls_sw_stats *sw_stats;
@@ -485,8 +485,20 @@ int mlx5e_ktls_add_tx(struct net_device *netdev, struct sock *sk,
goto err_create_key;
priv_tx->expected_seq = start_offload_tcp_sn;
- priv_tx->crypto_info =
- *(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
+ switch (crypto_info->cipher_type) {
+ case TLS_CIPHER_AES_GCM_128:
+ priv_tx->crypto_info.crypto_info_128 =
+ *(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
+ break;
+ case TLS_CIPHER_AES_GCM_256:
+ priv_tx->crypto_info.crypto_info_256 =
+ *(struct tls12_crypto_info_aes_gcm_256 *)crypto_info;
+ break;
+ default:
+ WARN_ONCE(1, "Unsupported cipher type %u\n",
+ crypto_info->cipher_type);
+ return -EOPNOTSUPP;
+ }
priv_tx->tx_ctx = tls_offload_ctx_tx(tls_ctx);
mlx5e_set_ktls_tx_priv_ctx(tls_ctx, priv_tx);
@@ -671,14 +683,31 @@ tx_post_resync_params(struct mlx5e_txqsq *sq,
struct mlx5e_ktls_offload_context_tx *priv_tx,
u64 rcd_sn)
{
- struct tls12_crypto_info_aes_gcm_128 *info = &priv_tx->crypto_info;
__be64 rn_be = cpu_to_be64(rcd_sn);
bool skip_static_post;
u16 rec_seq_sz;
char *rec_seq;
- rec_seq = info->rec_seq;
- rec_seq_sz = sizeof(info->rec_seq);
+ switch (priv_tx->crypto_info.crypto_info.cipher_type) {
+ case TLS_CIPHER_AES_GCM_128: {
+ struct tls12_crypto_info_aes_gcm_128 *info = &priv_tx->crypto_info.crypto_info_128;
+
+ rec_seq = info->rec_seq;
+ rec_seq_sz = sizeof(info->rec_seq);
+ break;
+ }
+ case TLS_CIPHER_AES_GCM_256: {
+ struct tls12_crypto_info_aes_gcm_256 *info = &priv_tx->crypto_info.crypto_info_256;
+
+ rec_seq = info->rec_seq;
+ rec_seq_sz = sizeof(info->rec_seq);
+ break;
+ }
+ default:
+ WARN_ONCE(1, "Unsupported cipher type %u\n",
+ priv_tx->crypto_info.crypto_info.cipher_type);
+ return;
+ }
skip_static_post = !memcmp(rec_seq, &rn_be, rec_seq_sz);
if (!skip_static_post)
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_txrx.c b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_txrx.c
index ac29aeb8af49..570a912dd6fa 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_txrx.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_txrx.c
@@ -21,7 +21,7 @@ enum {
static void
fill_static_params(struct mlx5_wqe_tls_static_params_seg *params,
- struct tls12_crypto_info_aes_gcm_128 *info,
+ union mlx5e_crypto_info *crypto_info,
u32 key_id, u32 resync_tcp_sn)
{
char *initial_rn, *gcm_iv;
@@ -32,7 +32,26 @@ fill_static_params(struct mlx5_wqe_tls_static_params_seg *params,
ctx = params->ctx;
- EXTRACT_INFO_FIELDS;
+ switch (crypto_info->crypto_info.cipher_type) {
+ case TLS_CIPHER_AES_GCM_128: {
+ struct tls12_crypto_info_aes_gcm_128 *info =
+ &crypto_info->crypto_info_128;
+
+ EXTRACT_INFO_FIELDS;
+ break;
+ }
+ case TLS_CIPHER_AES_GCM_256: {
+ struct tls12_crypto_info_aes_gcm_256 *info =
+ &crypto_info->crypto_info_256;
+
+ EXTRACT_INFO_FIELDS;
+ break;
+ }
+ default:
+ WARN_ONCE(1, "Unsupported cipher type %u\n",
+ crypto_info->crypto_info.cipher_type);
+ return;
+ }
gcm_iv = MLX5_ADDR_OF(tls_static_params, ctx, gcm_iv);
initial_rn = MLX5_ADDR_OF(tls_static_params, ctx, initial_record_number);
@@ -54,7 +73,7 @@ fill_static_params(struct mlx5_wqe_tls_static_params_seg *params,
void
mlx5e_ktls_build_static_params(struct mlx5e_set_tls_static_params_wqe *wqe,
u16 pc, u32 sqn,
- struct tls12_crypto_info_aes_gcm_128 *info,
+ union mlx5e_crypto_info *crypto_info,
u32 tis_tir_num, u32 key_id, u32 resync_tcp_sn,
bool fence, enum tls_offload_ctx_dir direction)
{
@@ -75,7 +94,7 @@ mlx5e_ktls_build_static_params(struct mlx5e_set_tls_static_params_wqe *wqe,
ucseg->flags = MLX5_UMR_INLINE;
ucseg->bsf_octowords = cpu_to_be16(MLX5_ST_SZ_BYTES(tls_static_params) / 16);
- fill_static_params(&wqe->params, info, key_id, resync_tcp_sn);
+ fill_static_params(&wqe->params, crypto_info, key_id, resync_tcp_sn);
}
static void
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_utils.h b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_utils.h
index 0dc715c4c10d..3d79cd379890 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_utils.h
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_utils.h
@@ -27,6 +27,12 @@ int mlx5e_ktls_add_rx(struct net_device *netdev, struct sock *sk,
void mlx5e_ktls_del_rx(struct net_device *netdev, struct tls_context *tls_ctx);
void mlx5e_ktls_rx_resync(struct net_device *netdev, struct sock *sk, u32 seq, u8 *rcd_sn);
+union mlx5e_crypto_info {
+ struct tls_crypto_info crypto_info;
+ struct tls12_crypto_info_aes_gcm_128 crypto_info_128;
+ struct tls12_crypto_info_aes_gcm_256 crypto_info_256;
+};
+
struct mlx5e_set_tls_static_params_wqe {
struct mlx5_wqe_ctrl_seg ctrl;
struct mlx5_wqe_umr_ctrl_seg uctrl;
@@ -72,7 +78,7 @@ struct mlx5e_get_tls_progress_params_wqe {
void
mlx5e_ktls_build_static_params(struct mlx5e_set_tls_static_params_wqe *wqe,
u16 pc, u32 sqn,
- struct tls12_crypto_info_aes_gcm_128 *info,
+ union mlx5e_crypto_info *crypto_info,
u32 tis_tir_num, u32 key_id, u32 resync_tcp_sn,
bool fence, enum tls_offload_ctx_dir direction);
void
diff --git a/include/net/tls.h b/include/net/tls.h
index cb205f9d9473..154949c7b0c8 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -51,6 +51,16 @@
struct tls_rec;
+struct tls_cipher_size_desc {
+ unsigned int iv;
+ unsigned int key;
+ unsigned int salt;
+ unsigned int tag;
+ unsigned int rec_seq;
+};
+
+extern const struct tls_cipher_size_desc tls_cipher_size_desc[];
+
/* Maximum data size carried in a TLS record */
#define TLS_MAX_PAYLOAD_SIZE ((size_t)1 << 14)
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index 0f983e5f7dde..a03d66046ca3 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -902,17 +902,28 @@ static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
}
static int
-tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx)
+tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx)
{
+ struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
+ const struct tls_cipher_size_desc *cipher_sz;
int err, offset, copy, data_len, pos;
struct sk_buff *skb, *skb_iter;
struct scatterlist sg[1];
struct strp_msg *rxm;
char *orig_buf, *buf;
+ switch (tls_ctx->crypto_recv.info.cipher_type) {
+ case TLS_CIPHER_AES_GCM_128:
+ case TLS_CIPHER_AES_GCM_256:
+ break;
+ default:
+ return -EINVAL;
+ }
+ cipher_sz = &tls_cipher_size_desc[tls_ctx->crypto_recv.info.cipher_type];
+
rxm = strp_msg(tls_strp_msg(sw_ctx));
- orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
- TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
+ orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + cipher_sz->iv,
+ sk->sk_allocation);
if (!orig_buf)
return -ENOMEM;
buf = orig_buf;
@@ -927,10 +938,8 @@ tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx)
sg_init_table(sg, 1);
sg_set_buf(&sg[0], buf,
- rxm->full_len + TLS_HEADER_SIZE +
- TLS_CIPHER_AES_GCM_128_IV_SIZE);
- err = skb_copy_bits(skb, offset, buf,
- TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
+ rxm->full_len + TLS_HEADER_SIZE + cipher_sz->iv);
+ err = skb_copy_bits(skb, offset, buf, TLS_HEADER_SIZE + cipher_sz->iv);
if (err)
goto free_buf;
@@ -941,7 +950,7 @@ tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx)
else
err = 0;
- data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+ data_len = rxm->full_len - cipher_sz->tag;
if (skb_pagelen(skb) > offset) {
copy = min_t(int, skb_pagelen(skb) - offset, data_len);
@@ -1024,7 +1033,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
* likely have initial fragments decrypted, and final ones not
* decrypted. We need to reencrypt that single SKB.
*/
- return tls_device_reencrypt(sk, sw_ctx);
+ return tls_device_reencrypt(sk, tls_ctx);
}
/* Return immediately if the record is either entirely plaintext or
@@ -1041,7 +1050,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
}
ctx->resync_nh_reset = 1;
- return tls_device_reencrypt(sk, sw_ctx);
+ return tls_device_reencrypt(sk, tls_ctx);
}
static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
@@ -1062,9 +1071,9 @@ static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
{
- u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
+ const struct tls_cipher_size_desc *cipher_sz;
struct tls_record_info *start_marker_record;
struct tls_offload_context_tx *offload_ctx;
struct tls_crypto_info *crypto_info;
@@ -1099,44 +1108,44 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
switch (crypto_info->cipher_type) {
case TLS_CIPHER_AES_GCM_128:
- nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
- tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
- iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
- rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
- salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
rec_seq =
((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
break;
+ case TLS_CIPHER_AES_GCM_256:
+ iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv;
+ rec_seq =
+ ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq;
+ break;
default:
rc = -EINVAL;
goto release_netdev;
}
+ cipher_sz = &tls_cipher_size_desc[crypto_info->cipher_type];
/* Sanity-check the rec_seq_size for stack allocations */
- if (rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
+ if (cipher_sz->rec_seq > TLS_MAX_REC_SEQ_SIZE) {
rc = -EINVAL;
goto release_netdev;
}
prot->version = crypto_info->version;
prot->cipher_type = crypto_info->cipher_type;
- prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
- prot->tag_size = tag_size;
+ prot->prepend_size = TLS_HEADER_SIZE + cipher_sz->iv;
+ prot->tag_size = cipher_sz->tag;
prot->overhead_size = prot->prepend_size + prot->tag_size;
- prot->iv_size = iv_size;
- prot->salt_size = salt_size;
- ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
- GFP_KERNEL);
+ prot->iv_size = cipher_sz->iv;
+ prot->salt_size = cipher_sz->salt;
+ ctx->tx.iv = kmalloc(cipher_sz->iv + cipher_sz->salt, GFP_KERNEL);
if (!ctx->tx.iv) {
rc = -ENOMEM;
goto release_netdev;
}
- memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
+ memcpy(ctx->tx.iv + cipher_sz->salt, iv, cipher_sz->iv);
- prot->rec_seq_size = rec_seq_size;
- ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
+ prot->rec_seq_size = cipher_sz->rec_seq;
+ ctx->tx.rec_seq = kmemdup(rec_seq, cipher_sz->rec_seq, GFP_KERNEL);
if (!ctx->tx.rec_seq) {
rc = -ENOMEM;
goto free_iv;
diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
index 7dfc8023e0f1..cdb391a8754b 100644
--- a/net/tls/tls_device_fallback.c
+++ b/net/tls/tls_device_fallback.c
@@ -54,13 +54,25 @@ static int tls_enc_record(struct aead_request *aead_req,
struct scatter_walk *out, int *in_len,
struct tls_prot_info *prot)
{
- unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE];
+ unsigned char buf[TLS_HEADER_SIZE + MAX_IV_SIZE];
+ const struct tls_cipher_size_desc *cipher_sz;
struct scatterlist sg_in[3];
struct scatterlist sg_out[3];
+ unsigned int buf_size;
u16 len;
int rc;
- len = min_t(int, *in_len, ARRAY_SIZE(buf));
+ switch (prot->cipher_type) {
+ case TLS_CIPHER_AES_GCM_128:
+ case TLS_CIPHER_AES_GCM_256:
+ break;
+ default:
+ return -EINVAL;
+ }
+ cipher_sz = &tls_cipher_size_desc[prot->cipher_type];
+
+ buf_size = TLS_HEADER_SIZE + cipher_sz->iv;
+ len = min_t(int, *in_len, buf_size);
scatterwalk_copychunks(buf, in, len, 0);
scatterwalk_copychunks(buf, out, len, 1);
@@ -73,13 +85,11 @@ static int tls_enc_record(struct aead_request *aead_req,
scatterwalk_pagedone(out, 1, 1);
len = buf[4] | (buf[3] << 8);
- len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
+ len -= cipher_sz->iv;
- tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
- (char *)&rcd_sn, buf[0], prot);
+ tls_make_aad(aad, len - cipher_sz->tag, (char *)&rcd_sn, buf[0], prot);
- memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
- TLS_CIPHER_AES_GCM_128_IV_SIZE);
+ memcpy(iv + cipher_sz->salt, buf + TLS_HEADER_SIZE, cipher_sz->iv);
sg_init_table(sg_in, ARRAY_SIZE(sg_in));
sg_init_table(sg_out, ARRAY_SIZE(sg_out));
@@ -90,7 +100,7 @@ static int tls_enc_record(struct aead_request *aead_req,
*in_len -= len;
if (*in_len < 0) {
- *in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+ *in_len += cipher_sz->tag;
/* the input buffer doesn't contain the entire record.
* trim len accordingly. The resulting authentication tag
* will contain garbage, but we don't care, so we won't
@@ -111,7 +121,7 @@ static int tls_enc_record(struct aead_request *aead_req,
scatterwalk_pagedone(out, 1, 1);
}
- len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+ len -= cipher_sz->tag;
aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
rc = crypto_aead_encrypt(aead_req);
@@ -299,11 +309,14 @@ static void fill_sg_out(struct scatterlist sg_out[3], void *buf,
int sync_size,
void *dummy_buf)
{
+ const struct tls_cipher_size_desc *cipher_sz =
+ &tls_cipher_size_desc[tls_ctx->crypto_send.info.cipher_type];
+
sg_set_buf(&sg_out[0], dummy_buf, sync_size);
sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset, payload_len);
/* Add room for authentication tag produced by crypto */
dummy_buf += sync_size;
- sg_set_buf(&sg_out[2], dummy_buf, TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+ sg_set_buf(&sg_out[2], dummy_buf, cipher_sz->tag);
}
static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
@@ -315,7 +328,8 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
int tcp_payload_offset = skb_tcp_all_headers(skb);
int payload_len = skb->len - tcp_payload_offset;
- void *buf, *iv, *aad, *dummy_buf;
+ const struct tls_cipher_size_desc *cipher_sz;
+ void *buf, *iv, *aad, *dummy_buf, *salt;
struct aead_request *aead_req;
struct sk_buff *nskb = NULL;
int buf_len;
@@ -324,20 +338,26 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
if (!aead_req)
return NULL;
- buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
- TLS_CIPHER_AES_GCM_128_IV_SIZE +
- TLS_AAD_SPACE_SIZE +
- sync_size +
- TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+ switch (tls_ctx->crypto_send.info.cipher_type) {
+ case TLS_CIPHER_AES_GCM_128:
+ salt = tls_ctx->crypto_send.aes_gcm_128.salt;
+ break;
+ case TLS_CIPHER_AES_GCM_256:
+ salt = tls_ctx->crypto_send.aes_gcm_256.salt;
+ break;
+ default:
+ return NULL;
+ }
+ cipher_sz = &tls_cipher_size_desc[tls_ctx->crypto_send.info.cipher_type];
+ buf_len = cipher_sz->salt + cipher_sz->iv + TLS_AAD_SPACE_SIZE +
+ sync_size + cipher_sz->tag;
buf = kmalloc(buf_len, GFP_ATOMIC);
if (!buf)
goto free_req;
iv = buf;
- memcpy(iv, tls_ctx->crypto_send.aes_gcm_128.salt,
- TLS_CIPHER_AES_GCM_128_SALT_SIZE);
- aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
- TLS_CIPHER_AES_GCM_128_IV_SIZE;
+ memcpy(iv, salt, cipher_sz->salt);
+ aad = buf + cipher_sz->salt + cipher_sz->iv;
dummy_buf = aad + TLS_AAD_SPACE_SIZE;
nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC);
@@ -451,6 +471,7 @@ int tls_sw_fallback_init(struct sock *sk,
struct tls_offload_context_tx *offload_ctx,
struct tls_crypto_info *crypto_info)
{
+ const struct tls_cipher_size_desc *cipher_sz;
const u8 *key;
int rc;
@@ -463,15 +484,23 @@ int tls_sw_fallback_init(struct sock *sk,
goto err_out;
}
- key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
+ switch (crypto_info->cipher_type) {
+ case TLS_CIPHER_AES_GCM_128:
+ key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
+ break;
+ case TLS_CIPHER_AES_GCM_256:
+ key = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->key;
+ break;
+ default:
+ return -EINVAL;
+ }
+ cipher_sz = &tls_cipher_size_desc[crypto_info->cipher_type];
- rc = crypto_aead_setkey(offload_ctx->aead_send, key,
- TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+ rc = crypto_aead_setkey(offload_ctx->aead_send, key, cipher_sz->key);
if (rc)
goto free_aead;
- rc = crypto_aead_setauthsize(offload_ctx->aead_send,
- TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+ rc = crypto_aead_setauthsize(offload_ctx->aead_send, cipher_sz->tag);
if (rc)
goto free_aead;
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 08ddf9d837ae..5cc6911cc97d 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -58,6 +58,23 @@ enum {
TLS_NUM_PROTS,
};
+#define CIPHER_SIZE_DESC(cipher) [cipher] = { \
+ .iv = cipher ## _IV_SIZE, \
+ .key = cipher ## _KEY_SIZE, \
+ .salt = cipher ## _SALT_SIZE, \
+ .tag = cipher ## _TAG_SIZE, \
+ .rec_seq = cipher ## _REC_SEQ_SIZE, \
+}
+
+const struct tls_cipher_size_desc tls_cipher_size_desc[] = {
+ CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_128),
+ CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_256),
+ CIPHER_SIZE_DESC(TLS_CIPHER_AES_CCM_128),
+ CIPHER_SIZE_DESC(TLS_CIPHER_CHACHA20_POLY1305),
+ CIPHER_SIZE_DESC(TLS_CIPHER_SM4_GCM),
+ CIPHER_SIZE_DESC(TLS_CIPHER_SM4_CCM),
+};
+
static const struct proto *saved_tcpv6_prot;
static DEFINE_MUTEX(tcpv6_prot_mutex);
static const struct proto *saved_tcpv4_prot;