summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/linux/bpf-cgroup.h2
-rw-r--r--include/net/inet_sock.h8
-rw-r--r--include/net/ip.h3
-rw-r--r--include/net/sock.h19
-rw-r--r--net/core/filter.c6
-rw-r--r--net/core/sock.c9
-rw-r--r--net/ipv4/ip_output.c5
-rw-r--r--net/ipv4/tcp_ipv4.c4
-rw-r--r--net/ipv4/tcp_output.c2
-rw-r--r--net/ipv6/tcp_ipv6.c3
-rw-r--r--net/sched/sch_fq.c3
11 files changed, 44 insertions, 20 deletions
diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h
index ce91d9b2acb9..f0f219271daf 100644
--- a/include/linux/bpf-cgroup.h
+++ b/include/linux/bpf-cgroup.h
@@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk,
int __ret = 0; \
if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) { \
typeof(sk) __sk = sk_to_full_sk(sk); \
- if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) && \
+ if (__sk && __sk == skb_to_full_sk(skb) && \
cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS)) \
__ret = __cgroup_bpf_run_filter_skb(__sk, skb, \
CGROUP_INET_EGRESS); \
diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h
index f01dd273bea6..56d8bc5593d3 100644
--- a/include/net/inet_sock.h
+++ b/include/net/inet_sock.h
@@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet)
static inline struct sock *sk_to_full_sk(struct sock *sk)
{
#ifdef CONFIG_INET
- if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
+ if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
sk = inet_reqsk(sk)->rsk_listener;
+ if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
+ sk = NULL;
#endif
return sk;
}
@@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk)
static inline const struct sock *sk_const_to_full_sk(const struct sock *sk)
{
#ifdef CONFIG_INET
- if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
+ if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
sk = ((const struct request_sock *)sk)->rsk_listener;
+ if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
+ sk = NULL;
#endif
return sk;
}
diff --git a/include/net/ip.h b/include/net/ip.h
index bab084df1567..4be0a6a603b2 100644
--- a/include/net/ip.h
+++ b/include/net/ip.h
@@ -288,7 +288,8 @@ static inline __u8 ip_reply_arg_flowi_flags(const struct ip_reply_arg *arg)
return (arg->flags & IP_REPLY_ARG_NOSRCCHECK) ? FLOWI_FLAG_ANYSRC : 0;
}
-void ip_send_unicast_reply(struct sock *sk, struct sk_buff *skb,
+void ip_send_unicast_reply(struct sock *sk, const struct sock *orig_sk,
+ struct sk_buff *skb,
const struct ip_options *sopt,
__be32 daddr, __be32 saddr,
const struct ip_reply_arg *arg,
diff --git a/include/net/sock.h b/include/net/sock.h
index 6da420ab1ee1..bf7fa3db10ae 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1760,6 +1760,15 @@ void sock_efree(struct sk_buff *skb);
#ifdef CONFIG_INET
void sock_edemux(struct sk_buff *skb);
void sock_pfree(struct sk_buff *skb);
+
+static inline void skb_set_owner_edemux(struct sk_buff *skb, struct sock *sk)
+{
+ skb_orphan(skb);
+ if (refcount_inc_not_zero(&sk->sk_refcnt)) {
+ skb->sk = sk;
+ skb->destructor = sock_edemux;
+ }
+}
#else
#define sock_edemux sock_efree
#endif
@@ -2802,6 +2811,16 @@ static inline bool sk_listener(const struct sock *sk)
return (1 << sk->sk_state) & (TCPF_LISTEN | TCPF_NEW_SYN_RECV);
}
+/* This helper checks if a socket is a LISTEN or NEW_SYN_RECV or TIME_WAIT
+ * TCP SYNACK messages can be attached to LISTEN or NEW_SYN_RECV (depending on SYNCOOKIE)
+ * TCP RST and ACK can be attached to TIME_WAIT.
+ */
+static inline bool sk_listener_or_tw(const struct sock *sk)
+{
+ return (1 << READ_ONCE(sk->sk_state)) &
+ (TCPF_LISTEN | TCPF_NEW_SYN_RECV | TCPF_TIME_WAIT);
+}
+
void sock_enable_timestamp(struct sock *sk, enum sock_flags flag);
int sock_recv_errqueue(struct sock *sk, struct msghdr *msg, int len, int level,
int type);
diff --git a/net/core/filter.c b/net/core/filter.c
index bd0d08bf76bb..202c1d386e19 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
/* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
* sock refcnt is decremented to prevent a request_sock leak.
*/
- if (!sk_fullsock(sk2))
- sk2 = NULL;
if (sk2 != sk) {
sock_gen_put(sk);
/* Ensure there is no need to bump sk2 refcnt */
@@ -6826,8 +6824,6 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
/* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
* sock refcnt is decremented to prevent a request_sock leak.
*/
- if (!sk_fullsock(sk2))
- sk2 = NULL;
if (sk2 != sk) {
sock_gen_put(sk);
/* Ensure there is no need to bump sk2 refcnt */
@@ -7276,7 +7272,7 @@ BPF_CALL_1(bpf_get_listener_sock, struct sock *, sk)
{
sk = sk_to_full_sk(sk);
- if (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
+ if (sk && sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
return (unsigned long)sk;
return (unsigned long)NULL;
diff --git a/net/core/sock.c b/net/core/sock.c
index 083d438d8b6f..f8c0d4eda888 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -2592,14 +2592,11 @@ void __sock_wfree(struct sk_buff *skb)
void skb_set_owner_w(struct sk_buff *skb, struct sock *sk)
{
skb_orphan(skb);
- skb->sk = sk;
#ifdef CONFIG_INET
- if (unlikely(!sk_fullsock(sk))) {
- skb->destructor = sock_edemux;
- sock_hold(sk);
- return;
- }
+ if (unlikely(!sk_fullsock(sk)))
+ return skb_set_owner_edemux(skb, sk);
#endif
+ skb->sk = sk;
skb->destructor = sock_wfree;
skb_set_hash_from_sk(skb, sk);
/*
diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
index e5c55a95063d..0065b1996c94 100644
--- a/net/ipv4/ip_output.c
+++ b/net/ipv4/ip_output.c
@@ -1596,7 +1596,8 @@ static int ip_reply_glue_bits(void *dptr, char *to, int offset,
* Generic function to send a packet as reply to another packet.
* Used to send some TCP resets/acks so far.
*/
-void ip_send_unicast_reply(struct sock *sk, struct sk_buff *skb,
+void ip_send_unicast_reply(struct sock *sk, const struct sock *orig_sk,
+ struct sk_buff *skb,
const struct ip_options *sopt,
__be32 daddr, __be32 saddr,
const struct ip_reply_arg *arg,
@@ -1662,6 +1663,8 @@ void ip_send_unicast_reply(struct sock *sk, struct sk_buff *skb,
arg->csumoffset) = csum_fold(csum_add(nskb->csum,
arg->csum));
nskb->ip_summed = CHECKSUM_NONE;
+ if (orig_sk)
+ skb_set_owner_edemux(nskb, (struct sock *)orig_sk);
if (transmit_time)
nskb->tstamp_type = SKB_CLOCK_MONOTONIC;
if (txhash)
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index 985028434f64..9d3dd101ea71 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -907,7 +907,7 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb,
ctl_sk->sk_mark = 0;
ctl_sk->sk_priority = 0;
}
- ip_send_unicast_reply(ctl_sk,
+ ip_send_unicast_reply(ctl_sk, sk,
skb, &TCP_SKB_CB(skb)->header.h4.opt,
ip_hdr(skb)->saddr, ip_hdr(skb)->daddr,
&arg, arg.iov[0].iov_len,
@@ -1021,7 +1021,7 @@ static void tcp_v4_send_ack(const struct sock *sk,
ctl_sk->sk_priority = (sk->sk_state == TCP_TIME_WAIT) ?
inet_twsk(sk)->tw_priority : READ_ONCE(sk->sk_priority);
transmit_time = tcp_transmit_time(sk);
- ip_send_unicast_reply(ctl_sk,
+ ip_send_unicast_reply(ctl_sk, sk,
skb, &TCP_SKB_CB(skb)->header.h4.opt,
ip_hdr(skb)->saddr, ip_hdr(skb)->daddr,
&arg, arg.iov[0].iov_len,
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 06200bb111f8..054244ce5117 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -3728,7 +3728,7 @@ struct sk_buff *tcp_make_synack(const struct sock *sk, struct dst_entry *dst,
switch (synack_type) {
case TCP_SYNACK_NORMAL:
- skb_set_owner_w(skb, req_to_sk(req));
+ skb_set_owner_edemux(skb, req_to_sk(req));
break;
case TCP_SYNACK_COOKIE:
/* Under synflood, we do not attach skb to a socket,
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index 7634c0be6acb..597920061a3a 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -967,6 +967,9 @@ static void tcp_v6_send_response(const struct sock *sk, struct sk_buff *skb, u32
}
if (sk) {
+ /* unconstify the socket only to attach it to buff with care. */
+ skb_set_owner_edemux(buff, (struct sock *)sk);
+
if (sk->sk_state == TCP_TIME_WAIT)
mark = inet_twsk(sk)->tw_mark;
else
diff --git a/net/sched/sch_fq.c b/net/sched/sch_fq.c
index aeabf45c9200..a97638bef6da 100644
--- a/net/sched/sch_fq.c
+++ b/net/sched/sch_fq.c
@@ -362,8 +362,9 @@ static struct fq_flow *fq_classify(struct Qdisc *sch, struct sk_buff *skb,
* 3) We do not want to rate limit them (eg SYNFLOOD attack),
* especially if the listener set SO_MAX_PACING_RATE
* 4) We pretend they are orphaned
+ * TCP can also associate TIME_WAIT sockets with RST or ACK packets.
*/
- if (!sk || sk_listener(sk)) {
+ if (!sk || sk_listener_or_tw(sk)) {
unsigned long hash = skb_get_hash(skb) & q->orphan_mask;
/* By forcing low order bit to 1, we make sure to not