[v1,4/6] mm/autonuma: use can_change_(pte|pmd)_writable() to replace savedwrite

Message ID 20221102191209.289237-5-david@redhat.com
State New
Headers
Series mm/autonuma: replace savedwrite infrastructure |

Commit Message

David Hildenbrand Nov. 2, 2022, 7:12 p.m. UTC
  commit b191f9b106ea ("mm: numa: preserve PTE write permissions across a
NUMA hinting fault") added remembering write permissions using ordinary
pte_write() for PROT_NONE mapped pages to avoid write faults when
remapping the page !PROT_NONE on NUMA hinting faults.

That commit noted:

    The patch looks hacky but the alternatives looked worse. The tidest was
    to rewalk the page tables after a hinting fault but it was more complex
    than this approach and the performance was worse. It's not generally
    safe to just mark the page writable during the fault if it's a write
    fault as it may have been read-only for COW so that approach was
    discarded.

Later, commit 288bc54949fc ("mm/autonuma: let architecture override how
the write bit should be stashed in a protnone pte.") introduced a family
of savedwrite PTE functions that didn't necessarily improve the whole
situation.

One confusing thing is that nowadays, if a page is pte_protnone()
and pte_savedwrite() then also pte_write() is true. Another source of
confusion is that there is only a single pte_mk_savedwrite() call in the
kernel. All other write-protection code seems to silently rely on
pte_wrprotect().

Ever since PageAnonExclusive was introduced and we started using it in
mprotect context via commit 64fe24a3e05e ("mm/mprotect: try avoiding write
faults for exclusive anonymous pages when changing protection"), we do
have machinery in place to avoid write faults when changing protection,
which is exactly what we want to do here.

Let's similarly do what ordinary mprotect() does nowadays when upgrading
write permissions and reuse can_change_pte_writable() and
can_change_pmd_writable() to detect if we can upgrade PTE permissions to be
writable.

For anonymous pages there should be absolutely no change: if an
anonymous page is not exclusive, it could not have been mapped writable --
because only exclusive anonymous pages can be mapped writable.

However, there *might* be a change for writable shared mappings that
require writenotify: if they are not dirty, we cannot map them writable.
While it might not matter in practice, we'd need a different way to
identify whether writenotify is actually required -- and ordinary mprotect
would benefit from that as well.

We'll remove all savedwrite leftovers next.

Signed-off-by: David Hildenbrand <david@redhat.com>
---
 include/linux/mm.h |  2 ++
 mm/huge_memory.c   | 28 +++++++++++++++++-----------
 mm/ksm.c           |  9 ++++-----
 mm/memory.c        | 19 ++++++++++++++++---
 mm/mprotect.c      |  7 ++-----
 5 files changed, 41 insertions(+), 24 deletions(-)
  

Comments

Nadav Amit Nov. 2, 2022, 9:22 p.m. UTC | #1
On Nov 2, 2022, at 12:12 PM, David Hildenbrand <david@redhat.com> wrote:

> !! External Email
> 
> commit b191f9b106ea ("mm: numa: preserve PTE write permissions across a
> NUMA hinting fault") added remembering write permissions using ordinary
> pte_write() for PROT_NONE mapped pages to avoid write faults when
> remapping the page !PROT_NONE on NUMA hinting faults.
> 

[ snip ]

Here’s a very shallow reviewed with some minor points...

> ---
> include/linux/mm.h |  2 ++
> mm/huge_memory.c   | 28 +++++++++++++++++-----------
> mm/ksm.c           |  9 ++++-----
> mm/memory.c        | 19 ++++++++++++++++---
> mm/mprotect.c      |  7 ++-----
> 5 files changed, 41 insertions(+), 24 deletions(-)
> 
> diff --git a/include/linux/mm.h b/include/linux/mm.h
> index 25ff9a14a777..a0deeece5e87 100644
> --- a/include/linux/mm.h
> +++ b/include/linux/mm.h
> @@ -1975,6 +1975,8 @@ extern unsigned long move_page_tables(struct vm_area_struct *vma,
> #define  MM_CP_UFFD_WP_ALL                 (MM_CP_UFFD_WP | \
>                                            MM_CP_UFFD_WP_RESOLVE)
> 
> +bool can_change_pte_writable(struct vm_area_struct *vma, unsigned long addr,
> +                            pte_t pte);

It might not be customary, but how about marking it as __pure?

> extern unsigned long change_protection(struct mmu_gather *tlb,
>                              struct vm_area_struct *vma, unsigned long start,
>                              unsigned long end, pgprot_t newprot,
> diff --git a/mm/huge_memory.c b/mm/huge_memory.c
> index 2ad68e91896a..45abd27d75a0 100644
> --- a/mm/huge_memory.c
> +++ b/mm/huge_memory.c
> @@ -1462,8 +1462,7 @@ vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
>        unsigned long haddr = vmf->address & HPAGE_PMD_MASK;
>        int page_nid = NUMA_NO_NODE;
>        int target_nid, last_cpupid = (-1 & LAST_CPUPID_MASK);
> -       bool migrated = false;
> -       bool was_writable = pmd_savedwrite(oldpmd);
> +       bool try_change_writable, migrated = false;
>        int flags = 0;
> 
>        vmf->ptl = pmd_lock(vma->vm_mm, vmf->pmd);
> @@ -1472,13 +1471,22 @@ vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
>                goto out;
>        }
> 
> +       /* See mprotect_fixup(). */
> +       if (vma->vm_flags & VM_SHARED)
> +               try_change_writable = vma_wants_writenotify(vma, vma->vm_page_prot);
> +       else
> +               try_change_writable = !!(vma->vm_flags & VM_WRITE);

Do you find it better to copy the code instead of extracting it to a
separate function?

> +
>        pmd = pmd_modify(oldpmd, vma->vm_page_prot);
>        page = vm_normal_page_pmd(vma, haddr, pmd);
>        if (!page)
>                goto out_map;
> 
>        /* See similar comment in do_numa_page for explanation */
> -       if (!was_writable)
> +       if (try_change_writable && !pmd_write(pmd) &&
> +            can_change_pmd_writable(vma, vmf->address, pmd))
> +               pmd = pmd_mkwrite(pmd);
> +       if (!pmd_write(pmd))
>                flags |= TNF_NO_GROUP;
> 
>        page_nid = page_to_nid(page);
> @@ -1523,8 +1531,12 @@ vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
>        /* Restore the PMD */
>        pmd = pmd_modify(oldpmd, vma->vm_page_prot);
>        pmd = pmd_mkyoung(pmd);
> -       if (was_writable)
> +
> +       /* Similar to mprotect() protection updates, avoid write faults. */
> +       if (try_change_writable && !pmd_write(pmd) &&
> +            can_change_pmd_writable(vma, vmf->address, pmd))

Why do I have a deja-vu? :)

There must be a way to avoid the redundant code and specifically the call to
can_change_pmd_writable(), no?

>                pmd = pmd_mkwrite(pmd);
> +
>        set_pmd_at(vma->vm_mm, haddr, vmf->pmd, pmd);
>        update_mmu_cache_pmd(vma, vmf->address, vmf->pmd);
>        spin_unlock(vmf->ptl);
> @@ -1764,11 +1776,10 @@ int change_huge_pmd(struct mmu_gather *tlb, struct vm_area_struct *vma,
>        struct mm_struct *mm = vma->vm_mm;
>        spinlock_t *ptl;
>        pmd_t oldpmd, entry;
> -       bool preserve_write;
> -       int ret;
>        bool prot_numa = cp_flags & MM_CP_PROT_NUMA;
>        bool uffd_wp = cp_flags & MM_CP_UFFD_WP;
>        bool uffd_wp_resolve = cp_flags & MM_CP_UFFD_WP_RESOLVE;
> +       int ret = 1;
> 
>        tlb_change_page_size(tlb, HPAGE_PMD_SIZE);
> 
> @@ -1779,9 +1790,6 @@ int change_huge_pmd(struct mmu_gather *tlb, struct vm_area_struct *vma,
>        if (!ptl)
>                return 0;
> 
> -       preserve_write = prot_numa && pmd_write(*pmd);
> -       ret = 1;
> -
> #ifdef CONFIG_ARCH_ENABLE_THP_MIGRATION
>        if (is_swap_pmd(*pmd)) {
>                swp_entry_t entry = pmd_to_swp_entry(*pmd);
> @@ -1861,8 +1869,6 @@ int change_huge_pmd(struct mmu_gather *tlb, struct vm_area_struct *vma,
>        oldpmd = pmdp_invalidate_ad(vma, addr, pmd);
> 
>        entry = pmd_modify(oldpmd, newprot);
> -       if (preserve_write)
> -               entry = pmd_mk_savedwrite(entry);
>        if (uffd_wp) {
>                entry = pmd_wrprotect(entry);
>                entry = pmd_mkuffd_wp(entry);
> diff --git a/mm/ksm.c b/mm/ksm.c
> index dc15c4a2a6ff..dd02780c387f 100644
> --- a/mm/ksm.c
> +++ b/mm/ksm.c
> @@ -1069,7 +1069,6 @@ static int write_protect_page(struct vm_area_struct *vma, struct page *page,
> 
>        anon_exclusive = PageAnonExclusive(page);
>        if (pte_write(*pvmw.pte) || pte_dirty(*pvmw.pte) ||
> -           (pte_protnone(*pvmw.pte) && pte_savedwrite(*pvmw.pte)) ||

Not related to your code, but it does not make me comfortable that PTE’s
status bits (which are volatile) are not accessed in this manner.

Especially since the PTE is later saved into orig_pte. It would feel safer
to do READ_ONCE(*pvmw.pte) and work on it (probably in a separate patch).

>            anon_exclusive || mm_tlb_flush_pending(mm)) {
>                pte_t entry;
> 
> @@ -1107,11 +1106,11 @@ static int write_protect_page(struct vm_area_struct *vma, struct page *page,
> 
>                if (pte_dirty(entry))
>                        set_page_dirty(page);
> +               entry = pte_mkclean(entry);
> +
> +               if (pte_write(entry))
> +                       entry = pte_wrprotect(entry);
> 
> -               if (pte_protnone(entry))
> -                       entry = pte_mkclean(pte_clear_savedwrite(entry));
> -               else
> -                       entry = pte_mkclean(pte_wrprotect(entry));
>                set_pte_at_notify(mm, pvmw.address, pvmw.pte, entry);
>        }
>        *orig_pte = *pvmw.pte;
> diff --git a/mm/memory.c b/mm/memory.c
> index c5599a9279b1..286c29ee3aba 100644
> --- a/mm/memory.c
> +++ b/mm/memory.c
> @@ -4672,12 +4672,12 @@ int numa_migrate_prep(struct page *page, struct vm_area_struct *vma,
> static vm_fault_t do_numa_page(struct vm_fault *vmf)
> {
>        struct vm_area_struct *vma = vmf->vma;
> +       bool try_change_writable;
>        struct page *page = NULL;
>        int page_nid = NUMA_NO_NODE;
>        int last_cpupid;
>        int target_nid;
>        pte_t pte, old_pte;
> -       bool was_writable = pte_savedwrite(vmf->orig_pte);
>        int flags = 0;
> 
>        /*
> @@ -4692,6 +4692,12 @@ static vm_fault_t do_numa_page(struct vm_fault *vmf)
>                goto out;
>        }
> 
> +       /* See mprotect_fixup(). */
> +       if (vma->vm_flags & VM_SHARED)
> +               try_change_writable = vma_wants_writenotify(vma, vma->vm_page_prot);
> +       else
> +               try_change_writable = !!(vma->vm_flags & VM_WRITE);

It really cannot be extracted into a separate function?
  
David Hildenbrand Nov. 3, 2022, 10:45 a.m. UTC | #2
On 02.11.22 22:22, Nadav Amit wrote:
> On Nov 2, 2022, at 12:12 PM, David Hildenbrand <david@redhat.com> wrote:
> 
>> !! External Email
>>
>> commit b191f9b106ea ("mm: numa: preserve PTE write permissions across a
>> NUMA hinting fault") added remembering write permissions using ordinary
>> pte_write() for PROT_NONE mapped pages to avoid write faults when
>> remapping the page !PROT_NONE on NUMA hinting faults.
>>
> 
> [ snip ]
> 
> Here’s a very shallow reviewed with some minor points...

Appreciated.

> 
>> ---
>> include/linux/mm.h |  2 ++
>> mm/huge_memory.c   | 28 +++++++++++++++++-----------
>> mm/ksm.c           |  9 ++++-----
>> mm/memory.c        | 19 ++++++++++++++++---
>> mm/mprotect.c      |  7 ++-----
>> 5 files changed, 41 insertions(+), 24 deletions(-)
>>
>> diff --git a/include/linux/mm.h b/include/linux/mm.h
>> index 25ff9a14a777..a0deeece5e87 100644
>> --- a/include/linux/mm.h
>> +++ b/include/linux/mm.h
>> @@ -1975,6 +1975,8 @@ extern unsigned long move_page_tables(struct vm_area_struct *vma,
>> #define  MM_CP_UFFD_WP_ALL                 (MM_CP_UFFD_WP | \
>>                                             MM_CP_UFFD_WP_RESOLVE)
>>
>> +bool can_change_pte_writable(struct vm_area_struct *vma, unsigned long addr,
>> +                            pte_t pte);
> 
> It might not be customary, but how about marking it as __pure?

Right, there is no a single use of __pure in the mm domain.

> 
>> extern unsigned long change_protection(struct mmu_gather *tlb,
>>                               struct vm_area_struct *vma, unsigned long start,
>>                               unsigned long end, pgprot_t newprot,
>> diff --git a/mm/huge_memory.c b/mm/huge_memory.c
>> index 2ad68e91896a..45abd27d75a0 100644
>> --- a/mm/huge_memory.c
>> +++ b/mm/huge_memory.c
>> @@ -1462,8 +1462,7 @@ vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
>>         unsigned long haddr = vmf->address & HPAGE_PMD_MASK;
>>         int page_nid = NUMA_NO_NODE;
>>         int target_nid, last_cpupid = (-1 & LAST_CPUPID_MASK);
>> -       bool migrated = false;
>> -       bool was_writable = pmd_savedwrite(oldpmd);
>> +       bool try_change_writable, migrated = false;
>>         int flags = 0;
>>
>>         vmf->ptl = pmd_lock(vma->vm_mm, vmf->pmd);
>> @@ -1472,13 +1471,22 @@ vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
>>                 goto out;
>>         }
>>
>> +       /* See mprotect_fixup(). */
>> +       if (vma->vm_flags & VM_SHARED)
>> +               try_change_writable = vma_wants_writenotify(vma, vma->vm_page_prot);
>> +       else
>> +               try_change_writable = !!(vma->vm_flags & VM_WRITE);
> 
> Do you find it better to copy the code instead of extracting it to a
> separate function?

Yeah, you're right ;) usually the issue is coming up with a suitable name. Let me try.

vma_wants_manual_writability_change() hmm ...

> 
>> +
>>         pmd = pmd_modify(oldpmd, vma->vm_page_prot);
>>         page = vm_normal_page_pmd(vma, haddr, pmd);
>>         if (!page)
>>                 goto out_map;
>>
>>         /* See similar comment in do_numa_page for explanation */
>> -       if (!was_writable)
>> +       if (try_change_writable && !pmd_write(pmd) &&
>> +            can_change_pmd_writable(vma, vmf->address, pmd))
>> +               pmd = pmd_mkwrite(pmd);
>> +       if (!pmd_write(pmd))
>>                 flags |= TNF_NO_GROUP;
>>
>>         page_nid = page_to_nid(page);
>> @@ -1523,8 +1531,12 @@ vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
>>         /* Restore the PMD */
>>         pmd = pmd_modify(oldpmd, vma->vm_page_prot);
>>         pmd = pmd_mkyoung(pmd);
>> -       if (was_writable)
>> +
>> +       /* Similar to mprotect() protection updates, avoid write faults. */
>> +       if (try_change_writable && !pmd_write(pmd) &&
>> +            can_change_pmd_writable(vma, vmf->address, pmd))
> 
> Why do I have a deja-vu? :)
> 
> There must be a way to avoid the redundant code and specifically the call to
> can_change_pmd_writable(), no?

The issue is that as soon as we drop the page table lock, that information is stale.
Especially, after we fail migration.

So the following should work, however, if we fail migration we wouldn't map the
page writable and would have to re-calculate:

diff --git a/mm/memory.c b/mm/memory.c
index c5599a9279b1..a997625641e4 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -4674,10 +4674,10 @@ static vm_fault_t do_numa_page(struct vm_fault *vmf)
         struct vm_area_struct *vma = vmf->vma;
         struct page *page = NULL;
         int page_nid = NUMA_NO_NODE;
+       bool writable = false;
         int last_cpupid;
         int target_nid;
         pte_t pte, old_pte;
-       bool was_writable = pte_savedwrite(vmf->orig_pte);
         int flags = 0;
  
         /*
@@ -4696,6 +4696,17 @@ static vm_fault_t do_numa_page(struct vm_fault *vmf)
         old_pte = ptep_get(vmf->pte);
         pte = pte_modify(old_pte, vma->vm_page_prot);
  
+       /*
+        * Detect now whether the PTE is or can be writable. Note that this
+        * information is valid as long as we're holding the PT lock, so also on
+        * the remap path below.
+        */
+       writable = pte_write(pte);
+       if (!writable && vma_wants_manual_writability_change(vma) &&
+           can_change_pte_writable(vma, vmf->address, pte);
+           writable = true;
+       }
+
         page = vm_normal_page(vma, vmf->address, pte);
         if (!page || is_zone_device_page(page))
                 goto out_map;
@@ -4712,7 +4723,7 @@ static vm_fault_t do_numa_page(struct vm_fault *vmf)
          * pte_dirty has unpredictable behaviour between PTE scan updates,
          * background writeback, dirty balancing and application behaviour.
          */
-       if (!was_writable)
+       if (!writable)
                 flags |= TNF_NO_GROUP;
  
         /*
@@ -4738,6 +4749,7 @@ static vm_fault_t do_numa_page(struct vm_fault *vmf)
                 put_page(page);
                 goto out_map;
         }
+       writable = false;
         pte_unmap_unlock(vmf->pte, vmf->ptl);
  
         /* Migrate to the requested node */
@@ -4767,7 +4779,7 @@ static vm_fault_t do_numa_page(struct vm_fault *vmf)
         old_pte = ptep_modify_prot_start(vma, vmf->address, vmf->pte);
         pte = pte_modify(old_pte, vma->vm_page_prot);
         pte = pte_mkyoung(pte);
-       if (was_writable)
+       if (writable)
                 pte = pte_mkwrite(pte);
         ptep_modify_prot_commit(vma, vmf->address, vmf->pte, old_pte, pte);
         update_mmu_cache(vma, vmf->address, vmf->pte);


To me, the less error-prone approach is to re-calculate.


[...]
>> --- a/mm/ksm.c
>> +++ b/mm/ksm.c
>> @@ -1069,7 +1069,6 @@ static int write_protect_page(struct vm_area_struct *vma, struct page *page,
>>
>>         anon_exclusive = PageAnonExclusive(page);
>>         if (pte_write(*pvmw.pte) || pte_dirty(*pvmw.pte) ||
>> -           (pte_protnone(*pvmw.pte) && pte_savedwrite(*pvmw.pte)) ||
> 
> Not related to your code, but it does not make me comfortable that PTE’s
> status bits (which are volatile) are not accessed in this manner.
> 
> Especially since the PTE is later saved into orig_pte. It would feel safer
> to do READ_ONCE(*pvmw.pte) and work on it (probably in a separate patch).

I assume you are talking about the dirty bit. I don't immediately see how something
could go wrong here, but I agree that it might look cleaner that way.

Anyhow, independent of this series, so I'll leave that alone for now but add a
note for the future.
  
David Hildenbrand Nov. 3, 2022, 10:51 a.m. UTC | #3
On 03.11.22 11:45, David Hildenbrand wrote:
> On 02.11.22 22:22, Nadav Amit wrote:
>> On Nov 2, 2022, at 12:12 PM, David Hildenbrand <david@redhat.com> wrote:
>>
>>> !! External Email
>>>
>>> commit b191f9b106ea ("mm: numa: preserve PTE write permissions across a
>>> NUMA hinting fault") added remembering write permissions using ordinary
>>> pte_write() for PROT_NONE mapped pages to avoid write faults when
>>> remapping the page !PROT_NONE on NUMA hinting faults.
>>>
>>
>> [ snip ]
>>
>> Here’s a very shallow reviewed with some minor points...
> 
> Appreciated.
> 
>>
>>> ---
>>> include/linux/mm.h |  2 ++
>>> mm/huge_memory.c   | 28 +++++++++++++++++-----------
>>> mm/ksm.c           |  9 ++++-----
>>> mm/memory.c        | 19 ++++++++++++++++---
>>> mm/mprotect.c      |  7 ++-----
>>> 5 files changed, 41 insertions(+), 24 deletions(-)
>>>
>>> diff --git a/include/linux/mm.h b/include/linux/mm.h
>>> index 25ff9a14a777..a0deeece5e87 100644
>>> --- a/include/linux/mm.h
>>> +++ b/include/linux/mm.h
>>> @@ -1975,6 +1975,8 @@ extern unsigned long move_page_tables(struct vm_area_struct *vma,
>>> #define  MM_CP_UFFD_WP_ALL                 (MM_CP_UFFD_WP | \
>>>                                              MM_CP_UFFD_WP_RESOLVE)
>>>
>>> +bool can_change_pte_writable(struct vm_area_struct *vma, unsigned long addr,
>>> +                            pte_t pte);
>>
>> It might not be customary, but how about marking it as __pure?
> 
> Right, there is no a single use of __pure in the mm domain.
> 
>>
>>> extern unsigned long change_protection(struct mmu_gather *tlb,
>>>                                struct vm_area_struct *vma, unsigned long start,
>>>                                unsigned long end, pgprot_t newprot,
>>> diff --git a/mm/huge_memory.c b/mm/huge_memory.c
>>> index 2ad68e91896a..45abd27d75a0 100644
>>> --- a/mm/huge_memory.c
>>> +++ b/mm/huge_memory.c
>>> @@ -1462,8 +1462,7 @@ vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
>>>          unsigned long haddr = vmf->address & HPAGE_PMD_MASK;
>>>          int page_nid = NUMA_NO_NODE;
>>>          int target_nid, last_cpupid = (-1 & LAST_CPUPID_MASK);
>>> -       bool migrated = false;
>>> -       bool was_writable = pmd_savedwrite(oldpmd);
>>> +       bool try_change_writable, migrated = false;
>>>          int flags = 0;
>>>
>>>          vmf->ptl = pmd_lock(vma->vm_mm, vmf->pmd);
>>> @@ -1472,13 +1471,22 @@ vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
>>>                  goto out;
>>>          }
>>>
>>> +       /* See mprotect_fixup(). */
>>> +       if (vma->vm_flags & VM_SHARED)
>>> +               try_change_writable = vma_wants_writenotify(vma, vma->vm_page_prot);
>>> +       else
>>> +               try_change_writable = !!(vma->vm_flags & VM_WRITE);
>>
>> Do you find it better to copy the code instead of extracting it to a
>> separate function?
> 
> Yeah, you're right ;) usually the issue is coming up with a suitable name. Let me try.
> 
> vma_wants_manual_writability_change() hmm ...
> 
>>
>>> +
>>>          pmd = pmd_modify(oldpmd, vma->vm_page_prot);
>>>          page = vm_normal_page_pmd(vma, haddr, pmd);
>>>          if (!page)
>>>                  goto out_map;
>>>
>>>          /* See similar comment in do_numa_page for explanation */
>>> -       if (!was_writable)
>>> +       if (try_change_writable && !pmd_write(pmd) &&
>>> +            can_change_pmd_writable(vma, vmf->address, pmd))
>>> +               pmd = pmd_mkwrite(pmd);
>>> +       if (!pmd_write(pmd))
>>>                  flags |= TNF_NO_GROUP;
>>>
>>>          page_nid = page_to_nid(page);
>>> @@ -1523,8 +1531,12 @@ vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
>>>          /* Restore the PMD */
>>>          pmd = pmd_modify(oldpmd, vma->vm_page_prot);
>>>          pmd = pmd_mkyoung(pmd);
>>> -       if (was_writable)
>>> +
>>> +       /* Similar to mprotect() protection updates, avoid write faults. */
>>> +       if (try_change_writable && !pmd_write(pmd) &&
>>> +            can_change_pmd_writable(vma, vmf->address, pmd))
>>
>> Why do I have a deja-vu? :)
>>
>> There must be a way to avoid the redundant code and specifically the call to
>> can_change_pmd_writable(), no?
> 
> The issue is that as soon as we drop the page table lock, that information is stale.
> Especially, after we fail migration.
> 
> So the following should work, however, if we fail migration we wouldn't map the
> page writable and would have to re-calculate:
> 
> diff --git a/mm/memory.c b/mm/memory.c
> index c5599a9279b1..a997625641e4 100644
> --- a/mm/memory.c
> +++ b/mm/memory.c
> @@ -4674,10 +4674,10 @@ static vm_fault_t do_numa_page(struct vm_fault *vmf)
>           struct vm_area_struct *vma = vmf->vma;
>           struct page *page = NULL;
>           int page_nid = NUMA_NO_NODE;
> +       bool writable = false;
>           int last_cpupid;
>           int target_nid;
>           pte_t pte, old_pte;
> -       bool was_writable = pte_savedwrite(vmf->orig_pte);
>           int flags = 0;
>    
>           /*
> @@ -4696,6 +4696,17 @@ static vm_fault_t do_numa_page(struct vm_fault *vmf)
>           old_pte = ptep_get(vmf->pte);
>           pte = pte_modify(old_pte, vma->vm_page_prot);
>    
> +       /*
> +        * Detect now whether the PTE is or can be writable. Note that this
> +        * information is valid as long as we're holding the PT lock, so also on
> +        * the remap path below.
> +        */
> +       writable = pte_write(pte);
> +       if (!writable && vma_wants_manual_writability_change(vma) &&
> +           can_change_pte_writable(vma, vmf->address, pte);
> +           writable = true;
> +       }
> +
>           page = vm_normal_page(vma, vmf->address, pte);
>           if (!page || is_zone_device_page(page))
>                   goto out_map;
> @@ -4712,7 +4723,7 @@ static vm_fault_t do_numa_page(struct vm_fault *vmf)
>            * pte_dirty has unpredictable behaviour between PTE scan updates,
>            * background writeback, dirty balancing and application behaviour.
>            */
> -       if (!was_writable)
> +       if (!writable)
>                   flags |= TNF_NO_GROUP;
>    
>           /*
> @@ -4738,6 +4749,7 @@ static vm_fault_t do_numa_page(struct vm_fault *vmf)
>                   put_page(page);
>                   goto out_map;
>           }
> +       writable = false;
>           pte_unmap_unlock(vmf->pte, vmf->ptl);
>    
>           /* Migrate to the requested node */
> @@ -4767,7 +4779,7 @@ static vm_fault_t do_numa_page(struct vm_fault *vmf)
>           old_pte = ptep_modify_prot_start(vma, vmf->address, vmf->pte);
>           pte = pte_modify(old_pte, vma->vm_page_prot);
>           pte = pte_mkyoung(pte);
> -       if (was_writable)
> +       if (writable)
>                   pte = pte_mkwrite(pte);
>           ptep_modify_prot_commit(vma, vmf->address, vmf->pte, old_pte, pte);
>           update_mmu_cache(vma, vmf->address, vmf->pte);
> 
> 
> To me, the less error-prone approach is to re-calculate.

Hmm, thinking again, the "if (unlikely(!pte_same(*vmf->pte, 
vmf->orig_pte))) {" check might actually not require us to recalculate.
  

Patch

diff --git a/include/linux/mm.h b/include/linux/mm.h
index 25ff9a14a777..a0deeece5e87 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1975,6 +1975,8 @@  extern unsigned long move_page_tables(struct vm_area_struct *vma,
 #define  MM_CP_UFFD_WP_ALL                 (MM_CP_UFFD_WP | \
 					    MM_CP_UFFD_WP_RESOLVE)
 
+bool can_change_pte_writable(struct vm_area_struct *vma, unsigned long addr,
+			     pte_t pte);
 extern unsigned long change_protection(struct mmu_gather *tlb,
 			      struct vm_area_struct *vma, unsigned long start,
 			      unsigned long end, pgprot_t newprot,
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index 2ad68e91896a..45abd27d75a0 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -1462,8 +1462,7 @@  vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
 	unsigned long haddr = vmf->address & HPAGE_PMD_MASK;
 	int page_nid = NUMA_NO_NODE;
 	int target_nid, last_cpupid = (-1 & LAST_CPUPID_MASK);
-	bool migrated = false;
-	bool was_writable = pmd_savedwrite(oldpmd);
+	bool try_change_writable, migrated = false;
 	int flags = 0;
 
 	vmf->ptl = pmd_lock(vma->vm_mm, vmf->pmd);
@@ -1472,13 +1471,22 @@  vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
 		goto out;
 	}
 
+	/* See mprotect_fixup(). */
+	if (vma->vm_flags & VM_SHARED)
+		try_change_writable = vma_wants_writenotify(vma, vma->vm_page_prot);
+	else
+		try_change_writable = !!(vma->vm_flags & VM_WRITE);
+
 	pmd = pmd_modify(oldpmd, vma->vm_page_prot);
 	page = vm_normal_page_pmd(vma, haddr, pmd);
 	if (!page)
 		goto out_map;
 
 	/* See similar comment in do_numa_page for explanation */
-	if (!was_writable)
+	if (try_change_writable && !pmd_write(pmd) &&
+	     can_change_pmd_writable(vma, vmf->address, pmd))
+		pmd = pmd_mkwrite(pmd);
+	if (!pmd_write(pmd))
 		flags |= TNF_NO_GROUP;
 
 	page_nid = page_to_nid(page);
@@ -1523,8 +1531,12 @@  vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
 	/* Restore the PMD */
 	pmd = pmd_modify(oldpmd, vma->vm_page_prot);
 	pmd = pmd_mkyoung(pmd);
-	if (was_writable)
+
+	/* Similar to mprotect() protection updates, avoid write faults. */
+	if (try_change_writable && !pmd_write(pmd) &&
+	     can_change_pmd_writable(vma, vmf->address, pmd))
 		pmd = pmd_mkwrite(pmd);
+
 	set_pmd_at(vma->vm_mm, haddr, vmf->pmd, pmd);
 	update_mmu_cache_pmd(vma, vmf->address, vmf->pmd);
 	spin_unlock(vmf->ptl);
@@ -1764,11 +1776,10 @@  int change_huge_pmd(struct mmu_gather *tlb, struct vm_area_struct *vma,
 	struct mm_struct *mm = vma->vm_mm;
 	spinlock_t *ptl;
 	pmd_t oldpmd, entry;
-	bool preserve_write;
-	int ret;
 	bool prot_numa = cp_flags & MM_CP_PROT_NUMA;
 	bool uffd_wp = cp_flags & MM_CP_UFFD_WP;
 	bool uffd_wp_resolve = cp_flags & MM_CP_UFFD_WP_RESOLVE;
+	int ret = 1;
 
 	tlb_change_page_size(tlb, HPAGE_PMD_SIZE);
 
@@ -1779,9 +1790,6 @@  int change_huge_pmd(struct mmu_gather *tlb, struct vm_area_struct *vma,
 	if (!ptl)
 		return 0;
 
-	preserve_write = prot_numa && pmd_write(*pmd);
-	ret = 1;
-
 #ifdef CONFIG_ARCH_ENABLE_THP_MIGRATION
 	if (is_swap_pmd(*pmd)) {
 		swp_entry_t entry = pmd_to_swp_entry(*pmd);
@@ -1861,8 +1869,6 @@  int change_huge_pmd(struct mmu_gather *tlb, struct vm_area_struct *vma,
 	oldpmd = pmdp_invalidate_ad(vma, addr, pmd);
 
 	entry = pmd_modify(oldpmd, newprot);
-	if (preserve_write)
-		entry = pmd_mk_savedwrite(entry);
 	if (uffd_wp) {
 		entry = pmd_wrprotect(entry);
 		entry = pmd_mkuffd_wp(entry);
diff --git a/mm/ksm.c b/mm/ksm.c
index dc15c4a2a6ff..dd02780c387f 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -1069,7 +1069,6 @@  static int write_protect_page(struct vm_area_struct *vma, struct page *page,
 
 	anon_exclusive = PageAnonExclusive(page);
 	if (pte_write(*pvmw.pte) || pte_dirty(*pvmw.pte) ||
-	    (pte_protnone(*pvmw.pte) && pte_savedwrite(*pvmw.pte)) ||
 	    anon_exclusive || mm_tlb_flush_pending(mm)) {
 		pte_t entry;
 
@@ -1107,11 +1106,11 @@  static int write_protect_page(struct vm_area_struct *vma, struct page *page,
 
 		if (pte_dirty(entry))
 			set_page_dirty(page);
+		entry = pte_mkclean(entry);
+
+		if (pte_write(entry))
+			entry = pte_wrprotect(entry);
 
-		if (pte_protnone(entry))
-			entry = pte_mkclean(pte_clear_savedwrite(entry));
-		else
-			entry = pte_mkclean(pte_wrprotect(entry));
 		set_pte_at_notify(mm, pvmw.address, pvmw.pte, entry);
 	}
 	*orig_pte = *pvmw.pte;
diff --git a/mm/memory.c b/mm/memory.c
index c5599a9279b1..286c29ee3aba 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -4672,12 +4672,12 @@  int numa_migrate_prep(struct page *page, struct vm_area_struct *vma,
 static vm_fault_t do_numa_page(struct vm_fault *vmf)
 {
 	struct vm_area_struct *vma = vmf->vma;
+	bool try_change_writable;
 	struct page *page = NULL;
 	int page_nid = NUMA_NO_NODE;
 	int last_cpupid;
 	int target_nid;
 	pte_t pte, old_pte;
-	bool was_writable = pte_savedwrite(vmf->orig_pte);
 	int flags = 0;
 
 	/*
@@ -4692,6 +4692,12 @@  static vm_fault_t do_numa_page(struct vm_fault *vmf)
 		goto out;
 	}
 
+	/* See mprotect_fixup(). */
+	if (vma->vm_flags & VM_SHARED)
+		try_change_writable = vma_wants_writenotify(vma, vma->vm_page_prot);
+	else
+		try_change_writable = !!(vma->vm_flags & VM_WRITE);
+
 	/* Get the normal PTE  */
 	old_pte = ptep_get(vmf->pte);
 	pte = pte_modify(old_pte, vma->vm_page_prot);
@@ -4712,7 +4718,10 @@  static vm_fault_t do_numa_page(struct vm_fault *vmf)
 	 * pte_dirty has unpredictable behaviour between PTE scan updates,
 	 * background writeback, dirty balancing and application behaviour.
 	 */
-	if (!was_writable)
+	if (try_change_writable && !pte_write(pte) &&
+	     can_change_pte_writable(vma, vmf->address, pte))
+		pte = pte_mkwrite(pte);
+	if (!pte_write(pte))
 		flags |= TNF_NO_GROUP;
 
 	/*
@@ -4767,8 +4776,12 @@  static vm_fault_t do_numa_page(struct vm_fault *vmf)
 	old_pte = ptep_modify_prot_start(vma, vmf->address, vmf->pte);
 	pte = pte_modify(old_pte, vma->vm_page_prot);
 	pte = pte_mkyoung(pte);
-	if (was_writable)
+
+	/* Similar to mprotect() protection updates, avoid write faults. */
+	if (try_change_writable && !pte_write(pte) &&
+	     can_change_pte_writable(vma, vmf->address, pte))
 		pte = pte_mkwrite(pte);
+
 	ptep_modify_prot_commit(vma, vmf->address, vmf->pte, old_pte, pte);
 	update_mmu_cache(vma, vmf->address, vmf->pte);
 	pte_unmap_unlock(vmf->pte, vmf->ptl);
diff --git a/mm/mprotect.c b/mm/mprotect.c
index 72aabffb7871..6c6248b65fd5 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -39,8 +39,8 @@ 
 
 #include "internal.h"
 
-static inline bool can_change_pte_writable(struct vm_area_struct *vma,
-					   unsigned long addr, pte_t pte)
+bool can_change_pte_writable(struct vm_area_struct *vma, unsigned long addr,
+			     pte_t pte)
 {
 	struct page *page;
 
@@ -121,7 +121,6 @@  static unsigned long change_pte_range(struct mmu_gather *tlb,
 		oldpte = *pte;
 		if (pte_present(oldpte)) {
 			pte_t ptent;
-			bool preserve_write = prot_numa && pte_write(oldpte);
 
 			/*
 			 * Avoid trapping faults against the zero or KSM
@@ -177,8 +176,6 @@  static unsigned long change_pte_range(struct mmu_gather *tlb,
 
 			oldpte = ptep_modify_prot_start(vma, addr, pte);
 			ptent = pte_modify(oldpte, newprot);
-			if (preserve_write)
-				ptent = pte_mk_savedwrite(ptent);
 
 			if (uffd_wp) {
 				ptent = pte_wrprotect(ptent);