diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r-- | drivers/vhost/vhost.c | 74 |
1 files changed, 56 insertions, 18 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 74d135ee7e26..9ad45e1d27f0 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -298,6 +298,13 @@ static void vhost_vq_meta_reset(struct vhost_dev *d) __vhost_vq_meta_reset(d->vqs[i]); } +static void vhost_vring_call_reset(struct vhost_vring_call *call_ctx) +{ + call_ctx->ctx = NULL; + memset(&call_ctx->producer, 0x0, sizeof(struct irq_bypass_producer)); + spin_lock_init(&call_ctx->ctx_lock); +} + static void vhost_vq_reset(struct vhost_dev *dev, struct vhost_virtqueue *vq) { @@ -319,13 +326,13 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->log_base = NULL; vq->error_ctx = NULL; vq->kick = NULL; - vq->call_ctx = NULL; vq->log_ctx = NULL; vhost_reset_is_le(vq); vhost_disable_cross_endian(vq); vq->busyloop_timeout = 0; vq->umem = NULL; vq->iotlb = NULL; + vhost_vring_call_reset(&vq->call_ctx); __vhost_vq_meta_reset(vq); } @@ -685,8 +692,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev) eventfd_ctx_put(dev->vqs[i]->error_ctx); if (dev->vqs[i]->kick) fput(dev->vqs[i]->kick); - if (dev->vqs[i]->call_ctx) - eventfd_ctx_put(dev->vqs[i]->call_ctx); + if (dev->vqs[i]->call_ctx.ctx) + eventfd_ctx_put(dev->vqs[i]->call_ctx.ctx); vhost_vq_reset(dev, dev->vqs[i]); } vhost_dev_free_iovecs(dev); @@ -1283,6 +1290,11 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, vring_used_t __user *used) { + /* If an IOTLB device is present, the vring addresses are + * GIOVAs. Access validation occurs at prefetch time. */ + if (vq->iotlb) + return true; + return access_ok(desc, vhost_get_desc_size(vq, num)) && access_ok(avail, vhost_get_avail_size(vq, num)) && access_ok(used, vhost_get_used_size(vq, num)); @@ -1358,6 +1370,20 @@ bool vhost_log_access_ok(struct vhost_dev *dev) } EXPORT_SYMBOL_GPL(vhost_log_access_ok); +static bool vq_log_used_access_ok(struct vhost_virtqueue *vq, + void __user *log_base, + bool log_used, + u64 log_addr) +{ + /* If an IOTLB device is present, log_addr is a GIOVA that + * will never be logged by log_used(). */ + if (vq->iotlb) + return true; + + return !log_used || log_access_ok(log_base, log_addr, + vhost_get_used_size(vq, vq->num)); +} + /* Verify access for write logging. */ /* Caller should have vq mutex and device mutex */ static bool vq_log_access_ok(struct vhost_virtqueue *vq, @@ -1365,8 +1391,7 @@ static bool vq_log_access_ok(struct vhost_virtqueue *vq, { return vq_memory_access_ok(log_base, vq->umem, vhost_has_feature(vq, VHOST_F_LOG_ALL)) && - (!vq->log_used || log_access_ok(log_base, vq->log_addr, - vhost_get_used_size(vq, vq->num))); + vq_log_used_access_ok(vq, log_base, vq->log_used, vq->log_addr); } /* Can we start vq? */ @@ -1376,10 +1401,6 @@ bool vhost_vq_access_ok(struct vhost_virtqueue *vq) if (!vq_log_access_ok(vq, vq->log_base)) return false; - /* Access validation occurs at prefetch time with IOTLB */ - if (vq->iotlb) - return true; - return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used); } EXPORT_SYMBOL_GPL(vhost_vq_access_ok); @@ -1405,7 +1426,7 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) memcpy(newmem, &mem, size); if (copy_from_user(newmem->regions, m->regions, - mem.nregions * sizeof *m->regions)) { + flex_array_size(newmem, regions, mem.nregions))) { kvfree(newmem); return -EFAULT; } @@ -1509,10 +1530,9 @@ static long vhost_vring_set_addr(struct vhost_dev *d, return -EINVAL; /* Also validate log access for used ring if enabled. */ - if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) && - !log_access_ok(vq->log_base, a.log_guest_addr, - sizeof *vq->used + - vq->num * sizeof *vq->used->ring)) + if (!vq_log_used_access_ok(vq, vq->log_base, + a.flags & (0x1 << VHOST_VRING_F_LOG), + a.log_guest_addr)) return -EINVAL; } @@ -1629,7 +1649,10 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg r = PTR_ERR(ctx); break; } - swap(ctx, vq->call_ctx); + + spin_lock(&vq->call_ctx.ctx_lock); + swap(ctx, vq->call_ctx.ctx); + spin_unlock(&vq->call_ctx.ctx_lock); break; case VHOST_SET_VRING_ERR: if (copy_from_user(&f, argp, sizeof f)) { @@ -2435,8 +2458,8 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) { /* Signal the Guest tell them we used something up. */ - if (vq->call_ctx && vhost_notify(dev, vq)) - eventfd_signal(vq->call_ctx, 1); + if (vq->call_ctx.ctx && vhost_notify(dev, vq)) + eventfd_signal(vq->call_ctx.ctx, 1); } EXPORT_SYMBOL_GPL(vhost_signal); @@ -2527,7 +2550,7 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { r = vhost_update_used_flags(vq); if (r) - vq_err(vq, "Failed to enable notification at %p: %d\n", + vq_err(vq, "Failed to disable notification at %p: %d\n", &vq->used->flags, r); } } @@ -2576,6 +2599,21 @@ struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev, } EXPORT_SYMBOL_GPL(vhost_dequeue_msg); +void vhost_set_backend_features(struct vhost_dev *dev, u64 features) +{ + struct vhost_virtqueue *vq; + int i; + + mutex_lock(&dev->mutex); + for (i = 0; i < dev->nvqs; ++i) { + vq = dev->vqs[i]; + mutex_lock(&vq->mutex); + vq->acked_backend_features = features; + mutex_unlock(&vq->mutex); + } + mutex_unlock(&dev->mutex); +} +EXPORT_SYMBOL_GPL(vhost_set_backend_features); static int __init vhost_init(void) { |