summaryrefslogtreecommitdiff
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c850
1 files changed, 727 insertions, 123 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index ff8892c38666..0536f8526359 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -298,6 +298,160 @@ static void vhost_vq_meta_reset(struct vhost_dev *d)
__vhost_vq_meta_reset(d->vqs[i]);
}
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+static void vhost_map_unprefetch(struct vhost_map *map)
+{
+ kfree(map->pages);
+ map->pages = NULL;
+ map->npages = 0;
+ map->addr = NULL;
+}
+
+static void vhost_uninit_vq_maps(struct vhost_virtqueue *vq)
+{
+ struct vhost_map *map[VHOST_NUM_ADDRS];
+ int i;
+
+ spin_lock(&vq->mmu_lock);
+ for (i = 0; i < VHOST_NUM_ADDRS; i++) {
+ map[i] = rcu_dereference_protected(vq->maps[i],
+ lockdep_is_held(&vq->mmu_lock));
+ if (map[i])
+ rcu_assign_pointer(vq->maps[i], NULL);
+ }
+ spin_unlock(&vq->mmu_lock);
+
+ synchronize_rcu();
+
+ for (i = 0; i < VHOST_NUM_ADDRS; i++)
+ if (map[i])
+ vhost_map_unprefetch(map[i]);
+
+}
+
+static void vhost_reset_vq_maps(struct vhost_virtqueue *vq)
+{
+ int i;
+
+ vhost_uninit_vq_maps(vq);
+ for (i = 0; i < VHOST_NUM_ADDRS; i++)
+ vq->uaddrs[i].size = 0;
+}
+
+static bool vhost_map_range_overlap(struct vhost_uaddr *uaddr,
+ unsigned long start,
+ unsigned long end)
+{
+ if (unlikely(!uaddr->size))
+ return false;
+
+ return !(end < uaddr->uaddr || start > uaddr->uaddr - 1 + uaddr->size);
+}
+
+static void vhost_invalidate_vq_start(struct vhost_virtqueue *vq,
+ int index,
+ unsigned long start,
+ unsigned long end)
+{
+ struct vhost_uaddr *uaddr = &vq->uaddrs[index];
+ struct vhost_map *map;
+ int i;
+
+ if (!vhost_map_range_overlap(uaddr, start, end))
+ return;
+
+ spin_lock(&vq->mmu_lock);
+ ++vq->invalidate_count;
+
+ map = rcu_dereference_protected(vq->maps[index],
+ lockdep_is_held(&vq->mmu_lock));
+ if (map) {
+ if (uaddr->write) {
+ for (i = 0; i < map->npages; i++)
+ set_page_dirty(map->pages[i]);
+ }
+ rcu_assign_pointer(vq->maps[index], NULL);
+ }
+ spin_unlock(&vq->mmu_lock);
+
+ if (map) {
+ synchronize_rcu();
+ vhost_map_unprefetch(map);
+ }
+}
+
+static void vhost_invalidate_vq_end(struct vhost_virtqueue *vq,
+ int index,
+ unsigned long start,
+ unsigned long end)
+{
+ if (!vhost_map_range_overlap(&vq->uaddrs[index], start, end))
+ return;
+
+ spin_lock(&vq->mmu_lock);
+ --vq->invalidate_count;
+ spin_unlock(&vq->mmu_lock);
+}
+
+static int vhost_invalidate_range_start(struct mmu_notifier *mn,
+ const struct mmu_notifier_range *range)
+{
+ struct vhost_dev *dev = container_of(mn, struct vhost_dev,
+ mmu_notifier);
+ int i, j;
+
+ if (!mmu_notifier_range_blockable(range))
+ return -EAGAIN;
+
+ for (i = 0; i < dev->nvqs; i++) {
+ struct vhost_virtqueue *vq = dev->vqs[i];
+
+ for (j = 0; j < VHOST_NUM_ADDRS; j++)
+ vhost_invalidate_vq_start(vq, j,
+ range->start,
+ range->end);
+ }
+
+ return 0;
+}
+
+static void vhost_invalidate_range_end(struct mmu_notifier *mn,
+ const struct mmu_notifier_range *range)
+{
+ struct vhost_dev *dev = container_of(mn, struct vhost_dev,
+ mmu_notifier);
+ int i, j;
+
+ for (i = 0; i < dev->nvqs; i++) {
+ struct vhost_virtqueue *vq = dev->vqs[i];
+
+ for (j = 0; j < VHOST_NUM_ADDRS; j++)
+ vhost_invalidate_vq_end(vq, j,
+ range->start,
+ range->end);
+ }
+}
+
+static const struct mmu_notifier_ops vhost_mmu_notifier_ops = {
+ .invalidate_range_start = vhost_invalidate_range_start,
+ .invalidate_range_end = vhost_invalidate_range_end,
+};
+
+static void vhost_init_maps(struct vhost_dev *dev)
+{
+ struct vhost_virtqueue *vq;
+ int i, j;
+
+ dev->mmu_notifier.ops = &vhost_mmu_notifier_ops;
+
+ for (i = 0; i < dev->nvqs; ++i) {
+ vq = dev->vqs[i];
+ for (j = 0; j < VHOST_NUM_ADDRS; j++)
+ RCU_INIT_POINTER(vq->maps[j], NULL);
+ }
+}
+#endif
+
static void vhost_vq_reset(struct vhost_dev *dev,
struct vhost_virtqueue *vq)
{
@@ -326,7 +480,11 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->busyloop_timeout = 0;
vq->umem = NULL;
vq->iotlb = NULL;
+ vq->invalidate_count = 0;
__vhost_vq_meta_reset(vq);
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ vhost_reset_vq_maps(vq);
+#endif
}
static int vhost_worker(void *data)
@@ -427,6 +585,32 @@ bool vhost_exceeds_weight(struct vhost_virtqueue *vq,
}
EXPORT_SYMBOL_GPL(vhost_exceeds_weight);
+static size_t vhost_get_avail_size(struct vhost_virtqueue *vq,
+ unsigned int num)
+{
+ size_t event __maybe_unused =
+ vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
+
+ return sizeof(*vq->avail) +
+ sizeof(*vq->avail->ring) * num + event;
+}
+
+static size_t vhost_get_used_size(struct vhost_virtqueue *vq,
+ unsigned int num)
+{
+ size_t event __maybe_unused =
+ vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
+
+ return sizeof(*vq->used) +
+ sizeof(*vq->used->ring) * num + event;
+}
+
+static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
+ unsigned int num)
+{
+ return sizeof(*vq->desc) * num;
+}
+
void vhost_dev_init(struct vhost_dev *dev,
struct vhost_virtqueue **vqs, int nvqs,
int iov_limit, int weight, int byte_weight)
@@ -450,7 +634,9 @@ void vhost_dev_init(struct vhost_dev *dev,
INIT_LIST_HEAD(&dev->read_list);
INIT_LIST_HEAD(&dev->pending_list);
spin_lock_init(&dev->iotlb_lock);
-
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ vhost_init_maps(dev);
+#endif
for (i = 0; i < dev->nvqs; ++i) {
vq = dev->vqs[i];
@@ -459,6 +645,7 @@ void vhost_dev_init(struct vhost_dev *dev,
vq->heads = NULL;
vq->dev = dev;
mutex_init(&vq->mutex);
+ spin_lock_init(&vq->mmu_lock);
vhost_vq_reset(dev, vq);
if (vq->handle_kick)
vhost_poll_init(&vq->poll, vq->handle_kick,
@@ -538,7 +725,18 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
if (err)
goto err_cgroup;
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ err = mmu_notifier_register(&dev->mmu_notifier, dev->mm);
+ if (err)
+ goto err_mmu_notifier;
+#endif
+
return 0;
+
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+err_mmu_notifier:
+ vhost_dev_free_iovecs(dev);
+#endif
err_cgroup:
kthread_stop(worker);
dev->worker = NULL;
@@ -629,6 +827,107 @@ static void vhost_clear_msg(struct vhost_dev *dev)
spin_unlock(&dev->iotlb_lock);
}
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+static void vhost_setup_uaddr(struct vhost_virtqueue *vq,
+ int index, unsigned long uaddr,
+ size_t size, bool write)
+{
+ struct vhost_uaddr *addr = &vq->uaddrs[index];
+
+ addr->uaddr = uaddr;
+ addr->size = size;
+ addr->write = write;
+}
+
+static void vhost_setup_vq_uaddr(struct vhost_virtqueue *vq)
+{
+ vhost_setup_uaddr(vq, VHOST_ADDR_DESC,
+ (unsigned long)vq->desc,
+ vhost_get_desc_size(vq, vq->num),
+ false);
+ vhost_setup_uaddr(vq, VHOST_ADDR_AVAIL,
+ (unsigned long)vq->avail,
+ vhost_get_avail_size(vq, vq->num),
+ false);
+ vhost_setup_uaddr(vq, VHOST_ADDR_USED,
+ (unsigned long)vq->used,
+ vhost_get_used_size(vq, vq->num),
+ true);
+}
+
+static int vhost_map_prefetch(struct vhost_virtqueue *vq,
+ int index)
+{
+ struct vhost_map *map;
+ struct vhost_uaddr *uaddr = &vq->uaddrs[index];
+ struct page **pages;
+ int npages = DIV_ROUND_UP(uaddr->size, PAGE_SIZE);
+ int npinned;
+ void *vaddr, *v;
+ int err;
+ int i;
+
+ spin_lock(&vq->mmu_lock);
+
+ err = -EFAULT;
+ if (vq->invalidate_count)
+ goto err;
+
+ err = -ENOMEM;
+ map = kmalloc(sizeof(*map), GFP_ATOMIC);
+ if (!map)
+ goto err;
+
+ pages = kmalloc_array(npages, sizeof(struct page *), GFP_ATOMIC);
+ if (!pages)
+ goto err_pages;
+
+ err = EFAULT;
+ npinned = __get_user_pages_fast(uaddr->uaddr, npages,
+ uaddr->write, pages);
+ if (npinned > 0)
+ release_pages(pages, npinned);
+ if (npinned != npages)
+ goto err_gup;
+
+ for (i = 0; i < npinned; i++)
+ if (PageHighMem(pages[i]))
+ goto err_gup;
+
+ vaddr = v = page_address(pages[0]);
+
+ /* For simplicity, fallback to userspace address if VA is not
+ * contigious.
+ */
+ for (i = 1; i < npinned; i++) {
+ v += PAGE_SIZE;
+ if (v != page_address(pages[i]))
+ goto err_gup;
+ }
+
+ map->addr = vaddr + (uaddr->uaddr & (PAGE_SIZE - 1));
+ map->npages = npages;
+ map->pages = pages;
+
+ rcu_assign_pointer(vq->maps[index], map);
+ /* No need for a synchronize_rcu(). This function should be
+ * called by dev->worker so we are serialized with all
+ * readers.
+ */
+ spin_unlock(&vq->mmu_lock);
+
+ return 0;
+
+err_gup:
+ kfree(pages);
+err_pages:
+ kfree(map);
+err:
+ spin_unlock(&vq->mmu_lock);
+ return err;
+}
+#endif
+
void vhost_dev_cleanup(struct vhost_dev *dev)
{
int i;
@@ -658,8 +957,16 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
kthread_stop(dev->worker);
dev->worker = NULL;
}
- if (dev->mm)
+ if (dev->mm) {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ mmu_notifier_unregister(&dev->mmu_notifier, dev->mm);
+#endif
mmput(dev->mm);
+ }
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ for (i = 0; i < dev->nvqs; i++)
+ vhost_uninit_vq_maps(dev->vqs[i]);
+#endif
dev->mm = NULL;
}
EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
@@ -886,6 +1193,113 @@ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
ret; \
})
+static inline int vhost_put_avail_event(struct vhost_virtqueue *vq)
+{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_used *used;
+
+ if (!vq->iotlb) {
+ rcu_read_lock();
+
+ map = rcu_dereference(vq->maps[VHOST_ADDR_USED]);
+ if (likely(map)) {
+ used = map->addr;
+ *((__virtio16 *)&used->ring[vq->num]) =
+ cpu_to_vhost16(vq, vq->avail_idx);
+ rcu_read_unlock();
+ return 0;
+ }
+
+ rcu_read_unlock();
+ }
+#endif
+
+ return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
+ vhost_avail_event(vq));
+}
+
+static inline int vhost_put_used(struct vhost_virtqueue *vq,
+ struct vring_used_elem *head, int idx,
+ int count)
+{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_used *used;
+ size_t size;
+
+ if (!vq->iotlb) {
+ rcu_read_lock();
+
+ map = rcu_dereference(vq->maps[VHOST_ADDR_USED]);
+ if (likely(map)) {
+ used = map->addr;
+ size = count * sizeof(*head);
+ memcpy(used->ring + idx, head, size);
+ rcu_read_unlock();
+ return 0;
+ }
+
+ rcu_read_unlock();
+ }
+#endif
+
+ return vhost_copy_to_user(vq, vq->used->ring + idx, head,
+ count * sizeof(*head));
+}
+
+static inline int vhost_put_used_flags(struct vhost_virtqueue *vq)
+
+{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_used *used;
+
+ if (!vq->iotlb) {
+ rcu_read_lock();
+
+ map = rcu_dereference(vq->maps[VHOST_ADDR_USED]);
+ if (likely(map)) {
+ used = map->addr;
+ used->flags = cpu_to_vhost16(vq, vq->used_flags);
+ rcu_read_unlock();
+ return 0;
+ }
+
+ rcu_read_unlock();
+ }
+#endif
+
+ return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
+ &vq->used->flags);
+}
+
+static inline int vhost_put_used_idx(struct vhost_virtqueue *vq)
+
+{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_used *used;
+
+ if (!vq->iotlb) {
+ rcu_read_lock();
+
+ map = rcu_dereference(vq->maps[VHOST_ADDR_USED]);
+ if (likely(map)) {
+ used = map->addr;
+ used->idx = cpu_to_vhost16(vq, vq->last_used_idx);
+ rcu_read_unlock();
+ return 0;
+ }
+
+ rcu_read_unlock();
+ }
+#endif
+
+ return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
+ &vq->used->idx);
+}
+
#define vhost_get_user(vq, x, ptr, type) \
({ \
int ret; \
@@ -924,6 +1338,155 @@ static void vhost_dev_unlock_vqs(struct vhost_dev *d)
mutex_unlock(&d->vqs[i]->mutex);
}
+static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq,
+ __virtio16 *idx)
+{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_avail *avail;
+
+ if (!vq->iotlb) {
+ rcu_read_lock();
+
+ map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]);
+ if (likely(map)) {
+ avail = map->addr;
+ *idx = avail->idx;
+ rcu_read_unlock();
+ return 0;
+ }
+
+ rcu_read_unlock();
+ }
+#endif
+
+ return vhost_get_avail(vq, *idx, &vq->avail->idx);
+}
+
+static inline int vhost_get_avail_head(struct vhost_virtqueue *vq,
+ __virtio16 *head, int idx)
+{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_avail *avail;
+
+ if (!vq->iotlb) {
+ rcu_read_lock();
+
+ map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]);
+ if (likely(map)) {
+ avail = map->addr;
+ *head = avail->ring[idx & (vq->num - 1)];
+ rcu_read_unlock();
+ return 0;
+ }
+
+ rcu_read_unlock();
+ }
+#endif
+
+ return vhost_get_avail(vq, *head,
+ &vq->avail->ring[idx & (vq->num - 1)]);
+}
+
+static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq,
+ __virtio16 *flags)
+{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_avail *avail;
+
+ if (!vq->iotlb) {
+ rcu_read_lock();
+
+ map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]);
+ if (likely(map)) {
+ avail = map->addr;
+ *flags = avail->flags;
+ rcu_read_unlock();
+ return 0;
+ }
+
+ rcu_read_unlock();
+ }
+#endif
+
+ return vhost_get_avail(vq, *flags, &vq->avail->flags);
+}
+
+static inline int vhost_get_used_event(struct vhost_virtqueue *vq,
+ __virtio16 *event)
+{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_avail *avail;
+
+ if (!vq->iotlb) {
+ rcu_read_lock();
+ map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]);
+ if (likely(map)) {
+ avail = map->addr;
+ *event = (__virtio16)avail->ring[vq->num];
+ rcu_read_unlock();
+ return 0;
+ }
+ rcu_read_unlock();
+ }
+#endif
+
+ return vhost_get_avail(vq, *event, vhost_used_event(vq));
+}
+
+static inline int vhost_get_used_idx(struct vhost_virtqueue *vq,
+ __virtio16 *idx)
+{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_used *used;
+
+ if (!vq->iotlb) {
+ rcu_read_lock();
+
+ map = rcu_dereference(vq->maps[VHOST_ADDR_USED]);
+ if (likely(map)) {
+ used = map->addr;
+ *idx = used->idx;
+ rcu_read_unlock();
+ return 0;
+ }
+
+ rcu_read_unlock();
+ }
+#endif
+
+ return vhost_get_used(vq, *idx, &vq->used->idx);
+}
+
+static inline int vhost_get_desc(struct vhost_virtqueue *vq,
+ struct vring_desc *desc, int idx)
+{
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ struct vhost_map *map;
+ struct vring_desc *d;
+
+ if (!vq->iotlb) {
+ rcu_read_lock();
+
+ map = rcu_dereference(vq->maps[VHOST_ADDR_DESC]);
+ if (likely(map)) {
+ d = map->addr;
+ *desc = *(d + idx);
+ rcu_read_unlock();
+ return 0;
+ }
+
+ rcu_read_unlock();
+ }
+#endif
+
+ return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc));
+}
+
static int vhost_new_umem_range(struct vhost_umem *umem,
u64 start, u64 size, u64 end,
u64 userspace_addr, int perm)
@@ -1209,13 +1772,9 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
struct vring_used __user *used)
{
- size_t s __maybe_unused = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
-
- return access_ok(desc, num * sizeof *desc) &&
- access_ok(avail,
- sizeof *avail + num * sizeof *avail->ring + s) &&
- access_ok(used,
- sizeof *used + num * sizeof *used->ring + s);
+ return access_ok(desc, vhost_get_desc_size(vq, num)) &&
+ access_ok(avail, vhost_get_avail_size(vq, num)) &&
+ access_ok(used, vhost_get_used_size(vq, num));
}
static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
@@ -1265,26 +1824,42 @@ static bool iotlb_access_ok(struct vhost_virtqueue *vq,
return true;
}
-int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+static void vhost_vq_map_prefetch(struct vhost_virtqueue *vq)
+{
+ struct vhost_map __rcu *map;
+ int i;
+
+ for (i = 0; i < VHOST_NUM_ADDRS; i++) {
+ rcu_read_lock();
+ map = rcu_dereference(vq->maps[i]);
+ rcu_read_unlock();
+ if (unlikely(!map))
+ vhost_map_prefetch(vq, i);
+ }
+}
+#endif
+
+int vq_meta_prefetch(struct vhost_virtqueue *vq)
{
- size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
unsigned int num = vq->num;
- if (!vq->iotlb)
+ if (!vq->iotlb) {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ vhost_vq_map_prefetch(vq);
+#endif
return 1;
+ }
return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
- num * sizeof(*vq->desc), VHOST_ADDR_DESC) &&
+ vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) &&
iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
- sizeof *vq->avail +
- num * sizeof(*vq->avail->ring) + s,
+ vhost_get_avail_size(vq, num),
VHOST_ADDR_AVAIL) &&
iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
- sizeof *vq->used +
- num * sizeof(*vq->used->ring) + s,
- VHOST_ADDR_USED);
+ vhost_get_used_size(vq, num), VHOST_ADDR_USED);
}
-EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
+EXPORT_SYMBOL_GPL(vq_meta_prefetch);
/* Can we log writes? */
/* Caller should have device mutex but not vq mutex */
@@ -1299,13 +1874,10 @@ EXPORT_SYMBOL_GPL(vhost_log_access_ok);
static bool vq_log_access_ok(struct vhost_virtqueue *vq,
void __user *log_base)
{
- size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
-
return vq_memory_access_ok(log_base, vq->umem,
vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
(!vq->log_used || log_access_ok(log_base, vq->log_addr,
- sizeof *vq->used +
- vq->num * sizeof *vq->used->ring + s));
+ vhost_get_used_size(vq, vq->num)));
}
/* Can we start vq? */
@@ -1405,6 +1977,121 @@ err:
return -EFAULT;
}
+static long vhost_vring_set_num(struct vhost_dev *d,
+ struct vhost_virtqueue *vq,
+ void __user *argp)
+{
+ struct vhost_vring_state s;
+
+ /* Resizing ring with an active backend?
+ * You don't want to do that. */
+ if (vq->private_data)
+ return -EBUSY;
+
+ if (copy_from_user(&s, argp, sizeof s))
+ return -EFAULT;
+
+ if (!s.num || s.num > 0xffff || (s.num & (s.num - 1)))
+ return -EINVAL;
+ vq->num = s.num;
+
+ return 0;
+}
+
+static long vhost_vring_set_addr(struct vhost_dev *d,
+ struct vhost_virtqueue *vq,
+ void __user *argp)
+{
+ struct vhost_vring_addr a;
+
+ if (copy_from_user(&a, argp, sizeof a))
+ return -EFAULT;
+ if (a.flags & ~(0x1 << VHOST_VRING_F_LOG))
+ return -EOPNOTSUPP;
+
+ /* For 32bit, verify that the top 32bits of the user
+ data are set to zero. */
+ if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
+ (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
+ (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr)
+ return -EFAULT;
+
+ /* Make sure it's safe to cast pointers to vring types. */
+ BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
+ BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
+ if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
+ (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
+ (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1)))
+ return -EINVAL;
+
+ /* We only verify access here if backend is configured.
+ * If it is not, we don't as size might not have been setup.
+ * We will verify when backend is configured. */
+ if (vq->private_data) {
+ if (!vq_access_ok(vq, vq->num,
+ (void __user *)(unsigned long)a.desc_user_addr,
+ (void __user *)(unsigned long)a.avail_user_addr,
+ (void __user *)(unsigned long)a.used_user_addr))
+ return -EINVAL;
+
+ /* Also validate log access for used ring if enabled. */
+ if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) &&
+ !log_access_ok(vq->log_base, a.log_guest_addr,
+ sizeof *vq->used +
+ vq->num * sizeof *vq->used->ring))
+ return -EINVAL;
+ }
+
+ vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
+ vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
+ vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
+ vq->log_addr = a.log_guest_addr;
+ vq->used = (void __user *)(unsigned long)a.used_user_addr;
+
+ return 0;
+}
+
+static long vhost_vring_set_num_addr(struct vhost_dev *d,
+ struct vhost_virtqueue *vq,
+ unsigned int ioctl,
+ void __user *argp)
+{
+ long r;
+
+ mutex_lock(&vq->mutex);
+
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ /* Unregister MMU notifer to allow invalidation callback
+ * can access vq->uaddrs[] without holding a lock.
+ */
+ if (d->mm)
+ mmu_notifier_unregister(&d->mmu_notifier, d->mm);
+
+ vhost_uninit_vq_maps(vq);
+#endif
+
+ switch (ioctl) {
+ case VHOST_SET_VRING_NUM:
+ r = vhost_vring_set_num(d, vq, argp);
+ break;
+ case VHOST_SET_VRING_ADDR:
+ r = vhost_vring_set_addr(d, vq, argp);
+ break;
+ default:
+ BUG();
+ }
+
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+ vhost_setup_vq_uaddr(vq);
+
+ if (d->mm)
+ mmu_notifier_register(&d->mmu_notifier, d->mm);
+#endif
+
+ mutex_unlock(&vq->mutex);
+
+ return r;
+}
long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
{
struct file *eventfp, *filep = NULL;
@@ -1414,7 +2101,6 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
struct vhost_virtqueue *vq;
struct vhost_vring_state s;
struct vhost_vring_file f;
- struct vhost_vring_addr a;
u32 idx;
long r;
@@ -1427,26 +2113,14 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
idx = array_index_nospec(idx, d->nvqs);
vq = d->vqs[idx];
+ if (ioctl == VHOST_SET_VRING_NUM ||
+ ioctl == VHOST_SET_VRING_ADDR) {
+ return vhost_vring_set_num_addr(d, vq, ioctl, argp);
+ }
+
mutex_lock(&vq->mutex);
switch (ioctl) {
- case VHOST_SET_VRING_NUM:
- /* Resizing ring with an active backend?
- * You don't want to do that. */
- if (vq->private_data) {
- r = -EBUSY;
- break;
- }
- if (copy_from_user(&s, argp, sizeof s)) {
- r = -EFAULT;
- break;
- }
- if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) {
- r = -EINVAL;
- break;
- }
- vq->num = s.num;
- break;
case VHOST_SET_VRING_BASE:
/* Moving base with an active backend?
* You don't want to do that. */
@@ -1472,62 +2146,6 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
if (copy_to_user(argp, &s, sizeof s))
r = -EFAULT;
break;
- case VHOST_SET_VRING_ADDR:
- if (copy_from_user(&a, argp, sizeof a)) {
- r = -EFAULT;
- break;
- }
- if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) {
- r = -EOPNOTSUPP;
- break;
- }
- /* For 32bit, verify that the top 32bits of the user
- data are set to zero. */
- if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
- (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
- (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr) {
- r = -EFAULT;
- break;
- }
-
- /* Make sure it's safe to cast pointers to vring types. */
- BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
- BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
- if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
- (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
- (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1))) {
- r = -EINVAL;
- break;
- }
-
- /* We only verify access here if backend is configured.
- * If it is not, we don't as size might not have been setup.
- * We will verify when backend is configured. */
- if (vq->private_data) {
- if (!vq_access_ok(vq, vq->num,
- (void __user *)(unsigned long)a.desc_user_addr,
- (void __user *)(unsigned long)a.avail_user_addr,
- (void __user *)(unsigned long)a.used_user_addr)) {
- r = -EINVAL;
- break;
- }
-
- /* Also validate log access for used ring if enabled. */
- if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) &&
- !log_access_ok(vq->log_base, a.log_guest_addr,
- sizeof *vq->used +
- vq->num * sizeof *vq->used->ring)) {
- r = -EINVAL;
- break;
- }
- }
-
- vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
- vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
- vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
- vq->log_addr = a.log_guest_addr;
- vq->used = (void __user *)(unsigned long)a.used_user_addr;
- break;
case VHOST_SET_VRING_KICK:
if (copy_from_user(&f, argp, sizeof f)) {
r = -EFAULT;
@@ -1861,8 +2479,7 @@ EXPORT_SYMBOL_GPL(vhost_log_write);
static int vhost_update_used_flags(struct vhost_virtqueue *vq)
{
void __user *used;
- if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
- &vq->used->flags) < 0)
+ if (vhost_put_used_flags(vq))
return -EFAULT;
if (unlikely(vq->log_used)) {
/* Make sure the flag is seen before log. */
@@ -1879,8 +2496,7 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq)
static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
{
- if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
- vhost_avail_event(vq)))
+ if (vhost_put_avail_event(vq))
return -EFAULT;
if (unlikely(vq->log_used)) {
void __user *used;
@@ -1916,7 +2532,7 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
r = -EFAULT;
goto err;
}
- r = vhost_get_used(vq, last_used_idx, &vq->used->idx);
+ r = vhost_get_used_idx(vq, &last_used_idx);
if (r) {
vq_err(vq, "Can't access used idx at %p\n",
&vq->used->idx);
@@ -2115,7 +2731,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
last_avail_idx = vq->last_avail_idx;
if (vq->avail_idx == vq->last_avail_idx) {
- if (unlikely(vhost_get_avail(vq, avail_idx, &vq->avail->idx))) {
+ if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) {
vq_err(vq, "Failed to access avail idx at %p\n",
&vq->avail->idx);
return -EFAULT;
@@ -2142,8 +2758,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
/* Grab the next descriptor number they're advertising, and increment
* the index we've seen. */
- if (unlikely(vhost_get_avail(vq, ring_head,
- &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
+ if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) {
vq_err(vq, "Failed to read head: idx %d address %p\n",
last_avail_idx,
&vq->avail->ring[last_avail_idx % vq->num]);
@@ -2178,8 +2793,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
i, vq->num, head);
return -EINVAL;
}
- ret = vhost_copy_from_user(vq, &desc, vq->desc + i,
- sizeof desc);
+ ret = vhost_get_desc(vq, &desc, i);
if (unlikely(ret)) {
vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
i, vq->desc + i);
@@ -2272,16 +2886,7 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
start = vq->last_used_idx & (vq->num - 1);
used = vq->used->ring + start;
- if (count == 1) {
- if (vhost_put_user(vq, heads[0].id, &used->id)) {
- vq_err(vq, "Failed to write used id");
- return -EFAULT;
- }
- if (vhost_put_user(vq, heads[0].len, &used->len)) {
- vq_err(vq, "Failed to write used len");
- return -EFAULT;
- }
- } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) {
+ if (vhost_put_used(vq, heads, start, count)) {
vq_err(vq, "Failed to write used");
return -EFAULT;
}
@@ -2323,8 +2928,7 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
/* Make sure buffer is written before we update index. */
smp_wmb();
- if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
- &vq->used->idx)) {
+ if (vhost_put_used_idx(vq)) {
vq_err(vq, "Failed to increment used idx");
return -EFAULT;
}
@@ -2357,7 +2961,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
__virtio16 flags;
- if (vhost_get_avail(vq, flags, &vq->avail->flags)) {
+ if (vhost_get_avail_flags(vq, &flags)) {
vq_err(vq, "Failed to get flags");
return true;
}
@@ -2371,7 +2975,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
if (unlikely(!v))
return true;
- if (vhost_get_avail(vq, event, vhost_used_event(vq))) {
+ if (vhost_get_used_event(vq, &event)) {
vq_err(vq, "Failed to get used event idx");
return true;
}
@@ -2416,7 +3020,7 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
if (vq->avail_idx != vq->last_avail_idx)
return false;
- r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
+ r = vhost_get_avail_idx(vq, &avail_idx);
if (unlikely(r))
return false;
vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
@@ -2452,7 +3056,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
/* They could have slipped one in as we were doing that: make
* sure it's written, then check again. */
smp_mb();
- r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
+ r = vhost_get_avail_idx(vq, &avail_idx);
if (r) {
vq_err(vq, "Failed to check avail idx at %p: %d\n",
&vq->avail->idx, r);