diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r-- | drivers/vhost/vhost.c | 164 |
1 files changed, 69 insertions, 95 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index f11bdbe4c2c5..60c9ebd629dd 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -22,11 +22,11 @@ #include <linux/slab.h> #include <linux/vmalloc.h> #include <linux/kthread.h> -#include <linux/cgroup.h> #include <linux/module.h> #include <linux/sort.h> #include <linux/sched/mm.h> #include <linux/sched/signal.h> +#include <linux/sched/vhost_task.h> #include <linux/interval_tree_generic.h> #include <linux/nospec.h> #include <linux/kcov.h> @@ -235,7 +235,7 @@ void vhost_dev_flush(struct vhost_dev *dev) { struct vhost_flush_struct flush; - if (dev->worker) { + if (dev->worker.vtsk) { init_completion(&flush.wait_event); vhost_work_init(&flush.work, vhost_flush_work); @@ -247,7 +247,7 @@ EXPORT_SYMBOL_GPL(vhost_dev_flush); void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) { - if (!dev->worker) + if (!dev->worker.vtsk) return; if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) { @@ -255,8 +255,8 @@ void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) * sure it was not in the list. * test_and_set_bit() implies a memory barrier. */ - llist_add(&work->node, &dev->work_list); - wake_up_process(dev->worker); + llist_add(&work->node, &dev->worker.work_list); + vhost_task_wake(dev->worker.vtsk); } } EXPORT_SYMBOL_GPL(vhost_work_queue); @@ -264,7 +264,7 @@ EXPORT_SYMBOL_GPL(vhost_work_queue); /* A lockless hint for busy polling code to exit the loop */ bool vhost_has_work(struct vhost_dev *dev) { - return !llist_empty(&dev->work_list); + return !llist_empty(&dev->worker.work_list); } EXPORT_SYMBOL_GPL(vhost_has_work); @@ -333,42 +333,29 @@ static void vhost_vq_reset(struct vhost_dev *dev, __vhost_vq_meta_reset(vq); } -static int vhost_worker(void *data) +static bool vhost_worker(void *data) { - struct vhost_dev *dev = data; + struct vhost_worker *worker = data; struct vhost_work *work, *work_next; struct llist_node *node; - kthread_use_mm(dev->mm); - - for (;;) { - /* mb paired w/ kthread_stop */ - set_current_state(TASK_INTERRUPTIBLE); - - if (kthread_should_stop()) { - __set_current_state(TASK_RUNNING); - break; - } - - node = llist_del_all(&dev->work_list); - if (!node) - schedule(); + node = llist_del_all(&worker->work_list); + if (node) { + __set_current_state(TASK_RUNNING); node = llist_reverse_order(node); /* make sure flag is seen after deletion */ smp_wmb(); llist_for_each_entry_safe(work, work_next, node, node) { clear_bit(VHOST_WORK_QUEUED, &work->flags); - __set_current_state(TASK_RUNNING); - kcov_remote_start_common(dev->kcov_handle); + kcov_remote_start_common(worker->kcov_handle); work->fn(work); kcov_remote_stop(); - if (need_resched()) - schedule(); + cond_resched(); } } - kthread_unuse_mm(dev->mm); - return 0; + + return !!node; } static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) @@ -436,8 +423,7 @@ static size_t vhost_get_avail_size(struct vhost_virtqueue *vq, size_t event __maybe_unused = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; - return sizeof(*vq->avail) + - sizeof(*vq->avail->ring) * num + event; + return size_add(struct_size(vq->avail, ring, num), event); } static size_t vhost_get_used_size(struct vhost_virtqueue *vq, @@ -446,8 +432,7 @@ static size_t vhost_get_used_size(struct vhost_virtqueue *vq, size_t event __maybe_unused = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; - return sizeof(*vq->used) + - sizeof(*vq->used->ring) * num + event; + return size_add(struct_size(vq->used, ring, num), event); } static size_t vhost_get_desc_size(struct vhost_virtqueue *vq, @@ -473,13 +458,13 @@ void vhost_dev_init(struct vhost_dev *dev, dev->umem = NULL; dev->iotlb = NULL; dev->mm = NULL; - dev->worker = NULL; + memset(&dev->worker, 0, sizeof(dev->worker)); + init_llist_head(&dev->worker.work_list); dev->iov_limit = iov_limit; dev->weight = weight; dev->byte_weight = byte_weight; dev->use_worker = use_worker; dev->msg_handler = msg_handler; - init_llist_head(&dev->work_list); init_waitqueue_head(&dev->wait); INIT_LIST_HEAD(&dev->read_list); INIT_LIST_HEAD(&dev->pending_list); @@ -509,31 +494,6 @@ long vhost_dev_check_owner(struct vhost_dev *dev) } EXPORT_SYMBOL_GPL(vhost_dev_check_owner); -struct vhost_attach_cgroups_struct { - struct vhost_work work; - struct task_struct *owner; - int ret; -}; - -static void vhost_attach_cgroups_work(struct vhost_work *work) -{ - struct vhost_attach_cgroups_struct *s; - - s = container_of(work, struct vhost_attach_cgroups_struct, work); - s->ret = cgroup_attach_task_all(s->owner, current); -} - -static int vhost_attach_cgroups(struct vhost_dev *dev) -{ - struct vhost_attach_cgroups_struct attach; - - attach.owner = current; - vhost_work_init(&attach.work, vhost_attach_cgroups_work); - vhost_work_queue(dev, &attach.work); - vhost_dev_flush(dev); - return attach.ret; -} - /* Caller should have device mutex */ bool vhost_dev_has_owner(struct vhost_dev *dev) { @@ -571,10 +531,37 @@ static void vhost_detach_mm(struct vhost_dev *dev) dev->mm = NULL; } +static void vhost_worker_free(struct vhost_dev *dev) +{ + if (!dev->worker.vtsk) + return; + + WARN_ON(!llist_empty(&dev->worker.work_list)); + vhost_task_stop(dev->worker.vtsk); + dev->worker.kcov_handle = 0; + dev->worker.vtsk = NULL; +} + +static int vhost_worker_create(struct vhost_dev *dev) +{ + struct vhost_task *vtsk; + char name[TASK_COMM_LEN]; + + snprintf(name, sizeof(name), "vhost-%d", current->pid); + + vtsk = vhost_task_create(vhost_worker, &dev->worker, name); + if (!vtsk) + return -ENOMEM; + + dev->worker.kcov_handle = kcov_common_handle(); + dev->worker.vtsk = vtsk; + vhost_task_start(vtsk); + return 0; +} + /* Caller should have device mutex */ long vhost_dev_set_owner(struct vhost_dev *dev) { - struct task_struct *worker; int err; /* Is there an owner already? */ @@ -585,36 +572,21 @@ long vhost_dev_set_owner(struct vhost_dev *dev) vhost_attach_mm(dev); - dev->kcov_handle = kcov_common_handle(); if (dev->use_worker) { - worker = kthread_create(vhost_worker, dev, - "vhost-%d", current->pid); - if (IS_ERR(worker)) { - err = PTR_ERR(worker); - goto err_worker; - } - - dev->worker = worker; - wake_up_process(worker); /* avoid contributing to loadavg */ - - err = vhost_attach_cgroups(dev); + err = vhost_worker_create(dev); if (err) - goto err_cgroup; + goto err_worker; } err = vhost_dev_alloc_iovecs(dev); if (err) - goto err_cgroup; + goto err_iovecs; return 0; -err_cgroup: - if (dev->worker) { - kthread_stop(dev->worker); - dev->worker = NULL; - } +err_iovecs: + vhost_worker_free(dev); err_worker: vhost_detach_mm(dev); - dev->kcov_handle = 0; err_mm: return err; } @@ -705,12 +677,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev) dev->iotlb = NULL; vhost_clear_msg(dev); wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM); - WARN_ON(!llist_empty(&dev->work_list)); - if (dev->worker) { - kthread_stop(dev->worker); - dev->worker = NULL; - dev->kcov_handle = 0; - } + vhost_worker_free(dev); vhost_detach_mm(dev); } EXPORT_SYMBOL_GPL(vhost_dev_cleanup); @@ -1633,17 +1600,25 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg r = -EFAULT; break; } - if (s.num > 0xffff) { - r = -EINVAL; - break; + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) { + vq->last_avail_idx = s.num & 0xffff; + vq->last_used_idx = (s.num >> 16) & 0xffff; + } else { + if (s.num > 0xffff) { + r = -EINVAL; + break; + } + vq->last_avail_idx = s.num; } - vq->last_avail_idx = s.num; /* Forget the cached index value. */ vq->avail_idx = vq->last_avail_idx; break; case VHOST_GET_VRING_BASE: s.index = idx; - s.num = vq->last_avail_idx; + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + s.num = (u32)vq->last_avail_idx | ((u32)vq->last_used_idx << 16); + else + s.num = vq->last_avail_idx; if (copy_to_user(argp, &s, sizeof s)) r = -EFAULT; break; @@ -1831,7 +1806,7 @@ EXPORT_SYMBOL_GPL(vhost_dev_ioctl); /* TODO: This is really inefficient. We need something like get_user() * (instruction directly accesses the data, with an exception table entry - * returning -EFAULT). See Documentation/x86/exception-tables.rst. + * returning -EFAULT). See Documentation/arch/x86/exception-tables.rst. */ static int set_bit_to_user(int nr, void __user *addr) { @@ -2582,12 +2557,11 @@ EXPORT_SYMBOL_GPL(vhost_disable_notify); /* Create a new message. */ struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type) { - struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL); + /* Make sure all padding within the structure is initialized. */ + struct vhost_msg_node *node = kzalloc(sizeof(*node), GFP_KERNEL); if (!node) return NULL; - /* Make sure all padding within the structure is initialized. */ - memset(&node->msg, 0, sizeof node->msg); node->vq = vq; node->msg.type = type; return node; |