[v3,08/48] mm/mmap: convert brk to use vma iterator

Message ID 20230117023335.1690727-9-Liam.Howlett@oracle.com
State New
Headers
Series VMA tree type safety and remove __vma_adjust() |

Commit Message

Liam R. Howlett Jan. 17, 2023, 2:34 a.m. UTC
  From: "Liam R. Howlett" <Liam.Howlett@Oracle.com>

Use the vma iterator API for the brk() system call.  This will provide
type safety at compile time.

Signed-off-by: Liam R. Howlett <Liam.Howlett@oracle.com>
---
 mm/mmap.c | 46 ++++++++++++++++++++++------------------------
 1 file changed, 22 insertions(+), 24 deletions(-)
  

Patch

diff --git a/mm/mmap.c b/mm/mmap.c
index 024fb46251e2..81ff47147317 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -180,10 +180,10 @@  static int check_brk_limits(unsigned long addr, unsigned long len)
 
 	return mlock_future_check(current->mm, current->mm->def_flags, len);
 }
-static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
+static int do_brk_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
 			 unsigned long newbrk, unsigned long oldbrk,
 			 struct list_head *uf);
-static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *brkvma,
+static int do_brk_flags(struct vma_iterator *vmi, struct vm_area_struct *brkvma,
 		unsigned long addr, unsigned long request, unsigned long flags);
 SYSCALL_DEFINE1(brk, unsigned long, brk)
 {
@@ -194,7 +194,7 @@  SYSCALL_DEFINE1(brk, unsigned long, brk)
 	bool populate;
 	bool downgraded = false;
 	LIST_HEAD(uf);
-	MA_STATE(mas, &mm->mm_mt, 0, 0);
+	struct vma_iterator vmi;
 
 	if (mmap_write_lock_killable(mm))
 		return -EINTR;
@@ -242,8 +242,8 @@  SYSCALL_DEFINE1(brk, unsigned long, brk)
 		int ret;
 
 		/* Search one past newbrk */
-		mas_set(&mas, newbrk);
-		brkvma = mas_find(&mas, oldbrk);
+		vma_iter_init(&vmi, mm, newbrk);
+		brkvma = vma_find(&vmi, oldbrk);
 		if (!brkvma || brkvma->vm_start >= oldbrk)
 			goto out; /* mapping intersects with an existing non-brk vma. */
 		/*
@@ -252,7 +252,7 @@  SYSCALL_DEFINE1(brk, unsigned long, brk)
 		 * before calling do_brk_munmap().
 		 */
 		mm->brk = brk;
-		ret = do_brk_munmap(&mas, brkvma, newbrk, oldbrk, &uf);
+		ret = do_brk_munmap(&vmi, brkvma, newbrk, oldbrk, &uf);
 		if (ret == 1)  {
 			downgraded = true;
 			goto success;
@@ -270,14 +270,14 @@  SYSCALL_DEFINE1(brk, unsigned long, brk)
 	 * Only check if the next VMA is within the stack_guard_gap of the
 	 * expansion area
 	 */
-	mas_set(&mas, oldbrk);
-	next = mas_find(&mas, newbrk - 1 + PAGE_SIZE + stack_guard_gap);
+	vma_iter_init(&vmi, mm, oldbrk);
+	next = vma_find(&vmi, newbrk + PAGE_SIZE + stack_guard_gap);
 	if (next && newbrk + PAGE_SIZE > vm_start_gap(next))
 		goto out;
 
-	brkvma = mas_prev(&mas, mm->start_brk);
+	brkvma = vma_prev_limit(&vmi, mm->start_brk);
 	/* Ok, looks good - let it rip. */
-	if (do_brk_flags(&mas, brkvma, oldbrk, newbrk - oldbrk, 0) < 0)
+	if (do_brk_flags(&vmi, brkvma, oldbrk, newbrk - oldbrk, 0) < 0)
 		goto out;
 
 	mm->brk = brk;
@@ -2904,7 +2904,7 @@  SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
 
 /*
  * brk_munmap() - Unmap a parital vma.
- * @mas: The maple tree state.
+ * @vmi: The vma iterator
  * @vma: The vma to be modified
  * @newbrk: the start of the address to unmap
  * @oldbrk: The end of the address to unmap
@@ -2914,7 +2914,7 @@  SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
  * unmaps a partial VMA mapping.  Does not handle alignment, downgrades lock if
  * possible.
  */
-static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
+static int do_brk_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
 			 unsigned long newbrk, unsigned long oldbrk,
 			 struct list_head *uf)
 {
@@ -2922,14 +2922,14 @@  static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
 	int ret;
 
 	arch_unmap(mm, newbrk, oldbrk);
-	ret = do_mas_align_munmap(mas, vma, mm, newbrk, oldbrk, uf, true);
+	ret = do_mas_align_munmap(&vmi->mas, vma, mm, newbrk, oldbrk, uf, true);
 	validate_mm_mt(mm);
 	return ret;
 }
 
 /*
  * do_brk_flags() - Increase the brk vma if the flags match.
- * @mas: The maple tree state.
+ * @vmi: The vma iterator
  * @addr: The start address
  * @len: The length of the increase
  * @vma: The vma,
@@ -2939,7 +2939,7 @@  static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
  * do not match then create a new anonymous VMA.  Eventually we may be able to
  * do some brk-specific accounting here.
  */
-static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
+static int do_brk_flags(struct vma_iterator *vmi, struct vm_area_struct *vma,
 		unsigned long addr, unsigned long len, unsigned long flags)
 {
 	struct mm_struct *mm = current->mm;
@@ -2966,8 +2966,7 @@  static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
 	if (vma && vma->vm_end == addr && !vma_policy(vma) &&
 	    can_vma_merge_after(vma, flags, NULL, NULL,
 				addr >> PAGE_SHIFT, NULL_VM_UFFD_CTX, NULL)) {
-		mas_set_range(mas, vma->vm_start, addr + len - 1);
-		if (mas_preallocate(mas, vma, GFP_KERNEL))
+		if (vma_iter_prealloc(vmi, vma))
 			goto unacct_fail;
 
 		vma_adjust_trans_huge(vma, vma->vm_start, addr + len, 0);
@@ -2977,7 +2976,7 @@  static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
 		}
 		vma->vm_end = addr + len;
 		vma->vm_flags |= VM_SOFTDIRTY;
-		mas_store_prealloc(mas, vma);
+		vma_iter_store(vmi, vma);
 
 		if (vma->anon_vma) {
 			anon_vma_interval_tree_post_update_vma(vma);
@@ -2998,8 +2997,7 @@  static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
 	vma->vm_pgoff = addr >> PAGE_SHIFT;
 	vma->vm_flags = flags;
 	vma->vm_page_prot = vm_get_page_prot(flags);
-	mas_set_range(mas, vma->vm_start, addr + len - 1);
-	if (mas_store_gfp(mas, vma, GFP_KERNEL))
+	if (vma_iter_store_gfp(vmi, vma, GFP_KERNEL))
 		goto mas_store_fail;
 
 	mm->map_count++;
@@ -3028,7 +3026,7 @@  int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags)
 	int ret;
 	bool populate;
 	LIST_HEAD(uf);
-	MA_STATE(mas, &mm->mm_mt, addr, addr);
+	VMA_ITERATOR(vmi, mm, addr);
 
 	len = PAGE_ALIGN(request);
 	if (len < request)
@@ -3047,12 +3045,12 @@  int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags)
 	if (ret)
 		goto limits_failed;
 
-	ret = do_mas_munmap(&mas, mm, addr, len, &uf, 0);
+	ret = do_mas_munmap(&vmi.mas, mm, addr, len, &uf, 0);
 	if (ret)
 		goto munmap_failed;
 
-	vma = mas_prev(&mas, 0);
-	ret = do_brk_flags(&mas, vma, addr, len, flags);
+	vma = vma_prev(&vmi);
+	ret = do_brk_flags(&vmi, vma, addr, len, flags);
 	populate = ((mm->def_flags & VM_LOCKED) != 0);
 	mmap_write_unlock(mm);
 	userfaultfd_unmap_complete(mm, &uf);