summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--net/ipv4/udp.c56
-rw-r--r--net/ipv6/udp.c50
2 files changed, 106 insertions, 0 deletions
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index e8953e88efef..4bc0a0686fcd 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -420,6 +420,49 @@ u32 udp_ehashfn(const struct net *net, const __be32 laddr, const __u16 lport,
}
EXPORT_SYMBOL(udp_ehashfn);
+/**
+ * udp4_lib_lookup1() - Simplified lookup using primary hash (destination port)
+ * @net: Network namespace
+ * @saddr: Source address, network order
+ * @sport: Source port, network order
+ * @daddr: Destination address, network order
+ * @hnum: Destination port, host order
+ * @dif: Destination interface index
+ * @sdif: Destination bridge port index, if relevant
+ * @udptable: Set of UDP hash tables
+ *
+ * Simplified lookup to be used as fallback if no sockets are found due to a
+ * potential race between (receive) address change, and lookup happening before
+ * the rehash operation. This function ignores SO_REUSEPORT groups while scoring
+ * result sockets, because if we have one, we don't need the fallback at all.
+ *
+ * Called under rcu_read_lock().
+ *
+ * Return: socket with highest matching score if any, NULL if none
+ */
+static struct sock *udp4_lib_lookup1(const struct net *net,
+ __be32 saddr, __be16 sport,
+ __be32 daddr, unsigned int hnum,
+ int dif, int sdif,
+ const struct udp_table *udptable)
+{
+ unsigned int slot = udp_hashfn(net, hnum, udptable->mask);
+ struct udp_hslot *hslot = &udptable->hash[slot];
+ struct sock *sk, *result = NULL;
+ int score, badness = 0;
+
+ sk_for_each_rcu(sk, &hslot->head) {
+ score = compute_score(sk, net,
+ saddr, sport, daddr, hnum, dif, sdif);
+ if (score > badness) {
+ result = sk;
+ badness = score;
+ }
+ }
+
+ return result;
+}
+
/* called with rcu_read_lock() */
static struct sock *udp4_lib_lookup2(const struct net *net,
__be32 saddr, __be16 sport,
@@ -683,6 +726,19 @@ struct sock *__udp4_lib_lookup(const struct net *net, __be32 saddr,
result = udp4_lib_lookup2(net, saddr, sport,
htonl(INADDR_ANY), hnum, dif, sdif,
hslot2, skb);
+ if (!IS_ERR_OR_NULL(result))
+ goto done;
+
+ /* Primary hash (destination port) lookup as fallback for this race:
+ * 1. __ip4_datagram_connect() sets sk_rcv_saddr
+ * 2. lookup (this function): new sk_rcv_saddr, hashes not updated yet
+ * 3. rehash operation updating _secondary and four-tuple_ hashes
+ * The primary hash doesn't need an update after 1., so, thanks to this
+ * further step, 1. and 3. don't need to be atomic against the lookup.
+ */
+ result = udp4_lib_lookup1(net, saddr, sport, daddr, hnum, dif, sdif,
+ udptable);
+
done:
if (IS_ERR(result))
return NULL;
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index 7c14c449804c..6671daa67f4f 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -170,6 +170,49 @@ static int compute_score(struct sock *sk, const struct net *net,
return score;
}
+/**
+ * udp6_lib_lookup1() - Simplified lookup using primary hash (destination port)
+ * @net: Network namespace
+ * @saddr: Source address, network order
+ * @sport: Source port, network order
+ * @daddr: Destination address, network order
+ * @hnum: Destination port, host order
+ * @dif: Destination interface index
+ * @sdif: Destination bridge port index, if relevant
+ * @udptable: Set of UDP hash tables
+ *
+ * Simplified lookup to be used as fallback if no sockets are found due to a
+ * potential race between (receive) address change, and lookup happening before
+ * the rehash operation. This function ignores SO_REUSEPORT groups while scoring
+ * result sockets, because if we have one, we don't need the fallback at all.
+ *
+ * Called under rcu_read_lock().
+ *
+ * Return: socket with highest matching score if any, NULL if none
+ */
+static struct sock *udp6_lib_lookup1(const struct net *net,
+ const struct in6_addr *saddr, __be16 sport,
+ const struct in6_addr *daddr,
+ unsigned int hnum, int dif, int sdif,
+ const struct udp_table *udptable)
+{
+ unsigned int slot = udp_hashfn(net, hnum, udptable->mask);
+ struct udp_hslot *hslot = &udptable->hash[slot];
+ struct sock *sk, *result = NULL;
+ int score, badness = 0;
+
+ sk_for_each_rcu(sk, &hslot->head) {
+ score = compute_score(sk, net,
+ saddr, sport, daddr, hnum, dif, sdif);
+ if (score > badness) {
+ result = sk;
+ badness = score;
+ }
+ }
+
+ return result;
+}
+
/* called with rcu_read_lock() */
static struct sock *udp6_lib_lookup2(const struct net *net,
const struct in6_addr *saddr, __be16 sport,
@@ -347,6 +390,13 @@ struct sock *__udp6_lib_lookup(const struct net *net,
result = udp6_lib_lookup2(net, saddr, sport,
&in6addr_any, hnum, dif, sdif,
hslot2, skb);
+ if (!IS_ERR_OR_NULL(result))
+ goto done;
+
+ /* Cover address change/lookup/rehash race: see __udp4_lib_lookup() */
+ result = udp6_lib_lookup1(net, saddr, sport, daddr, hnum, dif, sdif,
+ udptable);
+
done:
if (IS_ERR(result))
return NULL;