summaryrefslogtreecommitdiff
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c29
1 files changed, 20 insertions, 9 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index c579dcc9200c..8b5a1b33d0fe 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -320,7 +320,7 @@ long vhost_dev_reset_owner(struct vhost_dev *dev)
vhost_dev_cleanup(dev);
memory->nregions = 0;
- dev->memory = memory;
+ RCU_INIT_POINTER(dev->memory, memory);
return 0;
}
@@ -352,8 +352,9 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
fput(dev->log_file);
dev->log_file = NULL;
/* No one will access memory at this point */
- kfree(dev->memory);
- dev->memory = NULL;
+ kfree(rcu_dereference_protected(dev->memory,
+ lockdep_is_held(&dev->mutex)));
+ RCU_INIT_POINTER(dev->memory, NULL);
if (dev->mm)
mmput(dev->mm);
dev->mm = NULL;
@@ -440,14 +441,22 @@ static int vq_access_ok(unsigned int num,
/* Caller should have device mutex but not vq mutex */
int vhost_log_access_ok(struct vhost_dev *dev)
{
- return memory_access_ok(dev, dev->memory, 1);
+ struct vhost_memory *mp;
+
+ mp = rcu_dereference_protected(dev->memory,
+ lockdep_is_held(&dev->mutex));
+ return memory_access_ok(dev, mp, 1);
}
/* Verify access for write logging. */
/* Caller should have vq mutex and device mutex */
static int vq_log_access_ok(struct vhost_virtqueue *vq, void __user *log_base)
{
- return vq_memory_access_ok(log_base, vq->dev->memory,
+ struct vhost_memory *mp;
+
+ mp = rcu_dereference_protected(vq->dev->memory,
+ lockdep_is_held(&vq->mutex));
+ return vq_memory_access_ok(log_base, mp,
vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) &&
(!vq->log_used || log_access_ok(log_base, vq->log_addr,
sizeof *vq->used +
@@ -487,7 +496,8 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
kfree(newmem);
return -EFAULT;
}
- oldmem = d->memory;
+ oldmem = rcu_dereference_protected(d->memory,
+ lockdep_is_held(&d->mutex));
rcu_assign_pointer(d->memory, newmem);
synchronize_rcu();
kfree(oldmem);
@@ -858,11 +868,12 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
if (r < 0)
return r;
len -= l;
- if (!len)
+ if (!len) {
+ if (vq->log_ctx)
+ eventfd_signal(vq->log_ctx, 1);
return 0;
+ }
}
- if (vq->log_ctx)
- eventfd_signal(vq->log_ctx, 1);
/* Length written exceeds what we have stored. This is a bug. */
BUG();
return 0;