summaryrefslogtreecommitdiff
path: root/drivers/net/wireguard
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/net/wireguard')
-rw-r--r--drivers/net/wireguard/device.c59
-rw-r--r--drivers/net/wireguard/device.h3
-rw-r--r--drivers/net/wireguard/netlink.c14
-rw-r--r--drivers/net/wireguard/noise.c4
-rw-r--r--drivers/net/wireguard/queueing.h19
-rw-r--r--drivers/net/wireguard/receive.c12
-rw-r--r--drivers/net/wireguard/socket.c25
7 files changed, 63 insertions, 73 deletions
diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c
index 3ac3f8570ca1..c9f65e96ccb0 100644
--- a/drivers/net/wireguard/device.c
+++ b/drivers/net/wireguard/device.c
@@ -45,17 +45,18 @@ static int wg_open(struct net_device *dev)
if (dev_v6)
dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE;
+ mutex_lock(&wg->device_update_lock);
ret = wg_socket_init(wg, wg->incoming_port);
if (ret < 0)
- return ret;
- mutex_lock(&wg->device_update_lock);
+ goto out;
list_for_each_entry(peer, &wg->peer_list, peer_list) {
wg_packet_send_staged_packets(peer);
if (peer->persistent_keepalive_interval)
wg_packet_send_keepalive(peer);
}
+out:
mutex_unlock(&wg->device_update_lock);
- return 0;
+ return ret;
}
#ifdef CONFIG_PM_SLEEP
@@ -225,6 +226,7 @@ static void wg_destruct(struct net_device *dev)
list_del(&wg->device_list);
rtnl_unlock();
mutex_lock(&wg->device_update_lock);
+ rcu_assign_pointer(wg->creating_net, NULL);
wg->incoming_port = 0;
wg_socket_reinit(wg, NULL, NULL);
/* The final references are cleared in the below calls to destroy_workqueue. */
@@ -240,13 +242,11 @@ static void wg_destruct(struct net_device *dev)
skb_queue_purge(&wg->incoming_handshakes);
free_percpu(dev->tstats);
free_percpu(wg->incoming_handshakes_worker);
- if (wg->have_creating_net_ref)
- put_net(wg->creating_net);
kvfree(wg->index_hashtable);
kvfree(wg->peer_hashtable);
mutex_unlock(&wg->device_update_lock);
- pr_debug("%s: Interface deleted\n", dev->name);
+ pr_debug("%s: Interface destroyed\n", dev->name);
free_netdev(dev);
}
@@ -262,6 +262,7 @@ static void wg_setup(struct net_device *dev)
max(sizeof(struct ipv6hdr), sizeof(struct iphdr));
dev->netdev_ops = &netdev_ops;
+ dev->header_ops = &ip_tunnel_header_ops;
dev->hard_header_len = 0;
dev->addr_len = 0;
dev->needed_headroom = DATA_PACKET_HEAD_ROOM;
@@ -292,7 +293,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
struct wg_device *wg = netdev_priv(dev);
int ret = -ENOMEM;
- wg->creating_net = src_net;
+ rcu_assign_pointer(wg->creating_net, src_net);
init_rwsem(&wg->static_identity.lock);
mutex_init(&wg->socket_update_lock);
mutex_init(&wg->device_update_lock);
@@ -393,30 +394,26 @@ static struct rtnl_link_ops link_ops __read_mostly = {
.newlink = wg_newlink,
};
-static int wg_netdevice_notification(struct notifier_block *nb,
- unsigned long action, void *data)
+static void wg_netns_pre_exit(struct net *net)
{
- struct net_device *dev = ((struct netdev_notifier_info *)data)->dev;
- struct wg_device *wg = netdev_priv(dev);
-
- ASSERT_RTNL();
-
- if (action != NETDEV_REGISTER || dev->netdev_ops != &netdev_ops)
- return 0;
+ struct wg_device *wg;
- if (dev_net(dev) == wg->creating_net && wg->have_creating_net_ref) {
- put_net(wg->creating_net);
- wg->have_creating_net_ref = false;
- } else if (dev_net(dev) != wg->creating_net &&
- !wg->have_creating_net_ref) {
- wg->have_creating_net_ref = true;
- get_net(wg->creating_net);
+ rtnl_lock();
+ list_for_each_entry(wg, &device_list, device_list) {
+ if (rcu_access_pointer(wg->creating_net) == net) {
+ pr_debug("%s: Creating namespace exiting\n", wg->dev->name);
+ netif_carrier_off(wg->dev);
+ mutex_lock(&wg->device_update_lock);
+ rcu_assign_pointer(wg->creating_net, NULL);
+ wg_socket_reinit(wg, NULL, NULL);
+ mutex_unlock(&wg->device_update_lock);
+ }
}
- return 0;
+ rtnl_unlock();
}
-static struct notifier_block netdevice_notifier = {
- .notifier_call = wg_netdevice_notification
+static struct pernet_operations pernet_ops = {
+ .pre_exit = wg_netns_pre_exit
};
int __init wg_device_init(void)
@@ -429,18 +426,18 @@ int __init wg_device_init(void)
return ret;
#endif
- ret = register_netdevice_notifier(&netdevice_notifier);
+ ret = register_pernet_device(&pernet_ops);
if (ret)
goto error_pm;
ret = rtnl_link_register(&link_ops);
if (ret)
- goto error_netdevice;
+ goto error_pernet;
return 0;
-error_netdevice:
- unregister_netdevice_notifier(&netdevice_notifier);
+error_pernet:
+ unregister_pernet_device(&pernet_ops);
error_pm:
#ifdef CONFIG_PM_SLEEP
unregister_pm_notifier(&pm_notifier);
@@ -451,7 +448,7 @@ error_pm:
void wg_device_uninit(void)
{
rtnl_link_unregister(&link_ops);
- unregister_netdevice_notifier(&netdevice_notifier);
+ unregister_pernet_device(&pernet_ops);
#ifdef CONFIG_PM_SLEEP
unregister_pm_notifier(&pm_notifier);
#endif
diff --git a/drivers/net/wireguard/device.h b/drivers/net/wireguard/device.h
index b15a8be9d816..4d0144e16947 100644
--- a/drivers/net/wireguard/device.h
+++ b/drivers/net/wireguard/device.h
@@ -40,7 +40,7 @@ struct wg_device {
struct net_device *dev;
struct crypt_queue encrypt_queue, decrypt_queue;
struct sock __rcu *sock4, *sock6;
- struct net *creating_net;
+ struct net __rcu *creating_net;
struct noise_static_identity static_identity;
struct workqueue_struct *handshake_receive_wq, *handshake_send_wq;
struct workqueue_struct *packet_crypt_wq;
@@ -56,7 +56,6 @@ struct wg_device {
unsigned int num_peers, device_update_gen;
u32 fwmark;
u16 incoming_port;
- bool have_creating_net_ref;
};
int wg_device_init(void);
diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c
index 802099c8828a..20a4f3c0a0a1 100644
--- a/drivers/net/wireguard/netlink.c
+++ b/drivers/net/wireguard/netlink.c
@@ -511,11 +511,15 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
if (flags & ~__WGDEVICE_F_ALL)
goto out;
- ret = -EPERM;
- if ((info->attrs[WGDEVICE_A_LISTEN_PORT] ||
- info->attrs[WGDEVICE_A_FWMARK]) &&
- !ns_capable(wg->creating_net->user_ns, CAP_NET_ADMIN))
- goto out;
+ if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
+ struct net *net;
+ rcu_read_lock();
+ net = rcu_dereference(wg->creating_net);
+ ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
+ rcu_read_unlock();
+ if (ret)
+ goto out;
+ }
++wg->device_update_gen;
diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c
index 626433690abb..201a22681945 100644
--- a/drivers/net/wireguard/noise.c
+++ b/drivers/net/wireguard/noise.c
@@ -617,8 +617,8 @@ wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
memcpy(handshake->hash, hash, NOISE_HASH_LEN);
memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
handshake->remote_index = src->sender_index;
- if ((s64)(handshake->last_initiation_consumption -
- (initiation_consumption = ktime_get_coarse_boottime_ns())) < 0)
+ initiation_consumption = ktime_get_coarse_boottime_ns();
+ if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
handshake->last_initiation_consumption = initiation_consumption;
handshake->state = HANDSHAKE_CONSUMED_INITIATION;
up_write(&handshake->lock);
diff --git a/drivers/net/wireguard/queueing.h b/drivers/net/wireguard/queueing.h
index c58df439dbbe..dfb674e03076 100644
--- a/drivers/net/wireguard/queueing.h
+++ b/drivers/net/wireguard/queueing.h
@@ -11,6 +11,7 @@
#include <linux/skbuff.h>
#include <linux/ip.h>
#include <linux/ipv6.h>
+#include <net/ip_tunnels.h>
struct wg_device;
struct wg_peer;
@@ -65,25 +66,9 @@ struct packet_cb {
#define PACKET_CB(skb) ((struct packet_cb *)((skb)->cb))
#define PACKET_PEER(skb) (PACKET_CB(skb)->keypair->entry.peer)
-/* Returns either the correct skb->protocol value, or 0 if invalid. */
-static inline __be16 wg_examine_packet_protocol(struct sk_buff *skb)
-{
- if (skb_network_header(skb) >= skb->head &&
- (skb_network_header(skb) + sizeof(struct iphdr)) <=
- skb_tail_pointer(skb) &&
- ip_hdr(skb)->version == 4)
- return htons(ETH_P_IP);
- if (skb_network_header(skb) >= skb->head &&
- (skb_network_header(skb) + sizeof(struct ipv6hdr)) <=
- skb_tail_pointer(skb) &&
- ipv6_hdr(skb)->version == 6)
- return htons(ETH_P_IPV6);
- return 0;
-}
-
static inline bool wg_check_packet_protocol(struct sk_buff *skb)
{
- __be16 real_protocol = wg_examine_packet_protocol(skb);
+ __be16 real_protocol = ip_tunnel_parse_protocol(skb);
return real_protocol && skb->protocol == real_protocol;
}
diff --git a/drivers/net/wireguard/receive.c b/drivers/net/wireguard/receive.c
index 91438144e4f7..2c9551ea6dc7 100644
--- a/drivers/net/wireguard/receive.c
+++ b/drivers/net/wireguard/receive.c
@@ -387,7 +387,7 @@ static void wg_packet_consume_data_done(struct wg_peer *peer,
*/
skb->ip_summed = CHECKSUM_UNNECESSARY;
skb->csum_level = ~0; /* All levels */
- skb->protocol = wg_examine_packet_protocol(skb);
+ skb->protocol = ip_tunnel_parse_protocol(skb);
if (skb->protocol == htons(ETH_P_IP)) {
len = ntohs(ip_hdr(skb)->tot_len);
if (unlikely(len < sizeof(struct iphdr)))
@@ -414,14 +414,8 @@ static void wg_packet_consume_data_done(struct wg_peer *peer,
if (unlikely(routed_peer != peer))
goto dishonest_packet_peer;
- if (unlikely(napi_gro_receive(&peer->napi, skb) == GRO_DROP)) {
- ++dev->stats.rx_dropped;
- net_dbg_ratelimited("%s: Failed to give packet to userspace from peer %llu (%pISpfsc)\n",
- dev->name, peer->internal_id,
- &peer->endpoint.addr);
- } else {
- update_rx_stats(peer, message_data_len(len_before_trim));
- }
+ napi_gro_receive(&peer->napi, skb);
+ update_rx_stats(peer, message_data_len(len_before_trim));
return;
dishonest_packet_peer:
diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c
index f9018027fc13..c33e2c81635f 100644
--- a/drivers/net/wireguard/socket.c
+++ b/drivers/net/wireguard/socket.c
@@ -347,6 +347,7 @@ static void set_sock_opts(struct socket *sock)
int wg_socket_init(struct wg_device *wg, u16 port)
{
+ struct net *net;
int ret;
struct udp_tunnel_sock_cfg cfg = {
.sk_user_data = wg,
@@ -371,37 +372,47 @@ int wg_socket_init(struct wg_device *wg, u16 port)
};
#endif
+ rcu_read_lock();
+ net = rcu_dereference(wg->creating_net);
+ net = net ? maybe_get_net(net) : NULL;
+ rcu_read_unlock();
+ if (unlikely(!net))
+ return -ENONET;
+
#if IS_ENABLED(CONFIG_IPV6)
retry:
#endif
- ret = udp_sock_create(wg->creating_net, &port4, &new4);
+ ret = udp_sock_create(net, &port4, &new4);
if (ret < 0) {
pr_err("%s: Could not create IPv4 socket\n", wg->dev->name);
- return ret;
+ goto out;
}
set_sock_opts(new4);
- setup_udp_tunnel_sock(wg->creating_net, new4, &cfg);
+ setup_udp_tunnel_sock(net, new4, &cfg);
#if IS_ENABLED(CONFIG_IPV6)
if (ipv6_mod_enabled()) {
port6.local_udp_port = inet_sk(new4->sk)->inet_sport;
- ret = udp_sock_create(wg->creating_net, &port6, &new6);
+ ret = udp_sock_create(net, &port6, &new6);
if (ret < 0) {
udp_tunnel_sock_release(new4);
if (ret == -EADDRINUSE && !port && retries++ < 100)
goto retry;
pr_err("%s: Could not create IPv6 socket\n",
wg->dev->name);
- return ret;
+ goto out;
}
set_sock_opts(new6);
- setup_udp_tunnel_sock(wg->creating_net, new6, &cfg);
+ setup_udp_tunnel_sock(net, new6, &cfg);
}
#endif
wg_socket_reinit(wg, new4->sk, new6 ? new6->sk : NULL);
- return 0;
+ ret = 0;
+out:
+ put_net(net);
+ return ret;
}
void wg_socket_reinit(struct wg_device *wg, struct sock *new4,