diff options
Diffstat (limited to 'drivers/vhost')
-rw-r--r-- | drivers/vhost/net.c | 64 | ||||
-rw-r--r-- | drivers/vhost/vhost.c | 26 | ||||
-rw-r--r-- | drivers/vhost/vsock.c | 79 |
3 files changed, 126 insertions, 43 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index ab11b2bee273..36f3d0f49e60 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -141,6 +141,10 @@ struct vhost_net { unsigned tx_zcopy_err; /* Flush in progress. Protected by tx vq lock. */ bool tx_flush; + /* Private page frag */ + struct page_frag page_frag; + /* Refcount bias of page frag */ + int refcnt_bias; }; static unsigned vhost_net_zcopy_mask __read_mostly; @@ -513,7 +517,13 @@ static void vhost_net_busy_poll(struct vhost_net *net, struct socket *sock; struct vhost_virtqueue *vq = poll_rx ? tvq : rvq; - mutex_lock_nested(&vq->mutex, poll_rx ? VHOST_NET_VQ_TX: VHOST_NET_VQ_RX); + /* Try to hold the vq mutex of the paired virtqueue. We can't + * use mutex_lock() here since we could not guarantee a + * consistenet lock ordering. + */ + if (!mutex_trylock(&vq->mutex)) + return; + vhost_disable_notify(&net->dev, vq); sock = rvq->private_data; @@ -637,14 +647,53 @@ static bool tx_can_batch(struct vhost_virtqueue *vq, size_t total_len) !vhost_vq_avail_empty(vq->dev, vq); } +#define SKB_FRAG_PAGE_ORDER get_order(32768) + +static bool vhost_net_page_frag_refill(struct vhost_net *net, unsigned int sz, + struct page_frag *pfrag, gfp_t gfp) +{ + if (pfrag->page) { + if (pfrag->offset + sz <= pfrag->size) + return true; + __page_frag_cache_drain(pfrag->page, net->refcnt_bias); + } + + pfrag->offset = 0; + net->refcnt_bias = 0; + if (SKB_FRAG_PAGE_ORDER) { + /* Avoid direct reclaim but allow kswapd to wake */ + pfrag->page = alloc_pages((gfp & ~__GFP_DIRECT_RECLAIM) | + __GFP_COMP | __GFP_NOWARN | + __GFP_NORETRY, + SKB_FRAG_PAGE_ORDER); + if (likely(pfrag->page)) { + pfrag->size = PAGE_SIZE << SKB_FRAG_PAGE_ORDER; + goto done; + } + } + pfrag->page = alloc_page(gfp); + if (likely(pfrag->page)) { + pfrag->size = PAGE_SIZE; + goto done; + } + return false; + +done: + net->refcnt_bias = USHRT_MAX; + page_ref_add(pfrag->page, USHRT_MAX - 1); + return true; +} + #define VHOST_NET_RX_PAD (NET_IP_ALIGN + NET_SKB_PAD) static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq, struct iov_iter *from) { struct vhost_virtqueue *vq = &nvq->vq; + struct vhost_net *net = container_of(vq->dev, struct vhost_net, + dev); struct socket *sock = vq->private_data; - struct page_frag *alloc_frag = ¤t->task_frag; + struct page_frag *alloc_frag = &net->page_frag; struct virtio_net_hdr *gso; struct xdp_buff *xdp = &nvq->xdp[nvq->batched_xdp]; struct tun_xdp_hdr *hdr; @@ -665,7 +714,8 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq, buflen += SKB_DATA_ALIGN(len + pad); alloc_frag->offset = ALIGN((u64)alloc_frag->offset, SMP_CACHE_BYTES); - if (unlikely(!skb_page_frag_refill(buflen, alloc_frag, GFP_KERNEL))) + if (unlikely(!vhost_net_page_frag_refill(net, buflen, + alloc_frag, GFP_KERNEL))) return -ENOMEM; buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset; @@ -703,7 +753,7 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq, xdp->data_end = xdp->data + len; hdr->buflen = buflen; - get_page(alloc_frag->page); + --net->refcnt_bias; alloc_frag->offset += buflen; ++nvq->batched_xdp; @@ -1292,6 +1342,8 @@ static int vhost_net_open(struct inode *inode, struct file *f) vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev); f->private_data = n; + n->page_frag.page = NULL; + n->refcnt_bias = 0; return 0; } @@ -1359,13 +1411,15 @@ static int vhost_net_release(struct inode *inode, struct file *f) if (rx_sock) sockfd_put(rx_sock); /* Make sure no callbacks are outstanding */ - synchronize_rcu_bh(); + synchronize_rcu(); /* We do an extra flush before freeing memory, * since jobs can re-queue themselves. */ vhost_net_flush(n); kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue); kfree(n->vqs[VHOST_NET_VQ_TX].xdp); kfree(n->dev.vqs); + if (n->page_frag.page) + __page_frag_cache_drain(n->page_frag.page, n->refcnt_bias); kvfree(n); return 0; } diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 3a5f81a66d34..55e5aa662ad5 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -295,11 +295,8 @@ static void vhost_vq_meta_reset(struct vhost_dev *d) { int i; - for (i = 0; i < d->nvqs; ++i) { - mutex_lock(&d->vqs[i]->mutex); + for (i = 0; i < d->nvqs; ++i) __vhost_vq_meta_reset(d->vqs[i]); - mutex_unlock(&d->vqs[i]->mutex); - } } static void vhost_vq_reset(struct vhost_dev *dev, @@ -895,6 +892,20 @@ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq, #define vhost_get_used(vq, x, ptr) \ vhost_get_user(vq, x, ptr, VHOST_ADDR_USED) +static void vhost_dev_lock_vqs(struct vhost_dev *d) +{ + int i = 0; + for (i = 0; i < d->nvqs; ++i) + mutex_lock_nested(&d->vqs[i]->mutex, i); +} + +static void vhost_dev_unlock_vqs(struct vhost_dev *d) +{ + int i = 0; + for (i = 0; i < d->nvqs; ++i) + mutex_unlock(&d->vqs[i]->mutex); +} + static int vhost_new_umem_range(struct vhost_umem *umem, u64 start, u64 size, u64 end, u64 userspace_addr, int perm) @@ -944,10 +955,7 @@ static void vhost_iotlb_notify_vq(struct vhost_dev *d, if (msg->iova <= vq_msg->iova && msg->iova + msg->size - 1 >= vq_msg->iova && vq_msg->type == VHOST_IOTLB_MISS) { - mutex_lock(&node->vq->mutex); vhost_poll_queue(&node->vq->poll); - mutex_unlock(&node->vq->mutex); - list_del(&node->node); kfree(node); } @@ -979,6 +987,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, int ret = 0; mutex_lock(&dev->mutex); + vhost_dev_lock_vqs(dev); switch (msg->type) { case VHOST_IOTLB_UPDATE: if (!dev->iotlb) { @@ -1012,6 +1021,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, break; } + vhost_dev_unlock_vqs(dev); mutex_unlock(&dev->mutex); return ret; @@ -2223,6 +2233,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, return -EFAULT; } if (unlikely(vq->log_used)) { + /* Make sure used idx is seen before log. */ + smp_wmb(); /* Log used index update. */ log_write(vq->log_base, vq->log_addr + offsetof(struct vring_used, idx), diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c index 34bc3ab40c6d..98ed5be132c6 100644 --- a/drivers/vhost/vsock.c +++ b/drivers/vhost/vsock.c @@ -15,6 +15,7 @@ #include <net/sock.h> #include <linux/virtio_vsock.h> #include <linux/vhost.h> +#include <linux/hashtable.h> #include <net/af_vsock.h> #include "vhost.h" @@ -27,14 +28,14 @@ enum { /* Used to track all the vhost_vsock instances on the system. */ static DEFINE_SPINLOCK(vhost_vsock_lock); -static LIST_HEAD(vhost_vsock_list); +static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8); struct vhost_vsock { struct vhost_dev dev; struct vhost_virtqueue vqs[2]; - /* Link to global vhost_vsock_list, protected by vhost_vsock_lock */ - struct list_head list; + /* Link to global vhost_vsock_hash, writes use vhost_vsock_lock */ + struct hlist_node hash; struct vhost_work send_pkt_work; spinlock_t send_pkt_list_lock; @@ -50,11 +51,14 @@ static u32 vhost_transport_get_local_cid(void) return VHOST_VSOCK_DEFAULT_HOST_CID; } -static struct vhost_vsock *__vhost_vsock_get(u32 guest_cid) +/* Callers that dereference the return value must hold vhost_vsock_lock or the + * RCU read lock. + */ +static struct vhost_vsock *vhost_vsock_get(u32 guest_cid) { struct vhost_vsock *vsock; - list_for_each_entry(vsock, &vhost_vsock_list, list) { + hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) { u32 other_cid = vsock->guest_cid; /* Skip instances that have no CID yet */ @@ -69,17 +73,6 @@ static struct vhost_vsock *__vhost_vsock_get(u32 guest_cid) return NULL; } -static struct vhost_vsock *vhost_vsock_get(u32 guest_cid) -{ - struct vhost_vsock *vsock; - - spin_lock_bh(&vhost_vsock_lock); - vsock = __vhost_vsock_get(guest_cid); - spin_unlock_bh(&vhost_vsock_lock); - - return vsock; -} - static void vhost_transport_do_send_pkt(struct vhost_vsock *vsock, struct vhost_virtqueue *vq) @@ -210,9 +203,12 @@ vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt) struct vhost_vsock *vsock; int len = pkt->len; + rcu_read_lock(); + /* Find the vhost_vsock according to guest context id */ vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid)); if (!vsock) { + rcu_read_unlock(); virtio_transport_free_pkt(pkt); return -ENODEV; } @@ -225,6 +221,8 @@ vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt) spin_unlock_bh(&vsock->send_pkt_list_lock); vhost_work_queue(&vsock->dev, &vsock->send_pkt_work); + + rcu_read_unlock(); return len; } @@ -234,12 +232,15 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk) struct vhost_vsock *vsock; struct virtio_vsock_pkt *pkt, *n; int cnt = 0; + int ret = -ENODEV; LIST_HEAD(freeme); + rcu_read_lock(); + /* Find the vhost_vsock according to guest context id */ vsock = vhost_vsock_get(vsk->remote_addr.svm_cid); if (!vsock) - return -ENODEV; + goto out; spin_lock_bh(&vsock->send_pkt_list_lock); list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) { @@ -265,7 +266,10 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk) vhost_poll_queue(&tx_vq->poll); } - return 0; + ret = 0; +out: + rcu_read_unlock(); + return ret; } static struct virtio_vsock_pkt * @@ -533,10 +537,6 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file) spin_lock_init(&vsock->send_pkt_list_lock); INIT_LIST_HEAD(&vsock->send_pkt_list); vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work); - - spin_lock_bh(&vhost_vsock_lock); - list_add_tail(&vsock->list, &vhost_vsock_list); - spin_unlock_bh(&vhost_vsock_lock); return 0; out: @@ -563,13 +563,21 @@ static void vhost_vsock_reset_orphans(struct sock *sk) * executing. */ - if (!vhost_vsock_get(vsk->remote_addr.svm_cid)) { - sock_set_flag(sk, SOCK_DONE); - vsk->peer_shutdown = SHUTDOWN_MASK; - sk->sk_state = SS_UNCONNECTED; - sk->sk_err = ECONNRESET; - sk->sk_error_report(sk); - } + /* If the peer is still valid, no need to reset connection */ + if (vhost_vsock_get(vsk->remote_addr.svm_cid)) + return; + + /* If the close timeout is pending, let it expire. This avoids races + * with the timeout callback. + */ + if (vsk->close_work_scheduled) + return; + + sock_set_flag(sk, SOCK_DONE); + vsk->peer_shutdown = SHUTDOWN_MASK; + sk->sk_state = SS_UNCONNECTED; + sk->sk_err = ECONNRESET; + sk->sk_error_report(sk); } static int vhost_vsock_dev_release(struct inode *inode, struct file *file) @@ -577,9 +585,13 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file) struct vhost_vsock *vsock = file->private_data; spin_lock_bh(&vhost_vsock_lock); - list_del(&vsock->list); + if (vsock->guest_cid) + hash_del_rcu(&vsock->hash); spin_unlock_bh(&vhost_vsock_lock); + /* Wait for other CPUs to finish using vsock */ + synchronize_rcu(); + /* Iterating over all connections for all CIDs to find orphans is * inefficient. Room for improvement here. */ vsock_for_each_connected_socket(vhost_vsock_reset_orphans); @@ -620,12 +632,17 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid) /* Refuse if CID is already in use */ spin_lock_bh(&vhost_vsock_lock); - other = __vhost_vsock_get(guest_cid); + other = vhost_vsock_get(guest_cid); if (other && other != vsock) { spin_unlock_bh(&vhost_vsock_lock); return -EADDRINUSE; } + + if (vsock->guest_cid) + hash_del_rcu(&vsock->hash); + vsock->guest_cid = guest_cid; + hash_add_rcu(vhost_vsock_hash, &vsock->hash, guest_cid); spin_unlock_bh(&vhost_vsock_lock); return 0; |