diff options
Diffstat (limited to 'net/ipv4/udp.c')
| -rw-r--r-- | net/ipv4/udp.c | 152 | 
1 files changed, 118 insertions, 34 deletions
| diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 24ec14f9825c..dc45b538e237 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -100,7 +100,6 @@  #include <linux/slab.h>  #include <net/tcp_states.h>  #include <linux/skbuff.h> -#include <linux/netdevice.h>  #include <linux/proc_fs.h>  #include <linux/seq_file.h>  #include <net/net_namespace.h> @@ -114,6 +113,7 @@  #include <trace/events/skb.h>  #include <net/busy_poll.h>  #include "udp_impl.h" +#include <net/sock_reuseport.h>  struct udp_table udp_table __read_mostly;  EXPORT_SYMBOL(udp_table); @@ -138,7 +138,8 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,  			       unsigned long *bitmap,  			       struct sock *sk,  			       int (*saddr_comp)(const struct sock *sk1, -						 const struct sock *sk2), +						 const struct sock *sk2, +						 bool match_wildcard),  			       unsigned int log)  {  	struct sock *sk2; @@ -153,8 +154,9 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,  		    (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||  		     sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&  		    (!sk2->sk_reuseport || !sk->sk_reuseport || +		     rcu_access_pointer(sk->sk_reuseport_cb) ||  		     !uid_eq(uid, sock_i_uid(sk2))) && -		    saddr_comp(sk, sk2)) { +		    saddr_comp(sk, sk2, true)) {  			if (!bitmap)  				return 1;  			__set_bit(udp_sk(sk2)->udp_port_hash >> log, bitmap); @@ -171,7 +173,8 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,  				struct udp_hslot *hslot2,  				struct sock *sk,  				int (*saddr_comp)(const struct sock *sk1, -						  const struct sock *sk2)) +						  const struct sock *sk2, +						  bool match_wildcard))  {  	struct sock *sk2;  	struct hlist_nulls_node *node; @@ -187,8 +190,9 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,  		    (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||  		     sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&  		    (!sk2->sk_reuseport || !sk->sk_reuseport || +		     rcu_access_pointer(sk->sk_reuseport_cb) ||  		     !uid_eq(uid, sock_i_uid(sk2))) && -		    saddr_comp(sk, sk2)) { +		    saddr_comp(sk, sk2, true)) {  			res = 1;  			break;  		} @@ -197,6 +201,35 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,  	return res;  } +static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot, +				  int (*saddr_same)(const struct sock *sk1, +						    const struct sock *sk2, +						    bool match_wildcard)) +{ +	struct net *net = sock_net(sk); +	struct hlist_nulls_node *node; +	kuid_t uid = sock_i_uid(sk); +	struct sock *sk2; + +	sk_nulls_for_each(sk2, node, &hslot->head) { +		if (net_eq(sock_net(sk2), net) && +		    sk2 != sk && +		    sk2->sk_family == sk->sk_family && +		    ipv6_only_sock(sk2) == ipv6_only_sock(sk) && +		    (udp_sk(sk2)->udp_port_hash == udp_sk(sk)->udp_port_hash) && +		    (sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && +		    sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) && +		    (*saddr_same)(sk, sk2, false)) { +			return reuseport_add_sock(sk, sk2); +		} +	} + +	/* Initial allocation may have already happened via setsockopt */ +	if (!rcu_access_pointer(sk->sk_reuseport_cb)) +		return reuseport_alloc(sk); +	return 0; +} +  /**   *  udp_lib_get_port  -  UDP/-Lite port lookup for IPv4 and IPv6   * @@ -208,7 +241,8 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,   */  int udp_lib_get_port(struct sock *sk, unsigned short snum,  		     int (*saddr_comp)(const struct sock *sk1, -				       const struct sock *sk2), +				       const struct sock *sk2, +				       bool match_wildcard),  		     unsigned int hash2_nulladdr)  {  	struct udp_hslot *hslot, *hslot2; @@ -291,6 +325,14 @@ found:  	udp_sk(sk)->udp_port_hash = snum;  	udp_sk(sk)->udp_portaddr_hash ^= snum;  	if (sk_unhashed(sk)) { +		if (sk->sk_reuseport && +		    udp_reuseport_add_sock(sk, hslot, saddr_comp)) { +			inet_sk(sk)->inet_num = 0; +			udp_sk(sk)->udp_port_hash = 0; +			udp_sk(sk)->udp_portaddr_hash ^= snum; +			goto fail_unlock; +		} +  		sk_nulls_add_node_rcu(sk, &hslot->head);  		hslot->count++;  		sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); @@ -310,13 +352,22 @@ fail:  }  EXPORT_SYMBOL(udp_lib_get_port); -static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2) +/* match_wildcard == true:  0.0.0.0 equals to any IPv4 addresses + * match_wildcard == false: addresses must be exactly the same, i.e. + *                          0.0.0.0 only equals to 0.0.0.0 + */ +static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2, +				bool match_wildcard)  {  	struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2); -	return 	(!ipv6_only_sock(sk2)  && -		 (!inet1->inet_rcv_saddr || !inet2->inet_rcv_saddr || -		   inet1->inet_rcv_saddr == inet2->inet_rcv_saddr)); +	if (!ipv6_only_sock(sk2)) { +		if (inet1->inet_rcv_saddr == inet2->inet_rcv_saddr) +			return 1; +		if (!inet1->inet_rcv_saddr || !inet2->inet_rcv_saddr) +			return match_wildcard; +	} +	return 0;  }  static u32 udp4_portaddr_hash(const struct net *net, __be32 saddr, @@ -442,7 +493,8 @@ static u32 udp_ehashfn(const struct net *net, const __be32 laddr,  static struct sock *udp4_lib_lookup2(struct net *net,  		__be32 saddr, __be16 sport,  		__be32 daddr, unsigned int hnum, int dif, -		struct udp_hslot *hslot2, unsigned int slot2) +		struct udp_hslot *hslot2, unsigned int slot2, +		struct sk_buff *skb)  {  	struct sock *sk, *result;  	struct hlist_nulls_node *node; @@ -460,8 +512,15 @@ begin:  			badness = score;  			reuseport = sk->sk_reuseport;  			if (reuseport) { +				struct sock *sk2;  				hash = udp_ehashfn(net, daddr, hnum,  						   saddr, sport); +				sk2 = reuseport_select_sock(sk, hash, skb, +							    sizeof(struct udphdr)); +				if (sk2) { +					result = sk2; +					goto found; +				}  				matches = 1;  			}  		} else if (score == badness && reuseport) { @@ -479,6 +538,7 @@ begin:  	if (get_nulls_value(node) != slot2)  		goto begin;  	if (result) { +found:  		if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))  			result = NULL;  		else if (unlikely(compute_score2(result, net, saddr, sport, @@ -495,7 +555,7 @@ begin:   */  struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,  		__be16 sport, __be32 daddr, __be16 dport, -		int dif, struct udp_table *udptable) +		int dif, struct udp_table *udptable, struct sk_buff *skb)  {  	struct sock *sk, *result;  	struct hlist_nulls_node *node; @@ -515,7 +575,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,  		result = udp4_lib_lookup2(net, saddr, sport,  					  daddr, hnum, dif, -					  hslot2, slot2); +					  hslot2, slot2, skb);  		if (!result) {  			hash2 = udp4_portaddr_hash(net, htonl(INADDR_ANY), hnum);  			slot2 = hash2 & udptable->mask; @@ -525,7 +585,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,  			result = udp4_lib_lookup2(net, saddr, sport,  						  htonl(INADDR_ANY), hnum, dif, -						  hslot2, slot2); +						  hslot2, slot2, skb);  		}  		rcu_read_unlock();  		return result; @@ -541,8 +601,15 @@ begin:  			badness = score;  			reuseport = sk->sk_reuseport;  			if (reuseport) { +				struct sock *sk2;  				hash = udp_ehashfn(net, daddr, hnum,  						   saddr, sport); +				sk2 = reuseport_select_sock(sk, hash, skb, +							sizeof(struct udphdr)); +				if (sk2) { +					result = sk2; +					goto found; +				}  				matches = 1;  			}  		} else if (score == badness && reuseport) { @@ -561,6 +628,7 @@ begin:  		goto begin;  	if (result) { +found:  		if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))  			result = NULL;  		else if (unlikely(compute_score(result, net, saddr, hnum, sport, @@ -582,13 +650,14 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,  	return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport,  				 iph->daddr, dport, inet_iif(skb), -				 udptable); +				 udptable, skb);  }  struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,  			     __be32 daddr, __be16 dport, int dif)  { -	return __udp4_lib_lookup(net, saddr, sport, daddr, dport, dif, &udp_table); +	return __udp4_lib_lookup(net, saddr, sport, daddr, dport, dif, +				 &udp_table, NULL);  }  EXPORT_SYMBOL_GPL(udp4_lib_lookup); @@ -636,7 +705,8 @@ void __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)  	struct net *net = dev_net(skb->dev);  	sk = __udp4_lib_lookup(net, iph->daddr, uh->dest, -			iph->saddr, uh->source, skb->dev->ifindex, udptable); +			iph->saddr, uh->source, skb->dev->ifindex, udptable, +			NULL);  	if (!sk) {  		ICMP_INC_STATS_BH(net, ICMP_MIB_INERRORS);  		return;	/* No socket for error */ @@ -773,7 +843,8 @@ void udp_set_csum(bool nocheck, struct sk_buff *skb,  	else if (skb_is_gso(skb))  		uh->check = ~udp_v4_check(len, saddr, daddr, 0);  	else if (skb_dst(skb) && skb_dst(skb)->dev && -		 (skb_dst(skb)->dev->features & NETIF_F_V4_CSUM)) { +		 (skb_dst(skb)->dev->features & +		  (NETIF_F_IP_CSUM | NETIF_F_HW_CSUM))) {  		BUG_ON(skb->ip_summed == CHECKSUM_PARTIAL); @@ -1026,8 +1097,11 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)  				   flow_flags,  				   faddr, saddr, dport, inet->inet_sport); -		if (!saddr && ipc.oif) -			l3mdev_get_saddr(net, ipc.oif, fl4); +		if (!saddr && ipc.oif) { +			err = l3mdev_get_saddr(net, ipc.oif, fl4); +			if (err < 0) +				goto out; +		}  		security_sk_classify_flow(sk, flowi4_to_flowi(fl4));  		rt = ip_route_output_flow(net, fl4, sk); @@ -1271,6 +1345,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock,  	int peeked, off = 0;  	int err;  	int is_udplite = IS_UDPLITE(sk); +	bool checksum_valid = false;  	bool slow;  	if (flags & MSG_ERRQUEUE) @@ -1296,11 +1371,12 @@ try_again:  	 */  	if (copied < ulen || UDP_SKB_CB(skb)->partial_cov) { -		if (udp_lib_checksum_complete(skb)) +		checksum_valid = !udp_lib_checksum_complete(skb); +		if (!checksum_valid)  			goto csum_copy_err;  	} -	if (skb_csum_unnecessary(skb)) +	if (checksum_valid || skb_csum_unnecessary(skb))  		err = skb_copy_datagram_msg(skb, sizeof(struct udphdr),  					    msg, copied);  	else { @@ -1396,6 +1472,8 @@ void udp_lib_unhash(struct sock *sk)  		hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash);  		spin_lock_bh(&hslot->lock); +		if (rcu_access_pointer(sk->sk_reuseport_cb)) +			reuseport_detach_sock(sk);  		if (sk_nulls_del_node_init_rcu(sk)) {  			hslot->count--;  			inet_sk(sk)->inet_num = 0; @@ -1423,22 +1501,28 @@ void udp_lib_rehash(struct sock *sk, u16 newhash)  		hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash);  		nhslot2 = udp_hashslot2(udptable, newhash);  		udp_sk(sk)->udp_portaddr_hash = newhash; -		if (hslot2 != nhslot2) { + +		if (hslot2 != nhslot2 || +		    rcu_access_pointer(sk->sk_reuseport_cb)) {  			hslot = udp_hashslot(udptable, sock_net(sk),  					     udp_sk(sk)->udp_port_hash);  			/* we must lock primary chain too */  			spin_lock_bh(&hslot->lock); - -			spin_lock(&hslot2->lock); -			hlist_nulls_del_init_rcu(&udp_sk(sk)->udp_portaddr_node); -			hslot2->count--; -			spin_unlock(&hslot2->lock); - -			spin_lock(&nhslot2->lock); -			hlist_nulls_add_head_rcu(&udp_sk(sk)->udp_portaddr_node, -						 &nhslot2->head); -			nhslot2->count++; -			spin_unlock(&nhslot2->lock); +			if (rcu_access_pointer(sk->sk_reuseport_cb)) +				reuseport_detach_sock(sk); + +			if (hslot2 != nhslot2) { +				spin_lock(&hslot2->lock); +				hlist_nulls_del_init_rcu(&udp_sk(sk)->udp_portaddr_node); +				hslot2->count--; +				spin_unlock(&hslot2->lock); + +				spin_lock(&nhslot2->lock); +				hlist_nulls_add_head_rcu(&udp_sk(sk)->udp_portaddr_node, +							 &nhslot2->head); +				nhslot2->count++; +				spin_unlock(&nhslot2->lock); +			}  			spin_unlock_bh(&hslot->lock);  		} | 
