diff options
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r-- | net/tls/tls_main.c | 66 |
1 files changed, 45 insertions, 21 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index a127d61e8af9..523622dc74f8 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -45,21 +45,13 @@ MODULE_AUTHOR("Mellanox Technologies"); MODULE_DESCRIPTION("Transport Layer Security Support"); MODULE_LICENSE("Dual BSD/GPL"); +MODULE_ALIAS_TCP_ULP("tls"); enum { TLSV4, TLSV6, TLS_NUM_PROTS, }; -enum { - TLS_BASE, - TLS_SW, -#ifdef CONFIG_TLS_DEVICE - TLS_HW, -#endif - TLS_HW_RECORD, - TLS_NUM_CONFIG, -}; static struct proto *saved_tcpv6_prot; static DEFINE_MUTEX(tcpv6_prot_mutex); @@ -221,9 +213,14 @@ static void tls_write_space(struct sock *sk) { struct tls_context *ctx = tls_get_ctx(sk); - /* We are already sending pages, ignore notification */ - if (ctx->in_tcp_sendpages) + /* If in_tcp_sendpages call lower protocol write space handler + * to ensure we wake up any waiting operations there. For example + * if do_tcp_sendpages where to call sk_wait_event. + */ + if (ctx->in_tcp_sendpages) { + ctx->sk_write_space(sk); return; + } if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) { gfp_t sk_allocation = sk->sk_allocation; @@ -244,6 +241,16 @@ static void tls_write_space(struct sock *sk) ctx->sk_write_space(sk); } +static void tls_ctx_free(struct tls_context *ctx) +{ + if (!ctx) + return; + + memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); + memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); + kfree(ctx); +} + static void tls_sk_proto_close(struct sock *sk, long timeout) { struct tls_context *ctx = tls_get_ctx(sk); @@ -290,11 +297,14 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) } #ifdef CONFIG_TLS_DEVICE - if (ctx->tx_conf != TLS_HW) { + if (ctx->rx_conf == TLS_HW) + tls_device_offload_cleanup_rx(sk); + + if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) { #else { #endif - kfree(ctx); + tls_ctx_free(ctx); ctx = NULL; } @@ -305,7 +315,7 @@ skip_tx_cleanup: * for sk->sk_prot->unhash [tls_hw_unhash] */ if (free_ctx) - kfree(ctx); + tls_ctx_free(ctx); } static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, @@ -330,7 +340,7 @@ static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, } /* get user crypto info */ - crypto_info = &ctx->crypto_send; + crypto_info = &ctx->crypto_send.info; if (!TLS_CRYPTO_INFO_READY(crypto_info)) { rc = -EBUSY; @@ -417,9 +427,9 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, } if (tx) - crypto_info = &ctx->crypto_send; + crypto_info = &ctx->crypto_send.info; else - crypto_info = &ctx->crypto_recv; + crypto_info = &ctx->crypto_recv.info; /* Currently we don't support set crypto info more than one time */ if (TLS_CRYPTO_INFO_READY(crypto_info)) { @@ -470,8 +480,16 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, conf = TLS_SW; } } else { - rc = tls_set_sw_offload(sk, ctx, 0); - conf = TLS_SW; +#ifdef CONFIG_TLS_DEVICE + rc = tls_set_device_offload_rx(sk, ctx); + conf = TLS_HW; + if (rc) { +#else + { +#endif + rc = tls_set_sw_offload(sk, ctx, 0); + conf = TLS_SW; + } } if (rc) @@ -491,7 +509,7 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, goto out; err_crypto_info: - memset(crypto_info, 0, sizeof(*crypto_info)); + memzero_explicit(crypto_info, sizeof(union tls_crypto_context)); out: return rc; } @@ -629,6 +647,12 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage; + + prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; + + prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; + + prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; #endif prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; @@ -712,7 +736,7 @@ static int __init tls_register(void) build_protos(tls_prots[TLSV4], &tcp_prot); tls_sw_proto_ops = inet_stream_ops; - tls_sw_proto_ops.poll_mask = tls_sw_poll_mask; + tls_sw_proto_ops.poll = tls_sw_poll; tls_sw_proto_ops.splice_read = tls_sw_splice_read; #ifdef CONFIG_TLS_DEVICE |