summaryrefslogtreecommitdiff
path: root/drivers/vhost/vsock.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vsock.c')
-rw-r--r--drivers/vhost/vsock.c79
1 files changed, 48 insertions, 31 deletions
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 34bc3ab40c6d..98ed5be132c6 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -15,6 +15,7 @@
#include <net/sock.h>
#include <linux/virtio_vsock.h>
#include <linux/vhost.h>
+#include <linux/hashtable.h>
#include <net/af_vsock.h>
#include "vhost.h"
@@ -27,14 +28,14 @@ enum {
/* Used to track all the vhost_vsock instances on the system. */
static DEFINE_SPINLOCK(vhost_vsock_lock);
-static LIST_HEAD(vhost_vsock_list);
+static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
struct vhost_vsock {
struct vhost_dev dev;
struct vhost_virtqueue vqs[2];
- /* Link to global vhost_vsock_list, protected by vhost_vsock_lock */
- struct list_head list;
+ /* Link to global vhost_vsock_hash, writes use vhost_vsock_lock */
+ struct hlist_node hash;
struct vhost_work send_pkt_work;
spinlock_t send_pkt_list_lock;
@@ -50,11 +51,14 @@ static u32 vhost_transport_get_local_cid(void)
return VHOST_VSOCK_DEFAULT_HOST_CID;
}
-static struct vhost_vsock *__vhost_vsock_get(u32 guest_cid)
+/* Callers that dereference the return value must hold vhost_vsock_lock or the
+ * RCU read lock.
+ */
+static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
{
struct vhost_vsock *vsock;
- list_for_each_entry(vsock, &vhost_vsock_list, list) {
+ hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
u32 other_cid = vsock->guest_cid;
/* Skip instances that have no CID yet */
@@ -69,17 +73,6 @@ static struct vhost_vsock *__vhost_vsock_get(u32 guest_cid)
return NULL;
}
-static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
-{
- struct vhost_vsock *vsock;
-
- spin_lock_bh(&vhost_vsock_lock);
- vsock = __vhost_vsock_get(guest_cid);
- spin_unlock_bh(&vhost_vsock_lock);
-
- return vsock;
-}
-
static void
vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
struct vhost_virtqueue *vq)
@@ -210,9 +203,12 @@ vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
struct vhost_vsock *vsock;
int len = pkt->len;
+ rcu_read_lock();
+
/* Find the vhost_vsock according to guest context id */
vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
if (!vsock) {
+ rcu_read_unlock();
virtio_transport_free_pkt(pkt);
return -ENODEV;
}
@@ -225,6 +221,8 @@ vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
spin_unlock_bh(&vsock->send_pkt_list_lock);
vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
+
+ rcu_read_unlock();
return len;
}
@@ -234,12 +232,15 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk)
struct vhost_vsock *vsock;
struct virtio_vsock_pkt *pkt, *n;
int cnt = 0;
+ int ret = -ENODEV;
LIST_HEAD(freeme);
+ rcu_read_lock();
+
/* Find the vhost_vsock according to guest context id */
vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
if (!vsock)
- return -ENODEV;
+ goto out;
spin_lock_bh(&vsock->send_pkt_list_lock);
list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
@@ -265,7 +266,10 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk)
vhost_poll_queue(&tx_vq->poll);
}
- return 0;
+ ret = 0;
+out:
+ rcu_read_unlock();
+ return ret;
}
static struct virtio_vsock_pkt *
@@ -533,10 +537,6 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
spin_lock_init(&vsock->send_pkt_list_lock);
INIT_LIST_HEAD(&vsock->send_pkt_list);
vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
-
- spin_lock_bh(&vhost_vsock_lock);
- list_add_tail(&vsock->list, &vhost_vsock_list);
- spin_unlock_bh(&vhost_vsock_lock);
return 0;
out:
@@ -563,13 +563,21 @@ static void vhost_vsock_reset_orphans(struct sock *sk)
* executing.
*/
- if (!vhost_vsock_get(vsk->remote_addr.svm_cid)) {
- sock_set_flag(sk, SOCK_DONE);
- vsk->peer_shutdown = SHUTDOWN_MASK;
- sk->sk_state = SS_UNCONNECTED;
- sk->sk_err = ECONNRESET;
- sk->sk_error_report(sk);
- }
+ /* If the peer is still valid, no need to reset connection */
+ if (vhost_vsock_get(vsk->remote_addr.svm_cid))
+ return;
+
+ /* If the close timeout is pending, let it expire. This avoids races
+ * with the timeout callback.
+ */
+ if (vsk->close_work_scheduled)
+ return;
+
+ sock_set_flag(sk, SOCK_DONE);
+ vsk->peer_shutdown = SHUTDOWN_MASK;
+ sk->sk_state = SS_UNCONNECTED;
+ sk->sk_err = ECONNRESET;
+ sk->sk_error_report(sk);
}
static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
@@ -577,9 +585,13 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
struct vhost_vsock *vsock = file->private_data;
spin_lock_bh(&vhost_vsock_lock);
- list_del(&vsock->list);
+ if (vsock->guest_cid)
+ hash_del_rcu(&vsock->hash);
spin_unlock_bh(&vhost_vsock_lock);
+ /* Wait for other CPUs to finish using vsock */
+ synchronize_rcu();
+
/* Iterating over all connections for all CIDs to find orphans is
* inefficient. Room for improvement here. */
vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
@@ -620,12 +632,17 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
/* Refuse if CID is already in use */
spin_lock_bh(&vhost_vsock_lock);
- other = __vhost_vsock_get(guest_cid);
+ other = vhost_vsock_get(guest_cid);
if (other && other != vsock) {
spin_unlock_bh(&vhost_vsock_lock);
return -EADDRINUSE;
}
+
+ if (vsock->guest_cid)
+ hash_del_rcu(&vsock->hash);
+
vsock->guest_cid = guest_cid;
+ hash_add_rcu(vhost_vsock_hash, &vsock->hash, guest_cid);
spin_unlock_bh(&vhost_vsock_lock);
return 0;