diff options
Diffstat (limited to 'net/tls/tls_main.c')
| -rw-r--r-- | net/tls/tls_main.c | 58 | 
1 files changed, 39 insertions, 19 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 311cec8e533d..78cb4a584080 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -55,8 +55,10 @@ enum {  static struct proto *saved_tcpv6_prot;  static DEFINE_MUTEX(tcpv6_prot_mutex); +static struct proto *saved_tcpv4_prot; +static DEFINE_MUTEX(tcpv4_prot_mutex);  static LIST_HEAD(device_list); -static DEFINE_MUTEX(device_mutex); +static DEFINE_SPINLOCK(device_spinlock);  static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];  static struct proto_ops tls_sw_proto_ops; @@ -538,11 +540,14 @@ static struct tls_context *create_ctx(struct sock *sk)  	struct inet_connection_sock *icsk = inet_csk(sk);  	struct tls_context *ctx; -	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); +	ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);  	if (!ctx)  		return NULL;  	icsk->icsk_ulp_data = ctx; +	ctx->setsockopt = sk->sk_prot->setsockopt; +	ctx->getsockopt = sk->sk_prot->getsockopt; +	ctx->sk_proto_close = sk->sk_prot->close;  	return ctx;  } @@ -552,7 +557,7 @@ static int tls_hw_prot(struct sock *sk)  	struct tls_device *dev;  	int rc = 0; -	mutex_lock(&device_mutex); +	spin_lock_bh(&device_spinlock);  	list_for_each_entry(dev, &device_list, dev_list) {  		if (dev->feature && dev->feature(dev)) {  			ctx = create_ctx(sk); @@ -570,7 +575,7 @@ static int tls_hw_prot(struct sock *sk)  		}  	}  out: -	mutex_unlock(&device_mutex); +	spin_unlock_bh(&device_spinlock);  	return rc;  } @@ -579,12 +584,17 @@ static void tls_hw_unhash(struct sock *sk)  	struct tls_context *ctx = tls_get_ctx(sk);  	struct tls_device *dev; -	mutex_lock(&device_mutex); +	spin_lock_bh(&device_spinlock);  	list_for_each_entry(dev, &device_list, dev_list) { -		if (dev->unhash) +		if (dev->unhash) { +			kref_get(&dev->kref); +			spin_unlock_bh(&device_spinlock);  			dev->unhash(dev, sk); +			kref_put(&dev->kref, dev->release); +			spin_lock_bh(&device_spinlock); +		}  	} -	mutex_unlock(&device_mutex); +	spin_unlock_bh(&device_spinlock);  	ctx->unhash(sk);  } @@ -595,12 +605,17 @@ static int tls_hw_hash(struct sock *sk)  	int err;  	err = ctx->hash(sk); -	mutex_lock(&device_mutex); +	spin_lock_bh(&device_spinlock);  	list_for_each_entry(dev, &device_list, dev_list) { -		if (dev->hash) +		if (dev->hash) { +			kref_get(&dev->kref); +			spin_unlock_bh(&device_spinlock);  			err |= dev->hash(dev, sk); +			kref_put(&dev->kref, dev->release); +			spin_lock_bh(&device_spinlock); +		}  	} -	mutex_unlock(&device_mutex); +	spin_unlock_bh(&device_spinlock);  	if (err)  		tls_hw_unhash(sk); @@ -675,9 +690,6 @@ static int tls_init(struct sock *sk)  		rc = -ENOMEM;  		goto out;  	} -	ctx->setsockopt = sk->sk_prot->setsockopt; -	ctx->getsockopt = sk->sk_prot->getsockopt; -	ctx->sk_proto_close = sk->sk_prot->close;  	/* Build IPv6 TLS whenever the address of tcpv6	_prot changes */  	if (ip_ver == TLSV6 && @@ -690,6 +702,16 @@ static int tls_init(struct sock *sk)  		mutex_unlock(&tcpv6_prot_mutex);  	} +	if (ip_ver == TLSV4 && +	    unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) { +		mutex_lock(&tcpv4_prot_mutex); +		if (likely(sk->sk_prot != saved_tcpv4_prot)) { +			build_protos(tls_prots[TLSV4], sk->sk_prot); +			smp_store_release(&saved_tcpv4_prot, sk->sk_prot); +		} +		mutex_unlock(&tcpv4_prot_mutex); +	} +  	ctx->tx_conf = TLS_BASE;  	ctx->rx_conf = TLS_BASE;  	update_sk_prot(sk, ctx); @@ -699,17 +721,17 @@ out:  void tls_register_device(struct tls_device *device)  { -	mutex_lock(&device_mutex); +	spin_lock_bh(&device_spinlock);  	list_add_tail(&device->dev_list, &device_list); -	mutex_unlock(&device_mutex); +	spin_unlock_bh(&device_spinlock);  }  EXPORT_SYMBOL(tls_register_device);  void tls_unregister_device(struct tls_device *device)  { -	mutex_lock(&device_mutex); +	spin_lock_bh(&device_spinlock);  	list_del(&device->dev_list); -	mutex_unlock(&device_mutex); +	spin_unlock_bh(&device_spinlock);  }  EXPORT_SYMBOL(tls_unregister_device); @@ -721,8 +743,6 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {  static int __init tls_register(void)  { -	build_protos(tls_prots[TLSV4], &tcp_prot); -  	tls_sw_proto_ops = inet_stream_ops;  	tls_sw_proto_ops.splice_read = tls_sw_splice_read;  | 
