summaryrefslogtreecommitdiff
path: root/drivers/iommu/amd_iommu_v2.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/iommu/amd_iommu_v2.c')
-rw-r--r--drivers/iommu/amd_iommu_v2.c125
1 files changed, 78 insertions, 47 deletions
diff --git a/drivers/iommu/amd_iommu_v2.c b/drivers/iommu/amd_iommu_v2.c
index d4daa05efe60..5f578e850fc5 100644
--- a/drivers/iommu/amd_iommu_v2.c
+++ b/drivers/iommu/amd_iommu_v2.c
@@ -45,15 +45,17 @@ struct pri_queue {
struct pasid_state {
struct list_head list; /* For global state-list */
atomic_t count; /* Reference count */
- atomic_t mmu_notifier_count; /* Counting nested mmu_notifier
+ unsigned mmu_notifier_count; /* Counting nested mmu_notifier
calls */
- struct task_struct *task; /* Task bound to this PASID */
struct mm_struct *mm; /* mm_struct for the faults */
- struct mmu_notifier mn; /* mmu_otifier handle */
+ struct mmu_notifier mn; /* mmu_notifier handle */
struct pri_queue pri[PRI_QUEUE_SIZE]; /* PRI tag states */
struct device_state *device_state; /* Link to our device_state */
int pasid; /* PASID index */
- spinlock_t lock; /* Protect pri_queues */
+ bool invalid; /* Used during setup and
+ teardown of the pasid */
+ spinlock_t lock; /* Protect pri_queues and
+ mmu_notifer_count */
wait_queue_head_t wq; /* To wait for count == 0 */
};
@@ -98,7 +100,6 @@ static struct workqueue_struct *iommu_wq;
static u64 *empty_page_table;
static void free_pasid_states(struct device_state *dev_state);
-static void unbind_pasid(struct device_state *dev_state, int pasid);
static u16 device_id(struct pci_dev *pdev)
{
@@ -296,37 +297,29 @@ static void put_pasid_state_wait(struct pasid_state *pasid_state)
schedule();
finish_wait(&pasid_state->wq, &wait);
- mmput(pasid_state->mm);
free_pasid_state(pasid_state);
}
-static void __unbind_pasid(struct pasid_state *pasid_state)
+static void unbind_pasid(struct pasid_state *pasid_state)
{
struct iommu_domain *domain;
domain = pasid_state->device_state->domain;
+ /*
+ * Mark pasid_state as invalid, no more faults will we added to the
+ * work queue after this is visible everywhere.
+ */
+ pasid_state->invalid = true;
+
+ /* Make sure this is visible */
+ smp_wmb();
+
+ /* After this the device/pasid can't access the mm anymore */
amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
- clear_pasid_state(pasid_state->device_state, pasid_state->pasid);
/* Make sure no more pending faults are in the queue */
flush_workqueue(iommu_wq);
-
- mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
-
- put_pasid_state(pasid_state); /* Reference taken in bind() function */
-}
-
-static void unbind_pasid(struct device_state *dev_state, int pasid)
-{
- struct pasid_state *pasid_state;
-
- pasid_state = get_pasid_state(dev_state, pasid);
- if (pasid_state == NULL)
- return;
-
- __unbind_pasid(pasid_state);
- put_pasid_state_wait(pasid_state); /* Reference taken in this function */
}
static void free_pasid_states_level1(struct pasid_state **tbl)
@@ -372,6 +365,12 @@ static void free_pasid_states(struct device_state *dev_state)
* unbind the PASID
*/
mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
+
+ put_pasid_state_wait(pasid_state); /* Reference taken in
+ amd_iommu_bind_pasid */
+
+ /* Drop reference taken in amd_iommu_bind_pasid */
+ put_device_state(dev_state);
}
if (dev_state->pasid_levels == 2)
@@ -410,14 +409,6 @@ static int mn_clear_flush_young(struct mmu_notifier *mn,
return 0;
}
-static void mn_change_pte(struct mmu_notifier *mn,
- struct mm_struct *mm,
- unsigned long address,
- pte_t pte)
-{
- __mn_flush_page(mn, address);
-}
-
static void mn_invalidate_page(struct mmu_notifier *mn,
struct mm_struct *mm,
unsigned long address)
@@ -431,15 +422,19 @@ static void mn_invalidate_range_start(struct mmu_notifier *mn,
{
struct pasid_state *pasid_state;
struct device_state *dev_state;
+ unsigned long flags;
pasid_state = mn_to_state(mn);
dev_state = pasid_state->device_state;
- if (atomic_add_return(1, &pasid_state->mmu_notifier_count) == 1) {
+ spin_lock_irqsave(&pasid_state->lock, flags);
+ if (pasid_state->mmu_notifier_count == 0) {
amd_iommu_domain_set_gcr3(dev_state->domain,
pasid_state->pasid,
__pa(empty_page_table));
}
+ pasid_state->mmu_notifier_count += 1;
+ spin_unlock_irqrestore(&pasid_state->lock, flags);
}
static void mn_invalidate_range_end(struct mmu_notifier *mn,
@@ -448,37 +443,42 @@ static void mn_invalidate_range_end(struct mmu_notifier *mn,
{
struct pasid_state *pasid_state;
struct device_state *dev_state;
+ unsigned long flags;
pasid_state = mn_to_state(mn);
dev_state = pasid_state->device_state;
- if (atomic_dec_and_test(&pasid_state->mmu_notifier_count)) {
+ spin_lock_irqsave(&pasid_state->lock, flags);
+ pasid_state->mmu_notifier_count -= 1;
+ if (pasid_state->mmu_notifier_count == 0) {
amd_iommu_domain_set_gcr3(dev_state->domain,
pasid_state->pasid,
__pa(pasid_state->mm->pgd));
}
+ spin_unlock_irqrestore(&pasid_state->lock, flags);
}
static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
{
struct pasid_state *pasid_state;
struct device_state *dev_state;
+ bool run_inv_ctx_cb;
might_sleep();
- pasid_state = mn_to_state(mn);
- dev_state = pasid_state->device_state;
+ pasid_state = mn_to_state(mn);
+ dev_state = pasid_state->device_state;
+ run_inv_ctx_cb = !pasid_state->invalid;
- if (pasid_state->device_state->inv_ctx_cb)
+ if (run_inv_ctx_cb && pasid_state->device_state->inv_ctx_cb)
dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
- unbind_pasid(dev_state, pasid_state->pasid);
+ unbind_pasid(pasid_state);
}
static struct mmu_notifier_ops iommu_mn = {
.release = mn_release,
.clear_flush_young = mn_clear_flush_young,
- .change_pte = mn_change_pte,
.invalidate_page = mn_invalidate_page,
.invalidate_range_start = mn_invalidate_range_start,
.invalidate_range_end = mn_invalidate_range_end,
@@ -520,7 +520,7 @@ static void do_fault(struct work_struct *work)
write = !!(fault->flags & PPR_FAULT_WRITE);
down_read(&fault->state->mm->mmap_sem);
- npages = get_user_pages(fault->state->task, fault->state->mm,
+ npages = get_user_pages(NULL, fault->state->mm,
fault->address, 1, write, 0, &page, NULL);
up_read(&fault->state->mm->mmap_sem);
@@ -578,7 +578,7 @@ static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
goto out;
pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
- if (pasid_state == NULL) {
+ if (pasid_state == NULL || pasid_state->invalid) {
/* We know the device but not the PASID -> send INVALID */
amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
PPR_INVALID, tag);
@@ -603,6 +603,7 @@ static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
fault->state = pasid_state;
fault->tag = tag;
fault->finish = finish;
+ fault->pasid = iommu_fault->pasid;
fault->flags = iommu_fault->flags;
INIT_WORK(&fault->work, do_fault);
@@ -611,6 +612,10 @@ static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
ret = NOTIFY_OK;
out_drop_state:
+
+ if (ret != NOTIFY_OK && pasid_state)
+ put_pasid_state(pasid_state);
+
put_device_state(dev_state);
out:
@@ -626,6 +631,7 @@ int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
{
struct pasid_state *pasid_state;
struct device_state *dev_state;
+ struct mm_struct *mm;
u16 devid;
int ret;
@@ -649,21 +655,23 @@ int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
if (pasid_state == NULL)
goto out;
+
atomic_set(&pasid_state->count, 1);
- atomic_set(&pasid_state->mmu_notifier_count, 0);
init_waitqueue_head(&pasid_state->wq);
spin_lock_init(&pasid_state->lock);
- pasid_state->task = task;
- pasid_state->mm = get_task_mm(task);
+ mm = get_task_mm(task);
+ pasid_state->mm = mm;
pasid_state->device_state = dev_state;
pasid_state->pasid = pasid;
+ pasid_state->invalid = true; /* Mark as valid only if we are
+ done with setting up the pasid */
pasid_state->mn.ops = &iommu_mn;
if (pasid_state->mm == NULL)
goto out_free;
- mmu_notifier_register(&pasid_state->mn, pasid_state->mm);
+ mmu_notifier_register(&pasid_state->mn, mm);
ret = set_pasid_state(dev_state, pasid_state, pasid);
if (ret)
@@ -674,15 +682,26 @@ int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
if (ret)
goto out_clear_state;
+ /* Now we are ready to handle faults */
+ pasid_state->invalid = false;
+
+ /*
+ * Drop the reference to the mm_struct here. We rely on the
+ * mmu_notifier release call-back to inform us when the mm
+ * is going away.
+ */
+ mmput(mm);
+
return 0;
out_clear_state:
clear_pasid_state(dev_state, pasid);
out_unregister:
- mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
+ mmu_notifier_unregister(&pasid_state->mn, mm);
out_free:
+ mmput(mm);
free_pasid_state(pasid_state);
out:
@@ -720,10 +739,22 @@ void amd_iommu_unbind_pasid(struct pci_dev *pdev, int pasid)
*/
put_pasid_state(pasid_state);
- /* This will call the mn_release function and unbind the PASID */
+ /* Clear the pasid state so that the pasid can be re-used */
+ clear_pasid_state(dev_state, pasid_state->pasid);
+
+ /*
+ * Call mmu_notifier_unregister to drop our reference
+ * to pasid_state->mm
+ */
mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
+ put_pasid_state_wait(pasid_state); /* Reference taken in
+ amd_iommu_bind_pasid */
out:
+ /* Drop reference taken in this function */
+ put_device_state(dev_state);
+
+ /* Drop reference taken in amd_iommu_bind_pasid */
put_device_state(dev_state);
}
EXPORT_SYMBOL(amd_iommu_unbind_pasid);