summaryrefslogtreecommitdiff
path: root/fs/userfaultfd.c
diff options
context:
space:
mode:
Diffstat (limited to 'fs/userfaultfd.c')
-rw-r--r--fs/userfaultfd.c115
1 files changed, 52 insertions, 63 deletions
diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
index cc694846617a..44d1ee429eb0 100644
--- a/fs/userfaultfd.c
+++ b/fs/userfaultfd.c
@@ -113,7 +113,7 @@ static void userfaultfd_set_vm_flags(struct vm_area_struct *vma,
{
const bool uffd_wp_changed = (vma->vm_flags ^ flags) & VM_UFFD_WP;
- vma->vm_flags = flags;
+ vm_flags_reset(vma, flags);
/*
* For shared mappings, we want to enable writenotify while
* userfaultfd-wp is enabled (see vma_wants_writenotify()). We'll simply
@@ -252,14 +252,12 @@ static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
unsigned long flags,
unsigned long reason)
{
- struct mm_struct *mm = ctx->mm;
pte_t *ptep, pte;
bool ret = true;
- mmap_assert_locked(mm);
-
- ptep = huge_pte_offset(mm, address, vma_mmu_pagesize(vma));
+ mmap_assert_locked(ctx->mm);
+ ptep = hugetlb_walk(vma, address, vma_mmu_pagesize(vma));
if (!ptep)
goto out;
@@ -391,7 +389,8 @@ static inline unsigned int userfaultfd_get_blocking_state(unsigned int flags)
*/
vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
{
- struct mm_struct *mm = vmf->vma->vm_mm;
+ struct vm_area_struct *vma = vmf->vma;
+ struct mm_struct *mm = vma->vm_mm;
struct userfaultfd_ctx *ctx;
struct userfaultfd_wait_queue uwq;
vm_fault_t ret = VM_FAULT_SIGBUS;
@@ -418,7 +417,7 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
*/
mmap_assert_locked(mm);
- ctx = vmf->vma->vm_userfaultfd_ctx.ctx;
+ ctx = vma->vm_userfaultfd_ctx.ctx;
if (!ctx)
goto out;
@@ -508,6 +507,15 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
blocking_state = userfaultfd_get_blocking_state(vmf->flags);
+ /*
+ * Take the vma lock now, in order to safely call
+ * userfaultfd_huge_must_wait() later. Since acquiring the
+ * (sleepable) vma lock can modify the current task state, that
+ * must be before explicitly calling set_current_state().
+ */
+ if (is_vm_hugetlb_page(vma))
+ hugetlb_vma_lock_read(vma);
+
spin_lock_irq(&ctx->fault_pending_wqh.lock);
/*
* After the __add_wait_queue the uwq is visible to userland
@@ -522,13 +530,15 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
set_current_state(blocking_state);
spin_unlock_irq(&ctx->fault_pending_wqh.lock);
- if (!is_vm_hugetlb_page(vmf->vma))
+ if (!is_vm_hugetlb_page(vma))
must_wait = userfaultfd_must_wait(ctx, vmf->address, vmf->flags,
reason);
else
- must_wait = userfaultfd_huge_must_wait(ctx, vmf->vma,
+ must_wait = userfaultfd_huge_must_wait(ctx, vma,
vmf->address,
vmf->flags, reason);
+ if (is_vm_hugetlb_page(vma))
+ hugetlb_vma_unlock_read(vma);
mmap_read_unlock(mm);
if (likely(must_wait && !READ_ONCE(ctx->released))) {
@@ -873,7 +883,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
/* len == 0 means wake all */
struct userfaultfd_wake_range range = { .len = 0, };
unsigned long new_flags;
- MA_STATE(mas, &mm->mm_mt, 0, 0);
+ VMA_ITERATOR(vmi, mm, 0);
WRITE_ONCE(ctx->released, true);
@@ -890,7 +900,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
*/
mmap_write_lock(mm);
prev = NULL;
- mas_for_each(&mas, vma, ULONG_MAX) {
+ for_each_vma(vmi, vma) {
cond_resched();
BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^
!!(vma->vm_flags & __VM_UFFD_FLAGS));
@@ -899,13 +909,12 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
continue;
}
new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
- prev = vma_merge(mm, prev, vma->vm_start, vma->vm_end,
+ prev = vma_merge(&vmi, mm, prev, vma->vm_start, vma->vm_end,
new_flags, vma->anon_vma,
vma->vm_file, vma->vm_pgoff,
vma_policy(vma),
NULL_VM_UFFD_CTX, anon_vma_name(vma));
if (prev) {
- mas_pause(&mas);
vma = prev;
} else {
prev = vma;
@@ -1292,7 +1301,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
bool found;
bool basic_ioctls;
unsigned long start, end, vma_end;
- MA_STATE(mas, &mm->mm_mt, 0, 0);
+ struct vma_iterator vmi;
user_uffdio_register = (struct uffdio_register __user *) arg;
@@ -1334,17 +1343,13 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
if (!mmget_not_zero(mm))
goto out;
+ ret = -EINVAL;
mmap_write_lock(mm);
- mas_set(&mas, start);
- vma = mas_find(&mas, ULONG_MAX);
+ vma_iter_init(&vmi, mm, start);
+ vma = vma_find(&vmi, end);
if (!vma)
goto out_unlock;
- /* check that there's at least one vma in the range */
- ret = -EINVAL;
- if (vma->vm_start >= end)
- goto out_unlock;
-
/*
* If the first vma contains huge pages, make sure start address
* is aligned to huge page size.
@@ -1361,7 +1366,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
*/
found = false;
basic_ioctls = false;
- for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
+ cur = vma;
+ do {
cond_resched();
BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@@ -1418,16 +1424,14 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
basic_ioctls = true;
found = true;
- }
+ } for_each_vma_range(vmi, cur, end);
BUG_ON(!found);
- mas_set(&mas, start);
- prev = mas_prev(&mas, 0);
- if (prev != vma)
- mas_next(&mas, ULONG_MAX);
+ vma_iter_set(&vmi, start);
+ prev = vma_prev(&vmi);
ret = 0;
- do {
+ for_each_vma_range(vmi, vma, end) {
cond_resched();
BUG_ON(!vma_can_userfault(vma, vm_flags));
@@ -1448,30 +1452,25 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
vma_end = min(end, vma->vm_end);
new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
- prev = vma_merge(mm, prev, start, vma_end, new_flags,
+ prev = vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
vma->anon_vma, vma->vm_file, vma->vm_pgoff,
vma_policy(vma),
((struct vm_userfaultfd_ctx){ ctx }),
anon_vma_name(vma));
if (prev) {
/* vma_merge() invalidated the mas */
- mas_pause(&mas);
vma = prev;
goto next;
}
if (vma->vm_start < start) {
- ret = split_vma(mm, vma, start, 1);
+ ret = split_vma(&vmi, vma, start, 1);
if (ret)
break;
- /* split_vma() invalidated the mas */
- mas_pause(&mas);
}
if (vma->vm_end > end) {
- ret = split_vma(mm, vma, end, 0);
+ ret = split_vma(&vmi, vma, end, 0);
if (ret)
break;
- /* split_vma() invalidated the mas */
- mas_pause(&mas);
}
next:
/*
@@ -1488,8 +1487,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
skip:
prev = vma;
start = vma->vm_end;
- vma = mas_next(&mas, end - 1);
- } while (vma);
+ }
+
out_unlock:
mmap_write_unlock(mm);
mmput(mm);
@@ -1533,7 +1532,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
bool found;
unsigned long start, end, vma_end;
const void __user *buf = (void __user *)arg;
- MA_STATE(mas, &mm->mm_mt, 0, 0);
+ struct vma_iterator vmi;
ret = -EFAULT;
if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
@@ -1552,14 +1551,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
goto out;
mmap_write_lock(mm);
- mas_set(&mas, start);
- vma = mas_find(&mas, ULONG_MAX);
- if (!vma)
- goto out_unlock;
-
- /* check that there's at least one vma in the range */
ret = -EINVAL;
- if (vma->vm_start >= end)
+ vma_iter_init(&vmi, mm, start);
+ vma = vma_find(&vmi, end);
+ if (!vma)
goto out_unlock;
/*
@@ -1577,8 +1572,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
* Search for not compatible vmas.
*/
found = false;
- ret = -EINVAL;
- for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
+ cur = vma;
+ do {
cond_resched();
BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@@ -1595,16 +1590,13 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
goto out_unlock;
found = true;
- }
+ } for_each_vma_range(vmi, cur, end);
BUG_ON(!found);
- mas_set(&mas, start);
- prev = mas_prev(&mas, 0);
- if (prev != vma)
- mas_next(&mas, ULONG_MAX);
-
+ vma_iter_set(&vmi, start);
+ prev = vma_prev(&vmi);
ret = 0;
- do {
+ for_each_vma_range(vmi, vma, end) {
cond_resched();
BUG_ON(!vma_can_userfault(vma, vma->vm_flags));
@@ -1640,26 +1632,23 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
uffd_wp_range(mm, vma, start, vma_end - start, false);
new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
- prev = vma_merge(mm, prev, start, vma_end, new_flags,
+ prev = vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
vma->anon_vma, vma->vm_file, vma->vm_pgoff,
vma_policy(vma),
NULL_VM_UFFD_CTX, anon_vma_name(vma));
if (prev) {
vma = prev;
- mas_pause(&mas);
goto next;
}
if (vma->vm_start < start) {
- ret = split_vma(mm, vma, start, 1);
+ ret = split_vma(&vmi, vma, start, 1);
if (ret)
break;
- mas_pause(&mas);
}
if (vma->vm_end > end) {
- ret = split_vma(mm, vma, end, 0);
+ ret = split_vma(&vmi, vma, end, 0);
if (ret)
break;
- mas_pause(&mas);
}
next:
/*
@@ -1673,8 +1662,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
skip:
prev = vma;
start = vma->vm_end;
- vma = mas_next(&mas, end - 1);
- } while (vma);
+ }
+
out_unlock:
mmap_write_unlock(mm);
mmput(mm);