diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls_device.c | 22 | ||||
-rw-r--r-- | net/tls/tls_main.c | 31 |
2 files changed, 37 insertions, 16 deletions
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 1ba5a92832bb..a562ebaaa33c 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -366,7 +366,7 @@ static int tls_do_allocation(struct sock *sk, if (!offload_ctx->open_record) { if (unlikely(!skb_page_frag_refill(prepend_size, pfrag, sk->sk_allocation))) { - sk->sk_prot->enter_memory_pressure(sk); + READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk); sk_stream_moderate_sndbuf(sk); return -ENOMEM; } @@ -593,7 +593,7 @@ struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context, u32 seq, u64 *p_record_sn) { u64 record_sn = context->hint_record_sn; - struct tls_record_info *info; + struct tls_record_info *info, *last; info = context->retransmit_hint; if (!info || @@ -605,6 +605,24 @@ struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context, struct tls_record_info, list); if (!info) return NULL; + /* send the start_marker record if seq number is before the + * tls offload start marker sequence number. This record is + * required to handle TCP packets which are before TLS offload + * started. + * And if it's not start marker, look if this seq number + * belongs to the list. + */ + if (likely(!tls_record_is_start_marker(info))) { + /* we have the first record, get the last record to see + * if this seq number belongs to the list. + */ + last = list_last_entry(&context->records_list, + struct tls_record_info, list); + + if (!between(seq, tls_record_start_seq(info), + last->end_seq)) + return NULL; + } record_sn = context->unacked_record_sn; } diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 94774c0e5ff3..156efce50dbd 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -63,13 +63,14 @@ static DEFINE_MUTEX(tcpv4_prot_mutex); static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; static struct proto_ops tls_sw_proto_ops; static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], - struct proto *base); + const struct proto *base); void update_sk_prot(struct sock *sk, struct tls_context *ctx) { int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; - sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]; + WRITE_ONCE(sk->sk_prot, + &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]); } int wait_on_pending_writer(struct sock *sk, long *timeo) @@ -312,7 +313,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) write_lock_bh(&sk->sk_callback_lock); if (free_ctx) rcu_assign_pointer(icsk->icsk_ulp_data, NULL); - sk->sk_prot = ctx->sk_proto; + WRITE_ONCE(sk->sk_prot, ctx->sk_proto); if (sk->sk_write_space == tls_write_space) sk->sk_write_space = ctx->sk_write_space; write_unlock_bh(&sk->sk_callback_lock); @@ -621,38 +622,39 @@ struct tls_context *tls_ctx_create(struct sock *sk) mutex_init(&ctx->tx_lock); rcu_assign_pointer(icsk->icsk_ulp_data, ctx); - ctx->sk_proto = sk->sk_prot; + ctx->sk_proto = READ_ONCE(sk->sk_prot); return ctx; } static void tls_build_proto(struct sock *sk) { int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; + const struct proto *prot = READ_ONCE(sk->sk_prot); /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ if (ip_ver == TLSV6 && - unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) { + unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) { mutex_lock(&tcpv6_prot_mutex); - if (likely(sk->sk_prot != saved_tcpv6_prot)) { - build_protos(tls_prots[TLSV6], sk->sk_prot); - smp_store_release(&saved_tcpv6_prot, sk->sk_prot); + if (likely(prot != saved_tcpv6_prot)) { + build_protos(tls_prots[TLSV6], prot); + smp_store_release(&saved_tcpv6_prot, prot); } mutex_unlock(&tcpv6_prot_mutex); } if (ip_ver == TLSV4 && - unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) { + unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) { mutex_lock(&tcpv4_prot_mutex); - if (likely(sk->sk_prot != saved_tcpv4_prot)) { - build_protos(tls_prots[TLSV4], sk->sk_prot); - smp_store_release(&saved_tcpv4_prot, sk->sk_prot); + if (likely(prot != saved_tcpv4_prot)) { + build_protos(tls_prots[TLSV4], prot); + smp_store_release(&saved_tcpv4_prot, prot); } mutex_unlock(&tcpv4_prot_mutex); } } static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], - struct proto *base) + const struct proto *base) { prot[TLS_BASE][TLS_BASE] = *base; prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; @@ -742,7 +744,8 @@ static void tls_update(struct sock *sk, struct proto *p, ctx->sk_write_space = write_space; ctx->sk_proto = p; } else { - sk->sk_prot = p; + /* Pairs with lockless read in sk_clone_lock(). */ + WRITE_ONCE(sk->sk_prot, p); sk->sk_write_space = write_space; } } |