diff options
Diffstat (limited to 'drivers/vfio/group.c')
-rw-r--r-- | drivers/vfio/group.c | 46 |
1 files changed, 38 insertions, 8 deletions
diff --git a/drivers/vfio/group.c b/drivers/vfio/group.c index e166ad7ce6e7..27d5ba7cf9dc 100644 --- a/drivers/vfio/group.c +++ b/drivers/vfio/group.c @@ -140,7 +140,7 @@ static int vfio_group_ioctl_set_container(struct vfio_group *group, ret = iommufd_vfio_compat_ioas_create(iommufd); if (ret) { - iommufd_ctx_put(group->iommufd); + iommufd_ctx_put(iommufd); goto out_unlock; } @@ -157,6 +157,18 @@ out_unlock: return ret; } +static void vfio_device_group_get_kvm_safe(struct vfio_device *device) +{ + spin_lock(&device->group->kvm_ref_lock); + if (!device->group->kvm) + goto unlock; + + _vfio_device_get_kvm_safe(device, device->group->kvm); + +unlock: + spin_unlock(&device->group->kvm_ref_lock); +} + static int vfio_device_group_open(struct vfio_device *device) { int ret; @@ -167,13 +179,23 @@ static int vfio_device_group_open(struct vfio_device *device) goto out_unlock; } + mutex_lock(&device->dev_set->lock); + /* - * Here we pass the KVM pointer with the group under the lock. If the - * device driver will use it, it must obtain a reference and release it - * during close_device. + * Before the first device open, get the KVM pointer currently + * associated with the group (if there is one) and obtain a reference + * now that will be held until the open_count reaches 0 again. Save + * the pointer in the device for use by drivers. */ - ret = vfio_device_open(device, device->group->iommufd, - device->group->kvm); + if (device->open_count == 0) + vfio_device_group_get_kvm_safe(device); + + ret = vfio_device_open(device, device->group->iommufd); + + if (device->open_count == 0) + vfio_device_put_kvm(device); + + mutex_unlock(&device->dev_set->lock); out_unlock: mutex_unlock(&device->group->group_lock); @@ -183,7 +205,14 @@ out_unlock: void vfio_device_group_close(struct vfio_device *device) { mutex_lock(&device->group->group_lock); + mutex_lock(&device->dev_set->lock); + vfio_device_close(device, device->group->iommufd); + + if (device->open_count == 0) + vfio_device_put_kvm(device); + + mutex_unlock(&device->dev_set->lock); mutex_unlock(&device->group->group_lock); } @@ -453,6 +482,7 @@ static struct vfio_group *vfio_group_alloc(struct iommu_group *iommu_group, refcount_set(&group->drivers, 1); mutex_init(&group->group_lock); + spin_lock_init(&group->kvm_ref_lock); INIT_LIST_HEAD(&group->device_list); mutex_init(&group->device_lock); group->iommu_group = iommu_group; @@ -806,9 +836,9 @@ void vfio_file_set_kvm(struct file *file, struct kvm *kvm) if (!vfio_file_is_group(file)) return; - mutex_lock(&group->group_lock); + spin_lock(&group->kvm_ref_lock); group->kvm = kvm; - mutex_unlock(&group->group_lock); + spin_unlock(&group->kvm_ref_lock); } EXPORT_SYMBOL_GPL(vfio_file_set_kvm); |