summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--MAINTAINERS1
-rw-r--r--drivers/vhost/vsock.c44
-rw-r--r--include/linux/virtio_vsock.h9
-rw-r--r--include/net/af_vsock.h61
-rw-r--r--include/net/net_namespace.h4
-rw-r--r--include/net/netns/vsock.h21
-rw-r--r--net/vmw_vsock/af_vsock.c335
-rw-r--r--net/vmw_vsock/hyperv_transport.c7
-rw-r--r--net/vmw_vsock/virtio_transport.c22
-rw-r--r--net/vmw_vsock/virtio_transport_common.c62
-rw-r--r--net/vmw_vsock/vmci_transport.c26
-rw-r--r--net/vmw_vsock/vsock_loopback.c22
-rw-r--r--tools/testing/selftests/vsock/settings2
-rwxr-xr-xtools/testing/selftests/vsock/vmtest.sh1055
14 files changed, 1531 insertions, 140 deletions
diff --git a/MAINTAINERS b/MAINTAINERS
index 62392b61a52f..c3df85fd5acd 100644
--- a/MAINTAINERS
+++ b/MAINTAINERS
@@ -27556,6 +27556,7 @@ L: netdev@vger.kernel.org
S: Maintained
F: drivers/vhost/vsock.c
F: include/linux/virtio_vsock.h
+F: include/net/netns/vsock.h
F: include/uapi/linux/virtio_vsock.h
F: net/vmw_vsock/virtio_transport.c
F: net/vmw_vsock/virtio_transport_common.c
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 552cfb53498a..488d7fa6e4ec 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -48,6 +48,8 @@ static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
struct vhost_vsock {
struct vhost_dev dev;
struct vhost_virtqueue vqs[2];
+ struct net *net;
+ netns_tracker ns_tracker;
/* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
struct hlist_node hash;
@@ -69,7 +71,7 @@ static u32 vhost_transport_get_local_cid(void)
/* Callers must be in an RCU read section or hold the vhost_vsock_mutex.
* The return value can only be dereferenced while within the section.
*/
-static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
+static struct vhost_vsock *vhost_vsock_get(u32 guest_cid, struct net *net)
{
struct vhost_vsock *vsock;
@@ -81,9 +83,9 @@ static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
if (other_cid == 0)
continue;
- if (other_cid == guest_cid)
+ if (other_cid == guest_cid &&
+ vsock_net_check_mode(net, vsock->net))
return vsock;
-
}
return NULL;
@@ -272,7 +274,7 @@ static void vhost_transport_send_pkt_work(struct vhost_work *work)
}
static int
-vhost_transport_send_pkt(struct sk_buff *skb)
+vhost_transport_send_pkt(struct sk_buff *skb, struct net *net)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
struct vhost_vsock *vsock;
@@ -281,7 +283,7 @@ vhost_transport_send_pkt(struct sk_buff *skb)
rcu_read_lock();
/* Find the vhost_vsock according to guest context id */
- vsock = vhost_vsock_get(le64_to_cpu(hdr->dst_cid));
+ vsock = vhost_vsock_get(le64_to_cpu(hdr->dst_cid), net);
if (!vsock) {
rcu_read_unlock();
kfree_skb(skb);
@@ -308,7 +310,8 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk)
rcu_read_lock();
/* Find the vhost_vsock according to guest context id */
- vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
+ vsock = vhost_vsock_get(vsk->remote_addr.svm_cid,
+ sock_net(sk_vsock(vsk)));
if (!vsock)
goto out;
@@ -407,7 +410,14 @@ static bool vhost_transport_msgzerocopy_allow(void)
return true;
}
-static bool vhost_transport_seqpacket_allow(u32 remote_cid);
+static bool vhost_transport_seqpacket_allow(struct vsock_sock *vsk,
+ u32 remote_cid);
+
+static bool
+vhost_transport_stream_allow(struct vsock_sock *vsk, u32 cid, u32 port)
+{
+ return true;
+}
static struct virtio_transport vhost_transport = {
.transport = {
@@ -433,7 +443,7 @@ static struct virtio_transport vhost_transport = {
.stream_has_space = virtio_transport_stream_has_space,
.stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
.stream_is_active = virtio_transport_stream_is_active,
- .stream_allow = virtio_transport_stream_allow,
+ .stream_allow = vhost_transport_stream_allow,
.seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
.seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
@@ -463,13 +473,15 @@ static struct virtio_transport vhost_transport = {
.send_pkt = vhost_transport_send_pkt,
};
-static bool vhost_transport_seqpacket_allow(u32 remote_cid)
+static bool vhost_transport_seqpacket_allow(struct vsock_sock *vsk,
+ u32 remote_cid)
{
+ struct net *net = sock_net(sk_vsock(vsk));
struct vhost_vsock *vsock;
bool seqpacket_allow = false;
rcu_read_lock();
- vsock = vhost_vsock_get(remote_cid);
+ vsock = vhost_vsock_get(remote_cid, net);
if (vsock)
seqpacket_allow = vsock->seqpacket_allow;
@@ -540,7 +552,8 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
if (le64_to_cpu(hdr->src_cid) == vsock->guest_cid &&
le64_to_cpu(hdr->dst_cid) ==
vhost_transport_get_local_cid())
- virtio_transport_recv_pkt(&vhost_transport, skb);
+ virtio_transport_recv_pkt(&vhost_transport, skb,
+ vsock->net);
else
kfree_skb(skb);
@@ -657,6 +670,7 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
{
struct vhost_virtqueue **vqs;
struct vhost_vsock *vsock;
+ struct net *net;
int ret;
/* This struct is large and allocation could fail, fall back to vmalloc
@@ -672,6 +686,9 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
goto out;
}
+ net = current->nsproxy->net_ns;
+ vsock->net = get_net_track(net, &vsock->ns_tracker, GFP_KERNEL);
+
vsock->guest_cid = 0; /* no CID assigned yet */
vsock->seqpacket_allow = false;
@@ -713,7 +730,7 @@ static void vhost_vsock_reset_orphans(struct sock *sk)
rcu_read_lock();
/* If the peer is still valid, no need to reset connection */
- if (vhost_vsock_get(vsk->remote_addr.svm_cid)) {
+ if (vhost_vsock_get(vsk->remote_addr.svm_cid, sock_net(sk))) {
rcu_read_unlock();
return;
}
@@ -762,6 +779,7 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
virtio_vsock_skb_queue_purge(&vsock->send_pkt_queue);
vhost_dev_cleanup(&vsock->dev);
+ put_net_track(vsock->net, &vsock->ns_tracker);
kfree(vsock->dev.vqs);
vhost_vsock_free(vsock);
return 0;
@@ -788,7 +806,7 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
/* Refuse if CID is already in use */
mutex_lock(&vhost_vsock_mutex);
- other = vhost_vsock_get(guest_cid);
+ other = vhost_vsock_get(guest_cid, vsock->net);
if (other && other != vsock) {
mutex_unlock(&vhost_vsock_mutex);
return -EADDRINUSE;
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index 0c67543a45c8..f91704731057 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -173,6 +173,7 @@ struct virtio_vsock_pkt_info {
u32 remote_cid, remote_port;
struct vsock_sock *vsk;
struct msghdr *msg;
+ struct net *net;
u32 pkt_len;
u16 type;
u16 op;
@@ -185,7 +186,7 @@ struct virtio_transport {
struct vsock_transport transport;
/* Takes ownership of the packet */
- int (*send_pkt)(struct sk_buff *skb);
+ int (*send_pkt)(struct sk_buff *skb, struct net *net);
/* Used in MSG_ZEROCOPY mode. Checks, that provided data
* (number of buffers) could be transmitted with zerocopy
@@ -256,10 +257,10 @@ void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val);
u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk);
bool virtio_transport_stream_is_active(struct vsock_sock *vsk);
-bool virtio_transport_stream_allow(u32 cid, u32 port);
+bool virtio_transport_stream_allow(struct vsock_sock *vsk, u32 cid, u32 port);
int virtio_transport_dgram_bind(struct vsock_sock *vsk,
struct sockaddr_vm *addr);
-bool virtio_transport_dgram_allow(u32 cid, u32 port);
+bool virtio_transport_dgram_allow(struct vsock_sock *vsk, u32 cid, u32 port);
int virtio_transport_connect(struct vsock_sock *vsk);
@@ -280,7 +281,7 @@ virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
void virtio_transport_destruct(struct vsock_sock *vsk);
void virtio_transport_recv_pkt(struct virtio_transport *t,
- struct sk_buff *skb);
+ struct sk_buff *skb, struct net *net);
void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct sk_buff *skb);
u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 wanted);
void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit);
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index d40e978126e3..d3ff48a2fbe0 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -10,6 +10,7 @@
#include <linux/kernel.h>
#include <linux/workqueue.h>
+#include <net/netns/vsock.h>
#include <net/sock.h>
#include <uapi/linux/vm_sockets.h>
@@ -124,7 +125,7 @@ struct vsock_transport {
size_t len, int flags);
int (*dgram_enqueue)(struct vsock_sock *, struct sockaddr_vm *,
struct msghdr *, size_t len);
- bool (*dgram_allow)(u32 cid, u32 port);
+ bool (*dgram_allow)(struct vsock_sock *vsk, u32 cid, u32 port);
/* STREAM. */
/* TODO: stream_bind() */
@@ -136,14 +137,14 @@ struct vsock_transport {
s64 (*stream_has_space)(struct vsock_sock *);
u64 (*stream_rcvhiwat)(struct vsock_sock *);
bool (*stream_is_active)(struct vsock_sock *);
- bool (*stream_allow)(u32 cid, u32 port);
+ bool (*stream_allow)(struct vsock_sock *vsk, u32 cid, u32 port);
/* SEQ_PACKET. */
ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg,
int flags);
int (*seqpacket_enqueue)(struct vsock_sock *vsk, struct msghdr *msg,
size_t len);
- bool (*seqpacket_allow)(u32 remote_cid);
+ bool (*seqpacket_allow)(struct vsock_sock *vsk, u32 remote_cid);
u32 (*seqpacket_has_data)(struct vsock_sock *vsk);
/* Notification. */
@@ -216,6 +217,11 @@ void vsock_remove_connected(struct vsock_sock *vsk);
struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr);
struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
struct sockaddr_vm *dst);
+struct sock *vsock_find_bound_socket_net(struct sockaddr_vm *addr,
+ struct net *net);
+struct sock *vsock_find_connected_socket_net(struct sockaddr_vm *src,
+ struct sockaddr_vm *dst,
+ struct net *net);
void vsock_remove_sock(struct vsock_sock *vsk);
void vsock_for_each_connected_socket(struct vsock_transport *transport,
void (*fn)(struct sock *sk));
@@ -256,4 +262,53 @@ static inline bool vsock_msgzerocopy_allow(const struct vsock_transport *t)
{
return t->msgzerocopy_allow && t->msgzerocopy_allow();
}
+
+static inline enum vsock_net_mode vsock_net_mode(struct net *net)
+{
+ if (!net)
+ return VSOCK_NET_MODE_GLOBAL;
+
+ return READ_ONCE(net->vsock.mode);
+}
+
+static inline bool vsock_net_mode_global(struct vsock_sock *vsk)
+{
+ return vsock_net_mode(sock_net(sk_vsock(vsk))) == VSOCK_NET_MODE_GLOBAL;
+}
+
+static inline void vsock_net_set_child_mode(struct net *net,
+ enum vsock_net_mode mode)
+{
+ WRITE_ONCE(net->vsock.child_ns_mode, mode);
+}
+
+static inline enum vsock_net_mode vsock_net_child_mode(struct net *net)
+{
+ return READ_ONCE(net->vsock.child_ns_mode);
+}
+
+/* Return true if two namespaces pass the mode rules. Otherwise, return false.
+ *
+ * A NULL namespace is treated as VSOCK_NET_MODE_GLOBAL.
+ *
+ * Read more about modes in the comment header of net/vmw_vsock/af_vsock.c.
+ */
+static inline bool vsock_net_check_mode(struct net *ns0, struct net *ns1)
+{
+ enum vsock_net_mode mode0, mode1;
+
+ /* Any vsocks within the same network namespace are always reachable,
+ * regardless of the mode.
+ */
+ if (net_eq(ns0, ns1))
+ return true;
+
+ mode0 = vsock_net_mode(ns0);
+ mode1 = vsock_net_mode(ns1);
+
+ /* Different namespaces are only reachable if they are both
+ * global mode.
+ */
+ return mode0 == VSOCK_NET_MODE_GLOBAL && mode0 == mode1;
+}
#endif /* __AF_VSOCK_H__ */
diff --git a/include/net/net_namespace.h b/include/net/net_namespace.h
index cb664f6e3558..66d3de1d935f 100644
--- a/include/net/net_namespace.h
+++ b/include/net/net_namespace.h
@@ -37,6 +37,7 @@
#include <net/netns/smc.h>
#include <net/netns/bpf.h>
#include <net/netns/mctp.h>
+#include <net/netns/vsock.h>
#include <net/net_trackers.h>
#include <linux/ns_common.h>
#include <linux/idr.h>
@@ -196,6 +197,9 @@ struct net {
/* Move to a better place when the config guard is removed. */
struct mutex rtnl_mutex;
#endif
+#if IS_ENABLED(CONFIG_VSOCKETS)
+ struct netns_vsock vsock;
+#endif
} __randomize_layout;
#include <linux/seq_file_net.h>
diff --git a/include/net/netns/vsock.h b/include/net/netns/vsock.h
new file mode 100644
index 000000000000..b34d69a22fa8
--- /dev/null
+++ b/include/net/netns/vsock.h
@@ -0,0 +1,21 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+#ifndef __NET_NET_NAMESPACE_VSOCK_H
+#define __NET_NET_NAMESPACE_VSOCK_H
+
+#include <linux/types.h>
+
+enum vsock_net_mode {
+ VSOCK_NET_MODE_GLOBAL,
+ VSOCK_NET_MODE_LOCAL,
+};
+
+struct netns_vsock {
+ struct ctl_table_header *sysctl_hdr;
+
+ /* protected by the vsock_table_lock in af_vsock.c */
+ u32 port;
+
+ enum vsock_net_mode mode;
+ enum vsock_net_mode child_ns_mode;
+};
+#endif /* __NET_NET_NAMESPACE_VSOCK_H */
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index a3505a4dcee0..20ad2b2dc17b 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -83,6 +83,50 @@
* TCP_ESTABLISHED - connected
* TCP_CLOSING - disconnecting
* TCP_LISTEN - listening
+ *
+ * - Namespaces in vsock support two different modes: "local" and "global".
+ * Each mode defines how the namespace interacts with CIDs.
+ * Each namespace exposes two sysctl files:
+ *
+ * - /proc/sys/net/vsock/ns_mode (read-only) reports the current namespace's
+ * mode, which is set at namespace creation and immutable thereafter.
+ * - /proc/sys/net/vsock/child_ns_mode (writable) controls what mode future
+ * child namespaces will inherit when created. The default is "global".
+ *
+ * Changing child_ns_mode only affects newly created namespaces, not the
+ * current namespace or existing children. At namespace creation, ns_mode
+ * is inherited from the parent's child_ns_mode.
+ *
+ * The init_net mode is "global" and cannot be modified.
+ *
+ * The modes affect the allocation and accessibility of CIDs as follows:
+ *
+ * - global - access and allocation are all system-wide
+ * - all CID allocation from global namespaces draw from the same
+ * system-wide pool.
+ * - if one global namespace has already allocated some CID, another
+ * global namespace will not be able to allocate the same CID.
+ * - global mode AF_VSOCK sockets can reach any VM or socket in any global
+ * namespace, they are not contained to only their own namespace.
+ * - AF_VSOCK sockets in a global mode namespace cannot reach VMs or
+ * sockets in any local mode namespace.
+ * - local - access and allocation are contained within the namespace
+ * - CID allocation draws only from a private pool local only to the
+ * namespace, and does not affect the CIDs available for allocation in any
+ * other namespace (global or local).
+ * - VMs in a local namespace do not collide with CIDs in any other local
+ * namespace or any global namespace. For example, if a VM in a local mode
+ * namespace is given CID 10, then CID 10 is still available for
+ * allocation in any other namespace, but not in the same namespace.
+ * - AF_VSOCK sockets in a local mode namespace can connect only to VMs or
+ * other sockets within their own namespace.
+ * - sockets bound to VMADDR_CID_ANY in local namespaces will never resolve
+ * to any transport that is not compatible with local mode. There is no
+ * error that propagates to the user (as there is for connection attempts)
+ * because it is possible for some packet to reach this socket from
+ * a different transport that *does* support local mode. For
+ * example, virtio-vsock may not support local mode, but the socket
+ * may still accept a connection from vhost-vsock which does.
*/
#include <linux/compat.h>
@@ -100,20 +144,31 @@
#include <linux/module.h>
#include <linux/mutex.h>
#include <linux/net.h>
+#include <linux/proc_fs.h>
#include <linux/poll.h>
#include <linux/random.h>
#include <linux/skbuff.h>
#include <linux/smp.h>
#include <linux/socket.h>
#include <linux/stddef.h>
+#include <linux/sysctl.h>
#include <linux/unistd.h>
#include <linux/wait.h>
#include <linux/workqueue.h>
#include <net/sock.h>
#include <net/af_vsock.h>
+#include <net/netns/vsock.h>
#include <uapi/linux/vm_sockets.h>
#include <uapi/asm-generic/ioctls.h>
+#define VSOCK_NET_MODE_STR_GLOBAL "global"
+#define VSOCK_NET_MODE_STR_LOCAL "local"
+
+/* 6 chars for "global", 1 for null-terminator, and 1 more for '\n'.
+ * The newline is added by proc_dostring() for read operations.
+ */
+#define VSOCK_NET_MODE_STR_MAX 8
+
static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr);
static void vsock_sk_destruct(struct sock *sk);
static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
@@ -235,33 +290,42 @@ static void __vsock_remove_connected(struct vsock_sock *vsk)
sock_put(&vsk->sk);
}
-static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr)
+static struct sock *__vsock_find_bound_socket_net(struct sockaddr_vm *addr,
+ struct net *net)
{
struct vsock_sock *vsk;
list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) {
- if (vsock_addr_equals_addr(addr, &vsk->local_addr))
- return sk_vsock(vsk);
+ struct sock *sk = sk_vsock(vsk);
+
+ if (vsock_addr_equals_addr(addr, &vsk->local_addr) &&
+ vsock_net_check_mode(sock_net(sk), net))
+ return sk;
if (addr->svm_port == vsk->local_addr.svm_port &&
(vsk->local_addr.svm_cid == VMADDR_CID_ANY ||
- addr->svm_cid == VMADDR_CID_ANY))
- return sk_vsock(vsk);
+ addr->svm_cid == VMADDR_CID_ANY) &&
+ vsock_net_check_mode(sock_net(sk), net))
+ return sk;
}
return NULL;
}
-static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src,
- struct sockaddr_vm *dst)
+static struct sock *
+__vsock_find_connected_socket_net(struct sockaddr_vm *src,
+ struct sockaddr_vm *dst, struct net *net)
{
struct vsock_sock *vsk;
list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
connected_table) {
+ struct sock *sk = sk_vsock(vsk);
+
if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
- dst->svm_port == vsk->local_addr.svm_port) {
- return sk_vsock(vsk);
+ dst->svm_port == vsk->local_addr.svm_port &&
+ vsock_net_check_mode(sock_net(sk), net)) {
+ return sk;
}
}
@@ -304,12 +368,18 @@ void vsock_remove_connected(struct vsock_sock *vsk)
}
EXPORT_SYMBOL_GPL(vsock_remove_connected);
-struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr)
+/* Find a bound socket, filtering by namespace and namespace mode.
+ *
+ * Use this in transports that are namespace-aware and can provide the
+ * network namespace context.
+ */
+struct sock *vsock_find_bound_socket_net(struct sockaddr_vm *addr,
+ struct net *net)
{
struct sock *sk;
spin_lock_bh(&vsock_table_lock);
- sk = __vsock_find_bound_socket(addr);
+ sk = __vsock_find_bound_socket_net(addr, net);
if (sk)
sock_hold(sk);
@@ -317,15 +387,32 @@ struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr)
return sk;
}
+EXPORT_SYMBOL_GPL(vsock_find_bound_socket_net);
+
+/* Find a bound socket without namespace filtering.
+ *
+ * Use this in transports that lack namespace context. All sockets are
+ * treated as if in global mode.
+ */
+struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr)
+{
+ return vsock_find_bound_socket_net(addr, NULL);
+}
EXPORT_SYMBOL_GPL(vsock_find_bound_socket);
-struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
- struct sockaddr_vm *dst)
+/* Find a connected socket, filtering by namespace and namespace mode.
+ *
+ * Use this in transports that are namespace-aware and can provide the
+ * network namespace context.
+ */
+struct sock *vsock_find_connected_socket_net(struct sockaddr_vm *src,
+ struct sockaddr_vm *dst,
+ struct net *net)
{
struct sock *sk;
spin_lock_bh(&vsock_table_lock);
- sk = __vsock_find_connected_socket(src, dst);
+ sk = __vsock_find_connected_socket_net(src, dst, net);
if (sk)
sock_hold(sk);
@@ -333,6 +420,18 @@ struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
return sk;
}
+EXPORT_SYMBOL_GPL(vsock_find_connected_socket_net);
+
+/* Find a connected socket without namespace filtering.
+ *
+ * Use this in transports that lack namespace context. All sockets are
+ * treated as if in global mode.
+ */
+struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
+ struct sockaddr_vm *dst)
+{
+ return vsock_find_connected_socket_net(src, dst, NULL);
+}
EXPORT_SYMBOL_GPL(vsock_find_connected_socket);
void vsock_remove_sock(struct vsock_sock *vsk)
@@ -528,7 +627,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
if (sk->sk_type == SOCK_SEQPACKET) {
if (!new_transport->seqpacket_allow ||
- !new_transport->seqpacket_allow(remote_cid)) {
+ !new_transport->seqpacket_allow(vsk, remote_cid)) {
module_put(new_transport->module);
return -ESOCKTNOSUPPORT;
}
@@ -676,11 +775,11 @@ out:
static int __vsock_bind_connectible(struct vsock_sock *vsk,
struct sockaddr_vm *addr)
{
- static u32 port;
+ struct net *net = sock_net(sk_vsock(vsk));
struct sockaddr_vm new_addr;
- if (!port)
- port = get_random_u32_above(LAST_RESERVED_PORT);
+ if (!net->vsock.port)
+ net->vsock.port = get_random_u32_above(LAST_RESERVED_PORT);
vsock_addr_init(&new_addr, addr->svm_cid, addr->svm_port);
@@ -689,13 +788,13 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
unsigned int i;
for (i = 0; i < MAX_PORT_RETRIES; i++) {
- if (port == VMADDR_PORT_ANY ||
- port <= LAST_RESERVED_PORT)
- port = LAST_RESERVED_PORT + 1;
+ if (net->vsock.port == VMADDR_PORT_ANY ||
+ net->vsock.port <= LAST_RESERVED_PORT)
+ net->vsock.port = LAST_RESERVED_PORT + 1;
- new_addr.svm_port = port++;
+ new_addr.svm_port = net->vsock.port++;
- if (!__vsock_find_bound_socket(&new_addr)) {
+ if (!__vsock_find_bound_socket_net(&new_addr, net)) {
found = true;
break;
}
@@ -712,7 +811,7 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
return -EACCES;
}
- if (__vsock_find_bound_socket(&new_addr))
+ if (__vsock_find_bound_socket_net(&new_addr, net))
return -EADDRINUSE;
}
@@ -1314,7 +1413,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
goto out;
}
- if (!transport->dgram_allow(remote_addr->svm_cid,
+ if (!transport->dgram_allow(vsk, remote_addr->svm_cid,
remote_addr->svm_port)) {
err = -EINVAL;
goto out;
@@ -1355,7 +1454,7 @@ static int vsock_dgram_connect(struct socket *sock,
if (err)
goto out;
- if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
+ if (!vsk->transport->dgram_allow(vsk, remote_addr->svm_cid,
remote_addr->svm_port)) {
err = -EINVAL;
goto out;
@@ -1585,7 +1684,7 @@ static int vsock_connect(struct socket *sock, struct sockaddr_unsized *addr,
* endpoints.
*/
if (!transport ||
- !transport->stream_allow(remote_addr->svm_cid,
+ !transport->stream_allow(vsk, remote_addr->svm_cid,
remote_addr->svm_port)) {
err = -ENETUNREACH;
goto out;
@@ -2662,6 +2761,180 @@ static struct miscdevice vsock_device = {
.fops = &vsock_device_ops,
};
+static int __vsock_net_mode_string(const struct ctl_table *table, int write,
+ void *buffer, size_t *lenp, loff_t *ppos,
+ enum vsock_net_mode mode,
+ enum vsock_net_mode *new_mode)
+{
+ char data[VSOCK_NET_MODE_STR_MAX] = {0};
+ struct ctl_table tmp;
+ int ret;
+
+ if (!table->data || !table->maxlen || !*lenp) {
+ *lenp = 0;
+ return 0;
+ }
+
+ tmp = *table;
+ tmp.data = data;
+
+ if (!write) {
+ const char *p;
+
+ switch (mode) {
+ case VSOCK_NET_MODE_GLOBAL:
+ p = VSOCK_NET_MODE_STR_GLOBAL;
+ break;
+ case VSOCK_NET_MODE_LOCAL:
+ p = VSOCK_NET_MODE_STR_LOCAL;
+ break;
+ default:
+ WARN_ONCE(true, "netns has invalid vsock mode");
+ *lenp = 0;
+ return 0;
+ }
+
+ strscpy(data, p, sizeof(data));
+ tmp.maxlen = strlen(p);
+ }
+
+ ret = proc_dostring(&tmp, write, buffer, lenp, ppos);
+ if (ret || !write)
+ return ret;
+
+ if (*lenp >= sizeof(data))
+ return -EINVAL;
+
+ if (!strncmp(data, VSOCK_NET_MODE_STR_GLOBAL, sizeof(data)))
+ *new_mode = VSOCK_NET_MODE_GLOBAL;
+ else if (!strncmp(data, VSOCK_NET_MODE_STR_LOCAL, sizeof(data)))
+ *new_mode = VSOCK_NET_MODE_LOCAL;
+ else
+ return -EINVAL;
+
+ return 0;
+}
+
+static int vsock_net_mode_string(const struct ctl_table *table, int write,
+ void *buffer, size_t *lenp, loff_t *ppos)
+{
+ struct net *net;
+
+ if (write)
+ return -EPERM;
+
+ net = current->nsproxy->net_ns;
+
+ return __vsock_net_mode_string(table, write, buffer, lenp, ppos,
+ vsock_net_mode(net), NULL);
+}
+
+static int vsock_net_child_mode_string(const struct ctl_table *table, int write,
+ void *buffer, size_t *lenp, loff_t *ppos)
+{
+ enum vsock_net_mode new_mode;
+ struct net *net;
+ int ret;
+
+ net = current->nsproxy->net_ns;
+
+ ret = __vsock_net_mode_string(table, write, buffer, lenp, ppos,
+ vsock_net_child_mode(net), &new_mode);
+ if (ret)
+ return ret;
+
+ if (write)
+ vsock_net_set_child_mode(net, new_mode);
+
+ return 0;
+}
+
+static struct ctl_table vsock_table[] = {
+ {
+ .procname = "ns_mode",
+ .data = &init_net.vsock.mode,
+ .maxlen = VSOCK_NET_MODE_STR_MAX,
+ .mode = 0444,
+ .proc_handler = vsock_net_mode_string
+ },
+ {
+ .procname = "child_ns_mode",
+ .data = &init_net.vsock.child_ns_mode,
+ .maxlen = VSOCK_NET_MODE_STR_MAX,
+ .mode = 0644,
+ .proc_handler = vsock_net_child_mode_string
+ },
+};
+
+static int __net_init vsock_sysctl_register(struct net *net)
+{
+ struct ctl_table *table;
+
+ if (net_eq(net, &init_net)) {
+ table = vsock_table;
+ } else {
+ table = kmemdup(vsock_table, sizeof(vsock_table), GFP_KERNEL);
+ if (!table)
+ goto err_alloc;
+
+ table[0].data = &net->vsock.mode;
+ table[1].data = &net->vsock.child_ns_mode;
+ }
+
+ net->vsock.sysctl_hdr = register_net_sysctl_sz(net, "net/vsock", table,
+ ARRAY_SIZE(vsock_table));
+ if (!net->vsock.sysctl_hdr)
+ goto err_reg;
+
+ return 0;
+
+err_reg:
+ if (!net_eq(net, &init_net))
+ kfree(table);
+err_alloc:
+ return -ENOMEM;
+}
+
+static void vsock_sysctl_unregister(struct net *net)
+{
+ const struct ctl_table *table;
+
+ table = net->vsock.sysctl_hdr->ctl_table_arg;
+ unregister_net_sysctl_table(net->vsock.sysctl_hdr);
+ if (!net_eq(net, &init_net))
+ kfree(table);
+}
+
+static void vsock_net_init(struct net *net)
+{
+ if (net_eq(net, &init_net))
+ net->vsock.mode = VSOCK_NET_MODE_GLOBAL;
+ else
+ net->vsock.mode = vsock_net_child_mode(current->nsproxy->net_ns);
+
+ net->vsock.child_ns_mode = VSOCK_NET_MODE_GLOBAL;
+}
+
+static __net_init int vsock_sysctl_init_net(struct net *net)
+{
+ vsock_net_init(net);
+
+ if (vsock_sysctl_register(net))
+ return -ENOMEM;
+
+ return 0;
+}
+
+static __net_exit void vsock_sysctl_exit_net(struct net *net)
+{
+ vsock_sysctl_unregister(net);
+}
+
+static struct pernet_operations vsock_sysctl_ops = {
+ .init = vsock_sysctl_init_net,
+ .exit = vsock_sysctl_exit_net,
+};
+
static int __init vsock_init(void)
{
int err = 0;
@@ -2689,10 +2962,17 @@ static int __init vsock_init(void)
goto err_unregister_proto;
}
+ if (register_pernet_subsys(&vsock_sysctl_ops)) {
+ err = -ENOMEM;
+ goto err_unregister_sock;
+ }
+
vsock_bpf_build_proto();
return 0;
+err_unregister_sock:
+ sock_unregister(AF_VSOCK);
err_unregister_proto:
proto_unregister(&vsock_proto);
err_deregister_misc:
@@ -2706,6 +2986,7 @@ static void __exit vsock_exit(void)
misc_deregister(&vsock_device);
sock_unregister(AF_VSOCK);
proto_unregister(&vsock_proto);
+ unregister_pernet_subsys(&vsock_sysctl_ops);
}
const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
index 432fcbbd14d4..c3010c874308 100644
--- a/net/vmw_vsock/hyperv_transport.c
+++ b/net/vmw_vsock/hyperv_transport.c
@@ -570,7 +570,7 @@ static int hvs_dgram_enqueue(struct vsock_sock *vsk,
return -EOPNOTSUPP;
}
-static bool hvs_dgram_allow(u32 cid, u32 port)
+static bool hvs_dgram_allow(struct vsock_sock *vsk, u32 cid, u32 port)
{
return false;
}
@@ -745,8 +745,11 @@ static bool hvs_stream_is_active(struct vsock_sock *vsk)
return hvs->chan != NULL;
}
-static bool hvs_stream_allow(u32 cid, u32 port)
+static bool hvs_stream_allow(struct vsock_sock *vsk, u32 cid, u32 port)
{
+ if (!vsock_net_mode_global(vsk))
+ return false;
+
if (cid == VMADDR_CID_HOST)
return true;
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index 8c867023a2e5..3f7ea2db9bd7 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -231,7 +231,7 @@ static int virtio_transport_send_skb_fast_path(struct virtio_vsock *vsock, struc
}
static int
-virtio_transport_send_pkt(struct sk_buff *skb)
+virtio_transport_send_pkt(struct sk_buff *skb, struct net *net)
{
struct virtio_vsock_hdr *hdr;
struct virtio_vsock *vsock;
@@ -536,7 +536,13 @@ static bool virtio_transport_msgzerocopy_allow(void)
return true;
}
-static bool virtio_transport_seqpacket_allow(u32 remote_cid);
+bool virtio_transport_stream_allow(struct vsock_sock *vsk, u32 cid, u32 port)
+{
+ return vsock_net_mode_global(vsk);
+}
+
+static bool virtio_transport_seqpacket_allow(struct vsock_sock *vsk,
+ u32 remote_cid);
static struct virtio_transport virtio_transport = {
.transport = {
@@ -593,11 +599,15 @@ static struct virtio_transport virtio_transport = {
.can_msgzerocopy = virtio_transport_can_msgzerocopy,
};
-static bool virtio_transport_seqpacket_allow(u32 remote_cid)
+static bool
+virtio_transport_seqpacket_allow(struct vsock_sock *vsk, u32 remote_cid)
{
struct virtio_vsock *vsock;
bool seqpacket_allow;
+ if (!vsock_net_mode_global(vsk))
+ return false;
+
seqpacket_allow = false;
rcu_read_lock();
vsock = rcu_dereference(the_virtio_vsock);
@@ -660,7 +670,11 @@ static void virtio_transport_rx_work(struct work_struct *work)
virtio_vsock_skb_put(skb, payload_len);
virtio_transport_deliver_tap_pkt(skb);
- virtio_transport_recv_pkt(&virtio_transport, skb);
+
+ /* Force virtio-transport into global mode since it
+ * does not yet support local-mode namespacing.
+ */
+ virtio_transport_recv_pkt(&virtio_transport, skb, NULL);
}
} while (!virtqueue_enable_cb(vq));
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index d3e26025ef58..d017ab318a7e 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -414,7 +414,7 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
virtio_transport_inc_tx_pkt(vvs, skb);
- ret = t_ops->send_pkt(skb);
+ ret = t_ops->send_pkt(skb, info->net);
if (ret < 0)
break;
@@ -526,6 +526,7 @@ static int virtio_transport_send_credit_update(struct vsock_sock *vsk)
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
.vsk = vsk,
+ .net = sock_net(sk_vsock(vsk)),
};
return virtio_transport_send_pkt_info(vsk, &info);
@@ -1055,12 +1056,6 @@ bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
-bool virtio_transport_stream_allow(u32 cid, u32 port)
-{
- return true;
-}
-EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
-
int virtio_transport_dgram_bind(struct vsock_sock *vsk,
struct sockaddr_vm *addr)
{
@@ -1068,7 +1063,7 @@ int virtio_transport_dgram_bind(struct vsock_sock *vsk,
}
EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
-bool virtio_transport_dgram_allow(u32 cid, u32 port)
+bool virtio_transport_dgram_allow(struct vsock_sock *vsk, u32 cid, u32 port)
{
return false;
}
@@ -1079,6 +1074,7 @@ int virtio_transport_connect(struct vsock_sock *vsk)
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_REQUEST,
.vsk = vsk,
+ .net = sock_net(sk_vsock(vsk)),
};
return virtio_transport_send_pkt_info(vsk, &info);
@@ -1094,6 +1090,7 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
(mode & SEND_SHUTDOWN ?
VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
.vsk = vsk,
+ .net = sock_net(sk_vsock(vsk)),
};
return virtio_transport_send_pkt_info(vsk, &info);
@@ -1120,6 +1117,7 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk,
.msg = msg,
.pkt_len = len,
.vsk = vsk,
+ .net = sock_net(sk_vsock(vsk)),
};
return virtio_transport_send_pkt_info(vsk, &info);
@@ -1157,6 +1155,7 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
.op = VIRTIO_VSOCK_OP_RST,
.reply = !!skb,
.vsk = vsk,
+ .net = sock_net(sk_vsock(vsk)),
};
/* Send RST only if the original pkt is not a RST pkt */
@@ -1168,15 +1167,31 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
/* Normally packets are associated with a socket. There may be no socket if an
* attempt was made to connect to a socket that does not exist.
+ *
+ * net refers to the namespace of whoever sent the invalid message. For
+ * loopback, this is the namespace of the socket. For vhost, this is the
+ * namespace of the VM (i.e., vhost_vsock).
*/
static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
- struct sk_buff *skb)
+ struct sk_buff *skb, struct net *net)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RST,
.type = le16_to_cpu(hdr->type),
.reply = true,
+
+ /* Set sk owner to socket we are replying to (may be NULL for
+ * non-loopback). This keeps a reference to the sock and
+ * sock_net(sk) until the reply skb is freed.
+ */
+ .vsk = vsock_sk(skb->sk),
+
+ /* net is not defined here because we pass it directly to
+ * t->send_pkt(), instead of relying on
+ * virtio_transport_send_pkt_info() to pass it. It is not needed
+ * by virtio_transport_alloc_skb().
+ */
};
struct sk_buff *reply;
@@ -1195,7 +1210,7 @@ static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
if (!reply)
return -ENOMEM;
- return t->send_pkt(reply);
+ return t->send_pkt(reply, net);
}
/* This function should be called with sk_lock held and SOCK_DONE set */
@@ -1479,6 +1494,7 @@ virtio_transport_send_response(struct vsock_sock *vsk,
.remote_port = le32_to_cpu(hdr->src_port),
.reply = true,
.vsk = vsk,
+ .net = sock_net(sk_vsock(vsk)),
};
return virtio_transport_send_pkt_info(vsk, &info);
@@ -1521,12 +1537,12 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
int ret;
if (le16_to_cpu(hdr->op) != VIRTIO_VSOCK_OP_REQUEST) {
- virtio_transport_reset_no_sock(t, skb);
+ virtio_transport_reset_no_sock(t, skb, sock_net(sk));
return -EINVAL;
}
if (sk_acceptq_is_full(sk)) {
- virtio_transport_reset_no_sock(t, skb);
+ virtio_transport_reset_no_sock(t, skb, sock_net(sk));
return -ENOMEM;
}
@@ -1534,13 +1550,13 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
* Subsequent enqueues would lead to a memory leak.
*/
if (sk->sk_shutdown == SHUTDOWN_MASK) {
- virtio_transport_reset_no_sock(t, skb);
+ virtio_transport_reset_no_sock(t, skb, sock_net(sk));
return -ESHUTDOWN;
}
child = vsock_create_connected(sk);
if (!child) {
- virtio_transport_reset_no_sock(t, skb);
+ virtio_transport_reset_no_sock(t, skb, sock_net(sk));
return -ENOMEM;
}
@@ -1562,7 +1578,7 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
*/
if (ret || vchild->transport != &t->transport) {
release_sock(child);
- virtio_transport_reset_no_sock(t, skb);
+ virtio_transport_reset_no_sock(t, skb, sock_net(sk));
sock_put(child);
return ret;
}
@@ -1590,7 +1606,7 @@ static bool virtio_transport_valid_type(u16 type)
* lock.
*/
void virtio_transport_recv_pkt(struct virtio_transport *t,
- struct sk_buff *skb)
+ struct sk_buff *skb, struct net *net)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
struct sockaddr_vm src, dst;
@@ -1613,24 +1629,24 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
le32_to_cpu(hdr->fwd_cnt));
if (!virtio_transport_valid_type(le16_to_cpu(hdr->type))) {
- (void)virtio_transport_reset_no_sock(t, skb);
+ (void)virtio_transport_reset_no_sock(t, skb, net);
goto free_pkt;
}
/* The socket must be in connected or bound table
* otherwise send reset back
*/
- sk = vsock_find_connected_socket(&src, &dst);
+ sk = vsock_find_connected_socket_net(&src, &dst, net);
if (!sk) {
- sk = vsock_find_bound_socket(&dst);
+ sk = vsock_find_bound_socket_net(&dst, net);
if (!sk) {
- (void)virtio_transport_reset_no_sock(t, skb);
+ (void)virtio_transport_reset_no_sock(t, skb, net);
goto free_pkt;
}
}
if (virtio_transport_get_type(sk) != le16_to_cpu(hdr->type)) {
- (void)virtio_transport_reset_no_sock(t, skb);
+ (void)virtio_transport_reset_no_sock(t, skb, net);
sock_put(sk);
goto free_pkt;
}
@@ -1649,7 +1665,7 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
*/
if (sock_flag(sk, SOCK_DONE) ||
(sk->sk_state != TCP_LISTEN && vsk->transport != &t->transport)) {
- (void)virtio_transport_reset_no_sock(t, skb);
+ (void)virtio_transport_reset_no_sock(t, skb, net);
release_sock(sk);
sock_put(sk);
goto free_pkt;
@@ -1681,7 +1697,7 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
kfree_skb(skb);
break;
default:
- (void)virtio_transport_reset_no_sock(t, skb);
+ (void)virtio_transport_reset_no_sock(t, skb, net);
kfree_skb(skb);
break;
}
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
index 7eccd6708d66..00f6bbdb035a 100644
--- a/net/vmw_vsock/vmci_transport.c
+++ b/net/vmw_vsock/vmci_transport.c
@@ -646,13 +646,17 @@ static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg)
return VMCI_SUCCESS;
}
-static bool vmci_transport_stream_allow(u32 cid, u32 port)
+static bool vmci_transport_stream_allow(struct vsock_sock *vsk, u32 cid,
+ u32 port)
{
static const u32 non_socket_contexts[] = {
VMADDR_CID_LOCAL,
};
int i;
+ if (!vsock_net_mode_global(vsk))
+ return false;
+
BUILD_BUG_ON(sizeof(cid) != sizeof(*non_socket_contexts));
for (i = 0; i < ARRAY_SIZE(non_socket_contexts); i++) {
@@ -682,12 +686,10 @@ static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg)
err = VMCI_SUCCESS;
bh_process_pkt = false;
- /* Ignore incoming packets from contexts without sockets, or resources
- * that aren't vsock implementations.
+ /* Ignore incoming packets from resources that aren't vsock
+ * implementations.
*/
-
- if (!vmci_transport_stream_allow(dg->src.context, -1)
- || vmci_transport_peer_rid(dg->src.context) != dg->src.resource)
+ if (vmci_transport_peer_rid(dg->src.context) != dg->src.resource)
return VMCI_ERROR_NO_ACCESS;
if (VMCI_DG_SIZE(dg) < sizeof(*pkt))
@@ -749,6 +751,12 @@ static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg)
goto out;
}
+ /* Ignore incoming packets from contexts without sockets. */
+ if (!vmci_transport_stream_allow(vsk, dg->src.context, -1)) {
+ err = VMCI_ERROR_NO_ACCESS;
+ goto out;
+ }
+
/* We do most everything in a work queue, but let's fast path the
* notification of reads and writes to help data transfer performance.
* We can only do this if there is no process context code executing
@@ -1784,8 +1792,12 @@ out:
return err;
}
-static bool vmci_transport_dgram_allow(u32 cid, u32 port)
+static bool vmci_transport_dgram_allow(struct vsock_sock *vsk, u32 cid,
+ u32 port)
{
+ if (!vsock_net_mode_global(vsk))
+ return false;
+
if (cid == VMADDR_CID_HYPERVISOR) {
/* Registrations of PBRPC Servers do not modify VMX/Hypervisor
* state and are allowed.
diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c
index bc2ff918b315..8068d1b6e851 100644
--- a/net/vmw_vsock/vsock_loopback.c
+++ b/net/vmw_vsock/vsock_loopback.c
@@ -26,7 +26,7 @@ static u32 vsock_loopback_get_local_cid(void)
return VMADDR_CID_LOCAL;
}
-static int vsock_loopback_send_pkt(struct sk_buff *skb)
+static int vsock_loopback_send_pkt(struct sk_buff *skb, struct net *net)
{
struct vsock_loopback *vsock = &the_vsock_loopback;
int len = skb->len;
@@ -46,7 +46,15 @@ static int vsock_loopback_cancel_pkt(struct vsock_sock *vsk)
return 0;
}
-static bool vsock_loopback_seqpacket_allow(u32 remote_cid);
+static bool vsock_loopback_seqpacket_allow(struct vsock_sock *vsk,
+ u32 remote_cid);
+
+static bool vsock_loopback_stream_allow(struct vsock_sock *vsk, u32 cid,
+ u32 port)
+{
+ return true;
+}
+
static bool vsock_loopback_msgzerocopy_allow(void)
{
return true;
@@ -76,7 +84,7 @@ static struct virtio_transport loopback_transport = {
.stream_has_space = virtio_transport_stream_has_space,
.stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
.stream_is_active = virtio_transport_stream_is_active,
- .stream_allow = virtio_transport_stream_allow,
+ .stream_allow = vsock_loopback_stream_allow,
.seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
.seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
@@ -106,9 +114,10 @@ static struct virtio_transport loopback_transport = {
.send_pkt = vsock_loopback_send_pkt,
};
-static bool vsock_loopback_seqpacket_allow(u32 remote_cid)
+static bool
+vsock_loopback_seqpacket_allow(struct vsock_sock *vsk, u32 remote_cid)
{
- return true;
+ return vsock_net_mode_global(vsk);
}
static void vsock_loopback_work(struct work_struct *work)
@@ -130,7 +139,8 @@ static void vsock_loopback_work(struct work_struct *work)
*/
virtio_transport_consume_skb_sent(skb, false);
virtio_transport_deliver_tap_pkt(skb);
- virtio_transport_recv_pkt(&loopback_transport, skb);
+ virtio_transport_recv_pkt(&loopback_transport, skb,
+ sock_net(skb->sk));
}
}
diff --git a/tools/testing/selftests/vsock/settings b/tools/testing/selftests/vsock/settings
index 694d70710ff0..79b65bdf05db 100644
--- a/tools/testing/selftests/vsock/settings
+++ b/tools/testing/selftests/vsock/settings
@@ -1 +1 @@
-timeout=300
+timeout=1200
diff --git a/tools/testing/selftests/vsock/vmtest.sh b/tools/testing/selftests/vsock/vmtest.sh
index c7b270dd77a9..dc8dbe74a6d0 100755
--- a/tools/testing/selftests/vsock/vmtest.sh
+++ b/tools/testing/selftests/vsock/vmtest.sh
@@ -7,6 +7,7 @@
# * virtme-ng
# * busybox-static (used by virtme-ng)
# * qemu (used by virtme-ng)
+# * socat
#
# shellcheck disable=SC2317,SC2119
@@ -41,14 +42,119 @@ readonly KERNEL_CMDLINE="\
virtme.ssh virtme_ssh_channel=tcp virtme_ssh_user=$USER \
"
readonly LOG=$(mktemp /tmp/vsock_vmtest_XXXX.log)
-readonly TEST_NAMES=(vm_server_host_client vm_client_host_server vm_loopback)
+
+# Namespace tests must use the ns_ prefix. This is checked in check_netns() and
+# is used to determine if a test needs namespace setup before test execution.
+readonly TEST_NAMES=(
+ vm_server_host_client
+ vm_client_host_server
+ vm_loopback
+ ns_host_vsock_ns_mode_ok
+ ns_host_vsock_child_ns_mode_ok
+ ns_global_same_cid_fails
+ ns_local_same_cid_ok
+ ns_global_local_same_cid_ok
+ ns_local_global_same_cid_ok
+ ns_diff_global_host_connect_to_global_vm_ok
+ ns_diff_global_host_connect_to_local_vm_fails
+ ns_diff_global_vm_connect_to_global_host_ok
+ ns_diff_global_vm_connect_to_local_host_fails
+ ns_diff_local_host_connect_to_local_vm_fails
+ ns_diff_local_vm_connect_to_local_host_fails
+ ns_diff_global_to_local_loopback_local_fails
+ ns_diff_local_to_global_loopback_fails
+ ns_diff_local_to_local_loopback_fails
+ ns_diff_global_to_global_loopback_ok
+ ns_same_local_loopback_ok
+ ns_same_local_host_connect_to_local_vm_ok
+ ns_same_local_vm_connect_to_local_host_ok
+ ns_delete_vm_ok
+ ns_delete_host_ok
+ ns_delete_both_ok
+)
readonly TEST_DESCS=(
+ # vm_server_host_client
"Run vsock_test in server mode on the VM and in client mode on the host."
+
+ # vm_client_host_server
"Run vsock_test in client mode on the VM and in server mode on the host."
+
+ # vm_loopback
"Run vsock_test using the loopback transport in the VM."
+
+ # ns_host_vsock_ns_mode_ok
+ "Check /proc/sys/net/vsock/ns_mode strings on the host."
+
+ # ns_host_vsock_child_ns_mode_ok
+ "Check /proc/sys/net/vsock/ns_mode is read-only and child_ns_mode is writable."
+
+ # ns_global_same_cid_fails
+ "Check QEMU fails to start two VMs with same CID in two different global namespaces."
+
+ # ns_local_same_cid_ok
+ "Check QEMU successfully starts two VMs with same CID in two different local namespaces."
+
+ # ns_global_local_same_cid_ok
+ "Check QEMU successfully starts one VM in a global ns and then another VM in a local ns with the same CID."
+
+ # ns_local_global_same_cid_ok
+ "Check QEMU successfully starts one VM in a local ns and then another VM in a global ns with the same CID."
+
+ # ns_diff_global_host_connect_to_global_vm_ok
+ "Run vsock_test client in global ns with server in VM in another global ns."
+
+ # ns_diff_global_host_connect_to_local_vm_fails
+ "Run socat to test a process in a global ns fails to connect to a VM in a local ns."
+
+ # ns_diff_global_vm_connect_to_global_host_ok
+ "Run vsock_test client in VM in a global ns with server in another global ns."
+
+ # ns_diff_global_vm_connect_to_local_host_fails
+ "Run socat to test a VM in a global ns fails to connect to a host process in a local ns."
+
+ # ns_diff_local_host_connect_to_local_vm_fails
+ "Run socat to test a host process in a local ns fails to connect to a VM in another local ns."
+
+ # ns_diff_local_vm_connect_to_local_host_fails
+ "Run socat to test a VM in a local ns fails to connect to a host process in another local ns."
+
+ # ns_diff_global_to_local_loopback_local_fails
+ "Run socat to test a loopback vsock in a global ns fails to connect to a vsock in a local ns."
+
+ # ns_diff_local_to_global_loopback_fails
+ "Run socat to test a loopback vsock in a local ns fails to connect to a vsock in a global ns."
+
+ # ns_diff_local_to_local_loopback_fails
+ "Run socat to test a loopback vsock in a local ns fails to connect to a vsock in another local ns."
+
+ # ns_diff_global_to_global_loopback_ok
+ "Run socat to test a loopback vsock in a global ns successfully connects to a vsock in another global ns."
+
+ # ns_same_local_loopback_ok
+ "Run socat to test a loopback vsock in a local ns successfully connects to a vsock in the same ns."
+
+ # ns_same_local_host_connect_to_local_vm_ok
+ "Run vsock_test client in a local ns with server in VM in same ns."
+
+ # ns_same_local_vm_connect_to_local_host_ok
+ "Run vsock_test client in VM in a local ns with server in same ns."
+
+ # ns_delete_vm_ok
+ "Check that deleting the VM's namespace does not break the socket connection"
+
+ # ns_delete_host_ok
+ "Check that deleting the host's namespace does not break the socket connection"
+
+ # ns_delete_both_ok
+ "Check that deleting the VM and host's namespaces does not break the socket connection"
)
-readonly USE_SHARED_VM=(vm_server_host_client vm_client_host_server vm_loopback)
+readonly USE_SHARED_VM=(
+ vm_server_host_client
+ vm_client_host_server
+ vm_loopback
+)
+readonly NS_MODES=("local" "global")
VERBOSE=0
@@ -71,7 +177,7 @@ usage() {
for ((i = 0; i < ${#TEST_NAMES[@]}; i++)); do
name=${TEST_NAMES[${i}]}
desc=${TEST_DESCS[${i}]}
- printf "\t%-35s%-35s\n" "${name}" "${desc}"
+ printf "\t%-55s%-35s\n" "${name}" "${desc}"
done
echo
@@ -103,13 +209,55 @@ check_result() {
fi
}
+add_namespaces() {
+ local orig_mode
+ orig_mode=$(cat /proc/sys/net/vsock/child_ns_mode)
+
+ for mode in "${NS_MODES[@]}"; do
+ echo "${mode}" > /proc/sys/net/vsock/child_ns_mode
+ ip netns add "${mode}0" 2>/dev/null
+ ip netns add "${mode}1" 2>/dev/null
+ done
+
+ echo "${orig_mode}" > /proc/sys/net/vsock/child_ns_mode
+}
+
+init_namespaces() {
+ for mode in "${NS_MODES[@]}"; do
+ # we need lo for qemu port forwarding
+ ip netns exec "${mode}0" ip link set dev lo up
+ ip netns exec "${mode}1" ip link set dev lo up
+ done
+}
+
+del_namespaces() {
+ for mode in "${NS_MODES[@]}"; do
+ ip netns del "${mode}0" &>/dev/null
+ ip netns del "${mode}1" &>/dev/null
+ log_host "removed ns ${mode}0"
+ log_host "removed ns ${mode}1"
+ done
+}
+
vm_ssh() {
- ssh -q -o UserKnownHostsFile=/dev/null -p ${SSH_HOST_PORT} localhost "$@"
+ local ns_exec
+
+ if [[ "${1}" == init_ns ]]; then
+ ns_exec=""
+ else
+ ns_exec="ip netns exec ${1}"
+ fi
+
+ shift
+
+ ${ns_exec} ssh -q -o UserKnownHostsFile=/dev/null -p "${SSH_HOST_PORT}" localhost "$@"
+
return $?
}
cleanup() {
terminate_pidfiles "${!PIDFILES[@]}"
+ del_namespaces
}
check_args() {
@@ -139,7 +287,7 @@ check_args() {
}
check_deps() {
- for dep in vng ${QEMU} busybox pkill ssh; do
+ for dep in vng ${QEMU} busybox pkill ssh ss socat; do
if [[ ! -x $(command -v "${dep}") ]]; then
echo -e "skip: dependency ${dep} not found!\n"
exit "${KSFT_SKIP}"
@@ -153,6 +301,20 @@ check_deps() {
fi
}
+check_netns() {
+ local tname=$1
+
+ # If the test requires NS support, check if NS support exists
+ # using /proc/self/ns
+ if [[ "${tname}" =~ ^ns_ ]] &&
+ [[ ! -e /proc/self/ns ]]; then
+ log_host "No NS support detected for test ${tname}"
+ return 1
+ fi
+
+ return 0
+}
+
check_vng() {
local tested_versions
local version
@@ -176,6 +338,20 @@ check_vng() {
fi
}
+check_socat() {
+ local support_string
+
+ support_string="$(socat -V)"
+
+ if [[ "${support_string}" != *"WITH_VSOCK 1"* ]]; then
+ die "err: socat is missing vsock support"
+ fi
+
+ if [[ "${support_string}" != *"WITH_UNIX 1"* ]]; then
+ die "err: socat is missing unix support"
+ fi
+}
+
handle_build() {
if [[ ! "${BUILD}" -eq 1 ]]; then
return
@@ -224,12 +400,22 @@ terminate_pidfiles() {
done
}
+terminate_pids() {
+ local pid
+
+ for pid in "$@"; do
+ kill -SIGTERM "${pid}" &>/dev/null || :
+ done
+}
+
vm_start() {
local pidfile=$1
+ local ns=$2
local logfile=/dev/null
local verbose_opt=""
local kernel_opt=""
local qemu_opts=""
+ local ns_exec=""
local qemu
qemu=$(command -v "${QEMU}")
@@ -250,7 +436,11 @@ vm_start() {
kernel_opt="${KERNEL_CHECKOUT}"
fi
- vng \
+ if [[ "${ns}" != "init_ns" ]]; then
+ ns_exec="ip netns exec ${ns}"
+ fi
+
+ ${ns_exec} vng \
--run \
${kernel_opt} \
${verbose_opt} \
@@ -265,6 +455,7 @@ vm_start() {
}
vm_wait_for_ssh() {
+ local ns=$1
local i
i=0
@@ -272,7 +463,8 @@ vm_wait_for_ssh() {
if [[ ${i} -gt ${WAIT_PERIOD_MAX} ]]; then
die "Timed out waiting for guest ssh"
fi
- if vm_ssh -- true; then
+
+ if vm_ssh "${ns}" -- true; then
break
fi
i=$(( i + 1 ))
@@ -286,50 +478,107 @@ wait_for_listener()
local port=$1
local interval=$2
local max_intervals=$3
- local protocol=tcp
- local pattern
+ local protocol=$4
local i
- pattern=":$(printf "%04X" "${port}") "
-
- # for tcp protocol additionally check the socket state
- [ "${protocol}" = "tcp" ] && pattern="${pattern}0A"
-
for i in $(seq "${max_intervals}"); do
- if awk -v pattern="${pattern}" \
- 'BEGIN {rc=1} $2" "$4 ~ pattern {rc=0} END {exit rc}' \
- /proc/net/"${protocol}"*; then
+ case "${protocol}" in
+ tcp)
+ if ss --listening --tcp --numeric | grep -q ":${port} "; then
+ break
+ fi
+ ;;
+ vsock)
+ if ss --listening --vsock --numeric | grep -q ":${port} "; then
+ break
+ fi
+ ;;
+ unix)
+ # For unix sockets, port is actually the socket path
+ if ss --listening --unix | grep -q "${port}"; then
+ break
+ fi
+ ;;
+ *)
+ echo "Unknown protocol: ${protocol}" >&2
break
- fi
+ ;;
+ esac
sleep "${interval}"
done
}
vm_wait_for_listener() {
- local port=$1
+ local ns=$1
+ local port=$2
+ local protocol=$3
- vm_ssh <<EOF
+ vm_ssh "${ns}" <<EOF
$(declare -f wait_for_listener)
-wait_for_listener ${port} ${WAIT_PERIOD} ${WAIT_PERIOD_MAX}
+wait_for_listener ${port} ${WAIT_PERIOD} ${WAIT_PERIOD_MAX} ${protocol}
EOF
}
host_wait_for_listener() {
- local port=$1
+ local ns=$1
+ local port=$2
+ local protocol=$3
+
+ if [[ "${ns}" == "init_ns" ]]; then
+ wait_for_listener "${port}" "${WAIT_PERIOD}" "${WAIT_PERIOD_MAX}" "${protocol}"
+ else
+ ip netns exec "${ns}" bash <<-EOF
+ $(declare -f wait_for_listener)
+ wait_for_listener ${port} ${WAIT_PERIOD} ${WAIT_PERIOD_MAX} ${protocol}
+ EOF
+ fi
+}
+
+vm_dmesg_oops_count() {
+ local ns=$1
+
+ vm_ssh "${ns}" -- dmesg 2>/dev/null | grep -c -i 'Oops'
+}
+
+vm_dmesg_warn_count() {
+ local ns=$1
+
+ vm_ssh "${ns}" -- dmesg --level=warn 2>/dev/null | grep -c -i 'vsock'
+}
+
+vm_dmesg_check() {
+ local pidfile=$1
+ local ns=$2
+ local oops_before=$3
+ local warn_before=$4
+ local oops_after warn_after
+
+ oops_after=$(vm_dmesg_oops_count "${ns}")
+ if [[ "${oops_after}" -gt "${oops_before}" ]]; then
+ echo "FAIL: kernel oops detected on vm in ns ${ns}" | log_host
+ return 1
+ fi
+
+ warn_after=$(vm_dmesg_warn_count "${ns}")
+ if [[ "${warn_after}" -gt "${warn_before}" ]]; then
+ echo "FAIL: kernel warning detected on vm in ns ${ns}" | log_host
+ return 1
+ fi
- wait_for_listener "${port}" "${WAIT_PERIOD}" "${WAIT_PERIOD_MAX}"
+ return 0
}
vm_vsock_test() {
- local host=$1
- local cid=$2
- local port=$3
+ local ns=$1
+ local host=$2
+ local cid=$3
+ local port=$4
local rc
# log output and use pipefail to respect vsock_test errors
set -o pipefail
if [[ "${host}" != server ]]; then
- vm_ssh -- "${VSOCK_TEST}" \
+ vm_ssh "${ns}" -- "${VSOCK_TEST}" \
--mode=client \
--control-host="${host}" \
--peer-cid="${cid}" \
@@ -337,7 +586,7 @@ vm_vsock_test() {
2>&1 | log_guest
rc=$?
else
- vm_ssh -- "${VSOCK_TEST}" \
+ vm_ssh "${ns}" -- "${VSOCK_TEST}" \
--mode=server \
--peer-cid="${cid}" \
--control-port="${port}" \
@@ -349,7 +598,7 @@ vm_vsock_test() {
return $rc
fi
- vm_wait_for_listener "${port}"
+ vm_wait_for_listener "${ns}" "${port}" "tcp"
rc=$?
fi
set +o pipefail
@@ -358,25 +607,35 @@ vm_vsock_test() {
}
host_vsock_test() {
- local host=$1
- local cid=$2
- local port=$3
+ local ns=$1
+ local host=$2
+ local cid=$3
+ local port=$4
+ shift 4
+ local extra_args=("$@")
local rc
+ local cmd="${VSOCK_TEST}"
+ if [[ "${ns}" != "init_ns" ]]; then
+ cmd="ip netns exec ${ns} ${cmd}"
+ fi
+
# log output and use pipefail to respect vsock_test errors
set -o pipefail
if [[ "${host}" != server ]]; then
- ${VSOCK_TEST} \
+ ${cmd} \
--mode=client \
--peer-cid="${cid}" \
--control-host="${host}" \
- --control-port="${port}" 2>&1 | log_host
+ --control-port="${port}" \
+ "${extra_args[@]}" 2>&1 | log_host
rc=$?
else
- ${VSOCK_TEST} \
+ ${cmd} \
--mode=server \
--peer-cid="${cid}" \
- --control-port="${port}" 2>&1 | log_host &
+ --control-port="${port}" \
+ "${extra_args[@]}" 2>&1 | log_host &
rc=$?
if [[ $rc -ne 0 ]]; then
@@ -384,7 +643,7 @@ host_vsock_test() {
return $rc
fi
- host_wait_for_listener "${port}"
+ host_wait_for_listener "${ns}" "${port}" "tcp"
rc=$?
fi
set +o pipefail
@@ -427,12 +686,584 @@ log_guest() {
LOG_PREFIX=guest log "$@"
}
+ns_get_mode() {
+ local ns=$1
+
+ ip netns exec "${ns}" cat /proc/sys/net/vsock/ns_mode 2>/dev/null
+}
+
+test_ns_host_vsock_ns_mode_ok() {
+ for mode in "${NS_MODES[@]}"; do
+ local actual
+
+ actual=$(ns_get_mode "${mode}0")
+ if [[ "${actual}" != "${mode}" ]]; then
+ log_host "expected mode ${mode}, got ${actual}"
+ return "${KSFT_FAIL}"
+ fi
+ done
+
+ return "${KSFT_PASS}"
+}
+
+test_ns_diff_global_host_connect_to_global_vm_ok() {
+ local oops_before warn_before
+ local pids pid pidfile
+ local ns0 ns1 port
+ declare -a pids
+ local unixfile
+ ns0="global0"
+ ns1="global1"
+ port=1234
+ local rc
+
+ init_namespaces
+
+ pidfile="$(create_pidfile)"
+
+ if ! vm_start "${pidfile}" "${ns0}"; then
+ return "${KSFT_FAIL}"
+ fi
+
+ vm_wait_for_ssh "${ns0}"
+ oops_before=$(vm_dmesg_oops_count "${ns0}")
+ warn_before=$(vm_dmesg_warn_count "${ns0}")
+
+ unixfile=$(mktemp -u /tmp/XXXX.sock)
+ ip netns exec "${ns1}" \
+ socat TCP-LISTEN:"${TEST_HOST_PORT}",fork \
+ UNIX-CONNECT:"${unixfile}" &
+ pids+=($!)
+ host_wait_for_listener "${ns1}" "${TEST_HOST_PORT}" "tcp"
+
+ ip netns exec "${ns0}" socat UNIX-LISTEN:"${unixfile}",fork \
+ TCP-CONNECT:localhost:"${TEST_HOST_PORT}" &
+ pids+=($!)
+ host_wait_for_listener "${ns0}" "${unixfile}" "unix"
+
+ vm_vsock_test "${ns0}" "server" 2 "${TEST_GUEST_PORT}"
+ vm_wait_for_listener "${ns0}" "${TEST_GUEST_PORT}" "tcp"
+ host_vsock_test "${ns1}" "127.0.0.1" "${VSOCK_CID}" "${TEST_HOST_PORT}"
+ rc=$?
+
+ vm_dmesg_check "${pidfile}" "${ns0}" "${oops_before}" "${warn_before}"
+ dmesg_rc=$?
+
+ terminate_pids "${pids[@]}"
+ terminate_pidfiles "${pidfile}"
+
+ if [[ "${rc}" -ne 0 ]] || [[ "${dmesg_rc}" -ne 0 ]]; then
+ return "${KSFT_FAIL}"
+ fi
+
+ return "${KSFT_PASS}"
+}
+
+test_ns_diff_global_host_connect_to_local_vm_fails() {
+ local oops_before warn_before
+ local ns0="global0"
+ local ns1="local0"
+ local port=12345
+ local dmesg_rc
+ local pidfile
+ local result
+ local pid
+
+ init_namespaces
+
+ outfile=$(mktemp)
+
+ pidfile="$(create_pidfile)"
+ if ! vm_start "${pidfile}" "${ns1}"; then
+ log_host "failed to start vm (cid=${VSOCK_CID}, ns=${ns0})"
+ return "${KSFT_FAIL}"
+ fi
+
+ vm_wait_for_ssh "${ns1}"
+ oops_before=$(vm_dmesg_oops_count "${ns1}")
+ warn_before=$(vm_dmesg_warn_count "${ns1}")
+
+ vm_ssh "${ns1}" -- socat VSOCK-LISTEN:"${port}" STDOUT > "${outfile}" &
+ vm_wait_for_listener "${ns1}" "${port}" "vsock"
+ echo TEST | ip netns exec "${ns0}" \
+ socat STDIN VSOCK-CONNECT:"${VSOCK_CID}":"${port}" 2>/dev/null
+
+ vm_dmesg_check "${pidfile}" "${ns1}" "${oops_before}" "${warn_before}"
+ dmesg_rc=$?
+
+ terminate_pidfiles "${pidfile}"
+ result=$(cat "${outfile}")
+ rm -f "${outfile}"
+
+ if [[ "${result}" == "TEST" ]] || [[ "${dmesg_rc}" -ne 0 ]]; then
+ return "${KSFT_FAIL}"
+ fi
+
+ return "${KSFT_PASS}"
+}
+
+test_ns_diff_global_vm_connect_to_global_host_ok() {
+ local oops_before warn_before
+ local ns0="global0"
+ local ns1="global1"
+ local port=12345
+ local unixfile
+ local dmesg_rc
+ local pidfile
+ local pids
+ local rc
+
+ init_namespaces
+
+ declare -a pids
+
+ log_host "Setup socat bridge from ns ${ns0} to ns ${ns1} over port ${port}"
+
+ unixfile=$(mktemp -u /tmp/XXXX.sock)
+
+ ip netns exec "${ns0}" \
+ socat TCP-LISTEN:"${port}" UNIX-CONNECT:"${unixfile}" &
+ pids+=($!)
+ host_wait_for_listener "${ns0}" "${port}" "tcp"
+
+ ip netns exec "${ns1}" \
+ socat UNIX-LISTEN:"${unixfile}" TCP-CONNECT:127.0.0.1:"${port}" &
+ pids+=($!)
+ host_wait_for_listener "${ns1}" "${unixfile}" "unix"
+
+ log_host "Launching ${VSOCK_TEST} in ns ${ns1}"
+ host_vsock_test "${ns1}" "server" "${VSOCK_CID}" "${port}"
+
+ pidfile="$(create_pidfile)"
+ if ! vm_start "${pidfile}" "${ns0}"; then
+ log_host "failed to start vm (cid=${cid}, ns=${ns0})"
+ terminate_pids "${pids[@]}"
+ rm -f "${unixfile}"
+ return "${KSFT_FAIL}"
+ fi
+
+ vm_wait_for_ssh "${ns0}"
+
+ oops_before=$(vm_dmesg_oops_count "${ns0}")
+ warn_before=$(vm_dmesg_warn_count "${ns0}")
+
+ vm_vsock_test "${ns0}" "10.0.2.2" 2 "${port}"
+ rc=$?
+
+ vm_dmesg_check "${pidfile}" "${ns0}" "${oops_before}" "${warn_before}"
+ dmesg_rc=$?
+
+ terminate_pidfiles "${pidfile}"
+ terminate_pids "${pids[@]}"
+ rm -f "${unixfile}"
+
+ if [[ "${rc}" -ne 0 ]] || [[ "${dmesg_rc}" -ne 0 ]]; then
+ return "${KSFT_FAIL}"
+ fi
+
+ return "${KSFT_PASS}"
+
+}
+
+test_ns_diff_global_vm_connect_to_local_host_fails() {
+ local ns0="global0"
+ local ns1="local0"
+ local port=12345
+ local oops_before warn_before
+ local dmesg_rc
+ local pidfile
+ local result
+ local pid
+
+ init_namespaces
+
+ log_host "Launching socat in ns ${ns1}"
+ outfile=$(mktemp)
+
+ ip netns exec "${ns1}" socat VSOCK-LISTEN:"${port}" STDOUT &> "${outfile}" &
+ pid=$!
+ host_wait_for_listener "${ns1}" "${port}" "vsock"
+
+ pidfile="$(create_pidfile)"
+ if ! vm_start "${pidfile}" "${ns0}"; then
+ log_host "failed to start vm (cid=${cid}, ns=${ns0})"
+ terminate_pids "${pid}"
+ rm -f "${outfile}"
+ return "${KSFT_FAIL}"
+ fi
+
+ vm_wait_for_ssh "${ns0}"
+
+ oops_before=$(vm_dmesg_oops_count "${ns0}")
+ warn_before=$(vm_dmesg_warn_count "${ns0}")
+
+ vm_ssh "${ns0}" -- \
+ bash -c "echo TEST | socat STDIN VSOCK-CONNECT:2:${port}" 2>&1 | log_guest
+
+ vm_dmesg_check "${pidfile}" "${ns0}" "${oops_before}" "${warn_before}"
+ dmesg_rc=$?
+
+ terminate_pidfiles "${pidfile}"
+ terminate_pids "${pid}"
+
+ result=$(cat "${outfile}")
+ rm -f "${outfile}"
+
+ if [[ "${result}" != TEST ]] && [[ "${dmesg_rc}" -eq 0 ]]; then
+ return "${KSFT_PASS}"
+ fi
+
+ return "${KSFT_FAIL}"
+}
+
+test_ns_diff_local_host_connect_to_local_vm_fails() {
+ local ns0="local0"
+ local ns1="local1"
+ local port=12345
+ local oops_before warn_before
+ local dmesg_rc
+ local pidfile
+ local result
+ local pid
+
+ init_namespaces
+
+ outfile=$(mktemp)
+
+ pidfile="$(create_pidfile)"
+ if ! vm_start "${pidfile}" "${ns1}"; then
+ log_host "failed to start vm (cid=${cid}, ns=${ns0})"
+ return "${KSFT_FAIL}"
+ fi
+
+ vm_wait_for_ssh "${ns1}"
+ oops_before=$(vm_dmesg_oops_count "${ns1}")
+ warn_before=$(vm_dmesg_warn_count "${ns1}")
+
+ vm_ssh "${ns1}" -- socat VSOCK-LISTEN:"${port}" STDOUT > "${outfile}" &
+ vm_wait_for_listener "${ns1}" "${port}" "vsock"
+
+ echo TEST | ip netns exec "${ns0}" \
+ socat STDIN VSOCK-CONNECT:"${VSOCK_CID}":"${port}" 2>/dev/null
+
+ vm_dmesg_check "${pidfile}" "${ns1}" "${oops_before}" "${warn_before}"
+ dmesg_rc=$?
+
+ terminate_pidfiles "${pidfile}"
+
+ result=$(cat "${outfile}")
+ rm -f "${outfile}"
+
+ if [[ "${result}" != TEST ]] && [[ "${dmesg_rc}" -eq 0 ]]; then
+ return "${KSFT_PASS}"
+ fi
+
+ return "${KSFT_FAIL}"
+}
+
+test_ns_diff_local_vm_connect_to_local_host_fails() {
+ local oops_before warn_before
+ local ns0="local0"
+ local ns1="local1"
+ local port=12345
+ local dmesg_rc
+ local pidfile
+ local result
+ local pid
+
+ init_namespaces
+
+ log_host "Launching socat in ns ${ns1}"
+ outfile=$(mktemp)
+ ip netns exec "${ns1}" socat VSOCK-LISTEN:"${port}" STDOUT &> "${outfile}" &
+ pid=$!
+ host_wait_for_listener "${ns1}" "${port}" "vsock"
+
+ pidfile="$(create_pidfile)"
+ if ! vm_start "${pidfile}" "${ns0}"; then
+ log_host "failed to start vm (cid=${cid}, ns=${ns0})"
+ rm -f "${outfile}"
+ return "${KSFT_FAIL}"
+ fi
+
+ vm_wait_for_ssh "${ns0}"
+ oops_before=$(vm_dmesg_oops_count "${ns0}")
+ warn_before=$(vm_dmesg_warn_count "${ns0}")
+
+ vm_ssh "${ns0}" -- \
+ bash -c "echo TEST | socat STDIN VSOCK-CONNECT:2:${port}" 2>&1 | log_guest
+
+ vm_dmesg_check "${pidfile}" "${ns0}" "${oops_before}" "${warn_before}"
+ dmesg_rc=$?
+
+ terminate_pidfiles "${pidfile}"
+ terminate_pids "${pid}"
+
+ result=$(cat "${outfile}")
+ rm -f "${outfile}"
+
+ if [[ "${result}" != TEST ]] && [[ "${dmesg_rc}" -eq 0 ]]; then
+ return "${KSFT_PASS}"
+ fi
+
+ return "${KSFT_FAIL}"
+}
+
+__test_loopback_two_netns() {
+ local ns0=$1
+ local ns1=$2
+ local port=12345
+ local result
+ local pid
+
+ modprobe vsock_loopback &> /dev/null || :
+
+ log_host "Launching socat in ns ${ns1}"
+ outfile=$(mktemp)
+
+ ip netns exec "${ns1}" socat VSOCK-LISTEN:"${port}" STDOUT > "${outfile}" 2>/dev/null &
+ pid=$!
+ host_wait_for_listener "${ns1}" "${port}" "vsock"
+
+ log_host "Launching socat in ns ${ns0}"
+ echo TEST | ip netns exec "${ns0}" socat STDIN VSOCK-CONNECT:1:"${port}" 2>/dev/null
+ terminate_pids "${pid}"
+
+ result=$(cat "${outfile}")
+ rm -f "${outfile}"
+
+ if [[ "${result}" == TEST ]]; then
+ return 0
+ fi
+
+ return 1
+}
+
+test_ns_diff_global_to_local_loopback_local_fails() {
+ init_namespaces
+
+ if ! __test_loopback_two_netns "global0" "local0"; then
+ return "${KSFT_PASS}"
+ fi
+
+ return "${KSFT_FAIL}"
+}
+
+test_ns_diff_local_to_global_loopback_fails() {
+ init_namespaces
+
+ if ! __test_loopback_two_netns "local0" "global0"; then
+ return "${KSFT_PASS}"
+ fi
+
+ return "${KSFT_FAIL}"
+}
+
+test_ns_diff_local_to_local_loopback_fails() {
+ init_namespaces
+
+ if ! __test_loopback_two_netns "local0" "local1"; then
+ return "${KSFT_PASS}"
+ fi
+
+ return "${KSFT_FAIL}"
+}
+
+test_ns_diff_global_to_global_loopback_ok() {
+ init_namespaces
+
+ if __test_loopback_two_netns "global0" "global1"; then
+ return "${KSFT_PASS}"
+ fi
+
+ return "${KSFT_FAIL}"
+}
+
+test_ns_same_local_loopback_ok() {
+ init_namespaces
+
+ if __test_loopback_two_netns "local0" "local0"; then
+ return "${KSFT_PASS}"
+ fi
+
+ return "${KSFT_FAIL}"
+}
+
+test_ns_same_local_host_connect_to_local_vm_ok() {
+ local oops_before warn_before
+ local ns="local0"
+ local port=1234
+ local dmesg_rc
+ local pidfile
+ local rc
+
+ init_namespaces
+
+ pidfile="$(create_pidfile)"
+
+ if ! vm_start "${pidfile}" "${ns}"; then
+ return "${KSFT_FAIL}"
+ fi
+
+ vm_wait_for_ssh "${ns}"
+ oops_before=$(vm_dmesg_oops_count "${ns}")
+ warn_before=$(vm_dmesg_warn_count "${ns}")
+
+ vm_vsock_test "${ns}" "server" 2 "${TEST_GUEST_PORT}"
+
+ # Skip test 29 (transport release use-after-free): This test attempts
+ # binding both G2H and H2G CIDs. Because virtio-vsock (G2H) doesn't
+ # support local namespaces the test will fail when
+ # transport_g2h->stream_allow() returns false. This edge case only
+ # happens for vsock_test in client mode on the host in a local
+ # namespace. This is a false positive.
+ host_vsock_test "${ns}" "127.0.0.1" "${VSOCK_CID}" "${TEST_HOST_PORT}" --skip=29
+ rc=$?
+
+ vm_dmesg_check "${pidfile}" "${ns}" "${oops_before}" "${warn_before}"
+ dmesg_rc=$?
+
+ terminate_pidfiles "${pidfile}"
+
+ if [[ "${rc}" -ne 0 ]] || [[ "${dmesg_rc}" -ne 0 ]]; then
+ return "${KSFT_FAIL}"
+ fi
+
+ return "${KSFT_PASS}"
+}
+
+test_ns_same_local_vm_connect_to_local_host_ok() {
+ local oops_before warn_before
+ local ns="local0"
+ local port=1234
+ local dmesg_rc
+ local pidfile
+ local rc
+
+ init_namespaces
+
+ pidfile="$(create_pidfile)"
+
+ if ! vm_start "${pidfile}" "${ns}"; then
+ return "${KSFT_FAIL}"
+ fi
+
+ vm_wait_for_ssh "${ns}"
+ oops_before=$(vm_dmesg_oops_count "${ns}")
+ warn_before=$(vm_dmesg_warn_count "${ns}")
+
+ host_vsock_test "${ns}" "server" "${VSOCK_CID}" "${port}"
+ vm_vsock_test "${ns}" "10.0.2.2" 2 "${port}"
+ rc=$?
+
+ vm_dmesg_check "${pidfile}" "${ns}" "${oops_before}" "${warn_before}"
+ dmesg_rc=$?
+
+ terminate_pidfiles "${pidfile}"
+
+ if [[ "${rc}" -ne 0 ]] || [[ "${dmesg_rc}" -ne 0 ]]; then
+ return "${KSFT_FAIL}"
+ fi
+
+ return "${KSFT_PASS}"
+}
+
+namespaces_can_boot_same_cid() {
+ local ns0=$1
+ local ns1=$2
+ local pidfile1 pidfile2
+ local rc
+
+ pidfile1="$(create_pidfile)"
+
+ # The first VM should be able to start. If it can't then we have
+ # problems and need to return non-zero.
+ if ! vm_start "${pidfile1}" "${ns0}"; then
+ return 1
+ fi
+
+ pidfile2="$(create_pidfile)"
+ vm_start "${pidfile2}" "${ns1}"
+ rc=$?
+ terminate_pidfiles "${pidfile1}" "${pidfile2}"
+
+ return "${rc}"
+}
+
+test_ns_global_same_cid_fails() {
+ init_namespaces
+
+ if namespaces_can_boot_same_cid "global0" "global1"; then
+ return "${KSFT_FAIL}"
+ fi
+
+ return "${KSFT_PASS}"
+}
+
+test_ns_local_global_same_cid_ok() {
+ init_namespaces
+
+ if namespaces_can_boot_same_cid "local0" "global0"; then
+ return "${KSFT_PASS}"
+ fi
+
+ return "${KSFT_FAIL}"
+}
+
+test_ns_global_local_same_cid_ok() {
+ init_namespaces
+
+ if namespaces_can_boot_same_cid "global0" "local0"; then
+ return "${KSFT_PASS}"
+ fi
+
+ return "${KSFT_FAIL}"
+}
+
+test_ns_local_same_cid_ok() {
+ init_namespaces
+
+ if namespaces_can_boot_same_cid "local0" "local1"; then
+ return "${KSFT_PASS}"
+ fi
+
+ return "${KSFT_FAIL}"
+}
+
+test_ns_host_vsock_child_ns_mode_ok() {
+ local orig_mode
+ local rc
+
+ orig_mode=$(cat /proc/sys/net/vsock/child_ns_mode)
+
+ rc="${KSFT_PASS}"
+ for mode in "${NS_MODES[@]}"; do
+ local ns="${mode}0"
+
+ if echo "${mode}" 2>/dev/null > /proc/sys/net/vsock/ns_mode; then
+ log_host "ns_mode should be read-only but write succeeded"
+ rc="${KSFT_FAIL}"
+ continue
+ fi
+
+ if ! echo "${mode}" > /proc/sys/net/vsock/child_ns_mode; then
+ log_host "child_ns_mode should be writable to ${mode}"
+ rc="${KSFT_FAIL}"
+ continue
+ fi
+ done
+
+ echo "${orig_mode}" > /proc/sys/net/vsock/child_ns_mode
+
+ return "${rc}"
+}
+
test_vm_server_host_client() {
- if ! vm_vsock_test "server" 2 "${TEST_GUEST_PORT}"; then
+ if ! vm_vsock_test "init_ns" "server" 2 "${TEST_GUEST_PORT}"; then
return "${KSFT_FAIL}"
fi
- if ! host_vsock_test "127.0.0.1" "${VSOCK_CID}" "${TEST_HOST_PORT}"; then
+ if ! host_vsock_test "init_ns" "127.0.0.1" "${VSOCK_CID}" "${TEST_HOST_PORT}"; then
return "${KSFT_FAIL}"
fi
@@ -440,11 +1271,11 @@ test_vm_server_host_client() {
}
test_vm_client_host_server() {
- if ! host_vsock_test "server" "${VSOCK_CID}" "${TEST_HOST_PORT_LISTENER}"; then
+ if ! host_vsock_test "init_ns" "server" "${VSOCK_CID}" "${TEST_HOST_PORT_LISTENER}"; then
return "${KSFT_FAIL}"
fi
- if ! vm_vsock_test "10.0.2.2" 2 "${TEST_HOST_PORT_LISTENER}"; then
+ if ! vm_vsock_test "init_ns" "10.0.2.2" 2 "${TEST_HOST_PORT_LISTENER}"; then
return "${KSFT_FAIL}"
fi
@@ -454,19 +1285,92 @@ test_vm_client_host_server() {
test_vm_loopback() {
local port=60000 # non-forwarded local port
- vm_ssh -- modprobe vsock_loopback &> /dev/null || :
+ vm_ssh "init_ns" -- modprobe vsock_loopback &> /dev/null || :
- if ! vm_vsock_test "server" 1 "${port}"; then
+ if ! vm_vsock_test "init_ns" "server" 1 "${port}"; then
return "${KSFT_FAIL}"
fi
- if ! vm_vsock_test "127.0.0.1" 1 "${port}"; then
+
+ if ! vm_vsock_test "init_ns" "127.0.0.1" 1 "${port}"; then
return "${KSFT_FAIL}"
fi
return "${KSFT_PASS}"
}
+check_ns_delete_doesnt_break_connection() {
+ local pipefile pidfile outfile
+ local ns0="global0"
+ local ns1="global1"
+ local port=12345
+ local pids=()
+ local rc=0
+
+ init_namespaces
+
+ pidfile="$(create_pidfile)"
+ if ! vm_start "${pidfile}" "${ns0}"; then
+ return "${KSFT_FAIL}"
+ fi
+ vm_wait_for_ssh "${ns0}"
+
+ outfile=$(mktemp)
+ vm_ssh "${ns0}" -- \
+ socat VSOCK-LISTEN:"${port}",fork STDOUT > "${outfile}" 2>/dev/null &
+ pids+=($!)
+ vm_wait_for_listener "${ns0}" "${port}" "vsock"
+
+ # We use a pipe here so that we can echo into the pipe instead of using
+ # socat and a unix socket file. We just need a name for the pipe (not a
+ # regular file) so use -u.
+ pipefile=$(mktemp -u /tmp/vmtest_pipe_XXXX)
+ ip netns exec "${ns1}" \
+ socat PIPE:"${pipefile}" VSOCK-CONNECT:"${VSOCK_CID}":"${port}" &
+ pids+=($!)
+
+ timeout "${WAIT_PERIOD}" \
+ bash -c 'while [[ ! -e '"${pipefile}"' ]]; do sleep 1; done; exit 0'
+
+ if [[ "$1" == "vm" ]]; then
+ ip netns del "${ns0}"
+ elif [[ "$1" == "host" ]]; then
+ ip netns del "${ns1}"
+ elif [[ "$1" == "both" ]]; then
+ ip netns del "${ns0}"
+ ip netns del "${ns1}"
+ fi
+
+ echo "TEST" > "${pipefile}"
+
+ timeout "${WAIT_PERIOD}" \
+ bash -c 'while [[ ! -s '"${outfile}"' ]]; do sleep 1; done; exit 0'
+
+ if grep -q "TEST" "${outfile}"; then
+ rc="${KSFT_PASS}"
+ else
+ rc="${KSFT_FAIL}"
+ fi
+
+ terminate_pidfiles "${pidfile}"
+ terminate_pids "${pids[@]}"
+ rm -f "${outfile}" "${pipefile}"
+
+ return "${rc}"
+}
+
+test_ns_delete_vm_ok() {
+ check_ns_delete_doesnt_break_connection "vm"
+}
+
+test_ns_delete_host_ok() {
+ check_ns_delete_doesnt_break_connection "host"
+}
+
+test_ns_delete_both_ok() {
+ check_ns_delete_doesnt_break_connection "both"
+}
+
shared_vm_test() {
local tname
@@ -499,6 +1403,11 @@ run_shared_vm_tests() {
continue
fi
+ if ! check_netns "${arg}"; then
+ check_result "${KSFT_SKIP}" "${arg}"
+ continue
+ fi
+
run_shared_vm_test "${arg}"
check_result "$?" "${arg}"
done
@@ -518,8 +1427,8 @@ run_shared_vm_test() {
host_oops_cnt_before=$(dmesg | grep -c -i 'Oops')
host_warn_cnt_before=$(dmesg --level=warn | grep -c -i 'vsock')
- vm_oops_cnt_before=$(vm_ssh -- dmesg | grep -c -i 'Oops')
- vm_warn_cnt_before=$(vm_ssh -- dmesg --level=warn | grep -c -i 'vsock')
+ vm_oops_cnt_before=$(vm_dmesg_oops_count "init_ns")
+ vm_warn_cnt_before=$(vm_dmesg_warn_count "init_ns")
name=$(echo "${1}" | awk '{ print $1 }')
eval test_"${name}"
@@ -537,13 +1446,13 @@ run_shared_vm_test() {
rc=$KSFT_FAIL
fi
- vm_oops_cnt_after=$(vm_ssh -- dmesg | grep -i 'Oops' | wc -l)
+ vm_oops_cnt_after=$(vm_dmesg_oops_count "init_ns")
if [[ ${vm_oops_cnt_after} -gt ${vm_oops_cnt_before} ]]; then
echo "FAIL: kernel oops detected on vm" | log_host
rc=$KSFT_FAIL
fi
- vm_warn_cnt_after=$(vm_ssh -- dmesg --level=warn | grep -c -i 'vsock')
+ vm_warn_cnt_after=$(vm_dmesg_warn_count "init_ns")
if [[ ${vm_warn_cnt_after} -gt ${vm_warn_cnt_before} ]]; then
echo "FAIL: kernel warning detected on vm" | log_host
rc=$KSFT_FAIL
@@ -552,6 +1461,49 @@ run_shared_vm_test() {
return "${rc}"
}
+run_ns_tests() {
+ for arg in "${ARGS[@]}"; do
+ if shared_vm_test "${arg}"; then
+ continue
+ fi
+
+ if ! check_netns "${arg}"; then
+ check_result "${KSFT_SKIP}" "${arg}"
+ continue
+ fi
+
+ add_namespaces
+
+ name=$(echo "${arg}" | awk '{ print $1 }')
+ log_host "Executing test_${name}"
+
+ host_oops_before=$(dmesg 2>/dev/null | grep -c -i 'Oops')
+ host_warn_before=$(dmesg --level=warn 2>/dev/null | grep -c -i 'vsock')
+ eval test_"${name}"
+ rc=$?
+
+ host_oops_after=$(dmesg 2>/dev/null | grep -c -i 'Oops')
+ if [[ "${host_oops_after}" -gt "${host_oops_before}" ]]; then
+ echo "FAIL: kernel oops detected on host" | log_host
+ check_result "${KSFT_FAIL}" "${name}"
+ del_namespaces
+ continue
+ fi
+
+ host_warn_after=$(dmesg --level=warn 2>/dev/null | grep -c -i 'vsock')
+ if [[ "${host_warn_after}" -gt "${host_warn_before}" ]]; then
+ echo "FAIL: kernel warning detected on host" | log_host
+ check_result "${KSFT_FAIL}" "${name}"
+ del_namespaces
+ continue
+ fi
+
+ check_result "${rc}" "${name}"
+
+ del_namespaces
+ done
+}
+
BUILD=0
QEMU="qemu-system-$(uname -m)"
@@ -577,6 +1529,7 @@ fi
check_args "${ARGS[@]}"
check_deps
check_vng
+check_socat
handle_build
echo "1..${#ARGS[@]}"
@@ -589,14 +1542,16 @@ cnt_total=0
if shared_vm_tests_requested "${ARGS[@]}"; then
log_host "Booting up VM"
pidfile="$(create_pidfile)"
- vm_start "${pidfile}"
- vm_wait_for_ssh
+ vm_start "${pidfile}" "init_ns"
+ vm_wait_for_ssh "init_ns"
log_host "VM booted up"
run_shared_vm_tests "${ARGS[@]}"
terminate_pidfiles "${pidfile}"
fi
+run_ns_tests "${ARGS[@]}"
+
echo "SUMMARY: PASS=${cnt_pass} SKIP=${cnt_skip} FAIL=${cnt_fail}"
echo "Log: ${LOG}"