summaryrefslogtreecommitdiff
path: root/drivers/vhost
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/net.c64
-rw-r--r--drivers/vhost/vhost.c26
-rw-r--r--drivers/vhost/vsock.c79
3 files changed, 126 insertions, 43 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index ab11b2bee273..36f3d0f49e60 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -141,6 +141,10 @@ struct vhost_net {
unsigned tx_zcopy_err;
/* Flush in progress. Protected by tx vq lock. */
bool tx_flush;
+ /* Private page frag */
+ struct page_frag page_frag;
+ /* Refcount bias of page frag */
+ int refcnt_bias;
};
static unsigned vhost_net_zcopy_mask __read_mostly;
@@ -513,7 +517,13 @@ static void vhost_net_busy_poll(struct vhost_net *net,
struct socket *sock;
struct vhost_virtqueue *vq = poll_rx ? tvq : rvq;
- mutex_lock_nested(&vq->mutex, poll_rx ? VHOST_NET_VQ_TX: VHOST_NET_VQ_RX);
+ /* Try to hold the vq mutex of the paired virtqueue. We can't
+ * use mutex_lock() here since we could not guarantee a
+ * consistenet lock ordering.
+ */
+ if (!mutex_trylock(&vq->mutex))
+ return;
+
vhost_disable_notify(&net->dev, vq);
sock = rvq->private_data;
@@ -637,14 +647,53 @@ static bool tx_can_batch(struct vhost_virtqueue *vq, size_t total_len)
!vhost_vq_avail_empty(vq->dev, vq);
}
+#define SKB_FRAG_PAGE_ORDER get_order(32768)
+
+static bool vhost_net_page_frag_refill(struct vhost_net *net, unsigned int sz,
+ struct page_frag *pfrag, gfp_t gfp)
+{
+ if (pfrag->page) {
+ if (pfrag->offset + sz <= pfrag->size)
+ return true;
+ __page_frag_cache_drain(pfrag->page, net->refcnt_bias);
+ }
+
+ pfrag->offset = 0;
+ net->refcnt_bias = 0;
+ if (SKB_FRAG_PAGE_ORDER) {
+ /* Avoid direct reclaim but allow kswapd to wake */
+ pfrag->page = alloc_pages((gfp & ~__GFP_DIRECT_RECLAIM) |
+ __GFP_COMP | __GFP_NOWARN |
+ __GFP_NORETRY,
+ SKB_FRAG_PAGE_ORDER);
+ if (likely(pfrag->page)) {
+ pfrag->size = PAGE_SIZE << SKB_FRAG_PAGE_ORDER;
+ goto done;
+ }
+ }
+ pfrag->page = alloc_page(gfp);
+ if (likely(pfrag->page)) {
+ pfrag->size = PAGE_SIZE;
+ goto done;
+ }
+ return false;
+
+done:
+ net->refcnt_bias = USHRT_MAX;
+ page_ref_add(pfrag->page, USHRT_MAX - 1);
+ return true;
+}
+
#define VHOST_NET_RX_PAD (NET_IP_ALIGN + NET_SKB_PAD)
static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
struct iov_iter *from)
{
struct vhost_virtqueue *vq = &nvq->vq;
+ struct vhost_net *net = container_of(vq->dev, struct vhost_net,
+ dev);
struct socket *sock = vq->private_data;
- struct page_frag *alloc_frag = &current->task_frag;
+ struct page_frag *alloc_frag = &net->page_frag;
struct virtio_net_hdr *gso;
struct xdp_buff *xdp = &nvq->xdp[nvq->batched_xdp];
struct tun_xdp_hdr *hdr;
@@ -665,7 +714,8 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
buflen += SKB_DATA_ALIGN(len + pad);
alloc_frag->offset = ALIGN((u64)alloc_frag->offset, SMP_CACHE_BYTES);
- if (unlikely(!skb_page_frag_refill(buflen, alloc_frag, GFP_KERNEL)))
+ if (unlikely(!vhost_net_page_frag_refill(net, buflen,
+ alloc_frag, GFP_KERNEL)))
return -ENOMEM;
buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
@@ -703,7 +753,7 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
xdp->data_end = xdp->data + len;
hdr->buflen = buflen;
- get_page(alloc_frag->page);
+ --net->refcnt_bias;
alloc_frag->offset += buflen;
++nvq->batched_xdp;
@@ -1292,6 +1342,8 @@ static int vhost_net_open(struct inode *inode, struct file *f)
vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev);
f->private_data = n;
+ n->page_frag.page = NULL;
+ n->refcnt_bias = 0;
return 0;
}
@@ -1359,13 +1411,15 @@ static int vhost_net_release(struct inode *inode, struct file *f)
if (rx_sock)
sockfd_put(rx_sock);
/* Make sure no callbacks are outstanding */
- synchronize_rcu_bh();
+ synchronize_rcu();
/* We do an extra flush before freeing memory,
* since jobs can re-queue themselves. */
vhost_net_flush(n);
kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue);
kfree(n->vqs[VHOST_NET_VQ_TX].xdp);
kfree(n->dev.vqs);
+ if (n->page_frag.page)
+ __page_frag_cache_drain(n->page_frag.page, n->refcnt_bias);
kvfree(n);
return 0;
}
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 3a5f81a66d34..55e5aa662ad5 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -295,11 +295,8 @@ static void vhost_vq_meta_reset(struct vhost_dev *d)
{
int i;
- for (i = 0; i < d->nvqs; ++i) {
- mutex_lock(&d->vqs[i]->mutex);
+ for (i = 0; i < d->nvqs; ++i)
__vhost_vq_meta_reset(d->vqs[i]);
- mutex_unlock(&d->vqs[i]->mutex);
- }
}
static void vhost_vq_reset(struct vhost_dev *dev,
@@ -895,6 +892,20 @@ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
#define vhost_get_used(vq, x, ptr) \
vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
+static void vhost_dev_lock_vqs(struct vhost_dev *d)
+{
+ int i = 0;
+ for (i = 0; i < d->nvqs; ++i)
+ mutex_lock_nested(&d->vqs[i]->mutex, i);
+}
+
+static void vhost_dev_unlock_vqs(struct vhost_dev *d)
+{
+ int i = 0;
+ for (i = 0; i < d->nvqs; ++i)
+ mutex_unlock(&d->vqs[i]->mutex);
+}
+
static int vhost_new_umem_range(struct vhost_umem *umem,
u64 start, u64 size, u64 end,
u64 userspace_addr, int perm)
@@ -944,10 +955,7 @@ static void vhost_iotlb_notify_vq(struct vhost_dev *d,
if (msg->iova <= vq_msg->iova &&
msg->iova + msg->size - 1 >= vq_msg->iova &&
vq_msg->type == VHOST_IOTLB_MISS) {
- mutex_lock(&node->vq->mutex);
vhost_poll_queue(&node->vq->poll);
- mutex_unlock(&node->vq->mutex);
-
list_del(&node->node);
kfree(node);
}
@@ -979,6 +987,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
int ret = 0;
mutex_lock(&dev->mutex);
+ vhost_dev_lock_vqs(dev);
switch (msg->type) {
case VHOST_IOTLB_UPDATE:
if (!dev->iotlb) {
@@ -1012,6 +1021,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
break;
}
+ vhost_dev_unlock_vqs(dev);
mutex_unlock(&dev->mutex);
return ret;
@@ -2223,6 +2233,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
return -EFAULT;
}
if (unlikely(vq->log_used)) {
+ /* Make sure used idx is seen before log. */
+ smp_wmb();
/* Log used index update. */
log_write(vq->log_base,
vq->log_addr + offsetof(struct vring_used, idx),
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 34bc3ab40c6d..98ed5be132c6 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -15,6 +15,7 @@
#include <net/sock.h>
#include <linux/virtio_vsock.h>
#include <linux/vhost.h>
+#include <linux/hashtable.h>
#include <net/af_vsock.h>
#include "vhost.h"
@@ -27,14 +28,14 @@ enum {
/* Used to track all the vhost_vsock instances on the system. */
static DEFINE_SPINLOCK(vhost_vsock_lock);
-static LIST_HEAD(vhost_vsock_list);
+static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
struct vhost_vsock {
struct vhost_dev dev;
struct vhost_virtqueue vqs[2];
- /* Link to global vhost_vsock_list, protected by vhost_vsock_lock */
- struct list_head list;
+ /* Link to global vhost_vsock_hash, writes use vhost_vsock_lock */
+ struct hlist_node hash;
struct vhost_work send_pkt_work;
spinlock_t send_pkt_list_lock;
@@ -50,11 +51,14 @@ static u32 vhost_transport_get_local_cid(void)
return VHOST_VSOCK_DEFAULT_HOST_CID;
}
-static struct vhost_vsock *__vhost_vsock_get(u32 guest_cid)
+/* Callers that dereference the return value must hold vhost_vsock_lock or the
+ * RCU read lock.
+ */
+static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
{
struct vhost_vsock *vsock;
- list_for_each_entry(vsock, &vhost_vsock_list, list) {
+ hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
u32 other_cid = vsock->guest_cid;
/* Skip instances that have no CID yet */
@@ -69,17 +73,6 @@ static struct vhost_vsock *__vhost_vsock_get(u32 guest_cid)
return NULL;
}
-static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
-{
- struct vhost_vsock *vsock;
-
- spin_lock_bh(&vhost_vsock_lock);
- vsock = __vhost_vsock_get(guest_cid);
- spin_unlock_bh(&vhost_vsock_lock);
-
- return vsock;
-}
-
static void
vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
struct vhost_virtqueue *vq)
@@ -210,9 +203,12 @@ vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
struct vhost_vsock *vsock;
int len = pkt->len;
+ rcu_read_lock();
+
/* Find the vhost_vsock according to guest context id */
vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
if (!vsock) {
+ rcu_read_unlock();
virtio_transport_free_pkt(pkt);
return -ENODEV;
}
@@ -225,6 +221,8 @@ vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
spin_unlock_bh(&vsock->send_pkt_list_lock);
vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
+
+ rcu_read_unlock();
return len;
}
@@ -234,12 +232,15 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk)
struct vhost_vsock *vsock;
struct virtio_vsock_pkt *pkt, *n;
int cnt = 0;
+ int ret = -ENODEV;
LIST_HEAD(freeme);
+ rcu_read_lock();
+
/* Find the vhost_vsock according to guest context id */
vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
if (!vsock)
- return -ENODEV;
+ goto out;
spin_lock_bh(&vsock->send_pkt_list_lock);
list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
@@ -265,7 +266,10 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk)
vhost_poll_queue(&tx_vq->poll);
}
- return 0;
+ ret = 0;
+out:
+ rcu_read_unlock();
+ return ret;
}
static struct virtio_vsock_pkt *
@@ -533,10 +537,6 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
spin_lock_init(&vsock->send_pkt_list_lock);
INIT_LIST_HEAD(&vsock->send_pkt_list);
vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
-
- spin_lock_bh(&vhost_vsock_lock);
- list_add_tail(&vsock->list, &vhost_vsock_list);
- spin_unlock_bh(&vhost_vsock_lock);
return 0;
out:
@@ -563,13 +563,21 @@ static void vhost_vsock_reset_orphans(struct sock *sk)
* executing.
*/
- if (!vhost_vsock_get(vsk->remote_addr.svm_cid)) {
- sock_set_flag(sk, SOCK_DONE);
- vsk->peer_shutdown = SHUTDOWN_MASK;
- sk->sk_state = SS_UNCONNECTED;
- sk->sk_err = ECONNRESET;
- sk->sk_error_report(sk);
- }
+ /* If the peer is still valid, no need to reset connection */
+ if (vhost_vsock_get(vsk->remote_addr.svm_cid))
+ return;
+
+ /* If the close timeout is pending, let it expire. This avoids races
+ * with the timeout callback.
+ */
+ if (vsk->close_work_scheduled)
+ return;
+
+ sock_set_flag(sk, SOCK_DONE);
+ vsk->peer_shutdown = SHUTDOWN_MASK;
+ sk->sk_state = SS_UNCONNECTED;
+ sk->sk_err = ECONNRESET;
+ sk->sk_error_report(sk);
}
static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
@@ -577,9 +585,13 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
struct vhost_vsock *vsock = file->private_data;
spin_lock_bh(&vhost_vsock_lock);
- list_del(&vsock->list);
+ if (vsock->guest_cid)
+ hash_del_rcu(&vsock->hash);
spin_unlock_bh(&vhost_vsock_lock);
+ /* Wait for other CPUs to finish using vsock */
+ synchronize_rcu();
+
/* Iterating over all connections for all CIDs to find orphans is
* inefficient. Room for improvement here. */
vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
@@ -620,12 +632,17 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
/* Refuse if CID is already in use */
spin_lock_bh(&vhost_vsock_lock);
- other = __vhost_vsock_get(guest_cid);
+ other = vhost_vsock_get(guest_cid);
if (other && other != vsock) {
spin_unlock_bh(&vhost_vsock_lock);
return -EADDRINUSE;
}
+
+ if (vsock->guest_cid)
+ hash_del_rcu(&vsock->hash);
+
vsock->guest_cid = guest_cid;
+ hash_add_rcu(vhost_vsock_hash, &vsock->hash, guest_cid);
spin_unlock_bh(&vhost_vsock_lock);
return 0;