[v6,1/4] KVM: mmu: introduce new gfn_to_pfn_noref functions

Message ID 20230330085802.2414466-2-stevensd@google.com
State New
Headers
Series KVM: allow mapping non-refcounted pages |

Commit Message

David Stevens March 30, 2023, 8:57 a.m. UTC
  From: David Stevens <stevensd@chromium.org>

Introduce new gfn_to_pfn_noref functions that parallel existing
gfn_to_pfn functions. These functions can be used when the caller does
not need to maintain a reference to the returned pfn (i.e. when usage is
guarded by a mmu_notifier). The noref functions take an out parameter
that is used to return the struct page if the hva was resolved via gup.
The caller needs to drop its reference such a returned page.

Signed-off-by: David Stevens <stevensd@chromium.org>
---
 include/linux/kvm_host.h |  18 ++++
 virt/kvm/kvm_main.c      | 209 ++++++++++++++++++++++++++++-----------
 virt/kvm/kvm_mm.h        |   6 +-
 virt/kvm/pfncache.c      |  12 ++-
 4 files changed, 188 insertions(+), 57 deletions(-)
  

Comments

Sean Christopherson May 22, 2023, 8:46 p.m. UTC | #1
+Peter

On Thu, Mar 30, 2023, David Stevens wrote:
> From: David Stevens <stevensd@chromium.org>
> 
> Introduce new gfn_to_pfn_noref functions that parallel existing
> gfn_to_pfn functions. These functions can be used when the caller does
> not need to maintain a reference to the returned pfn (i.e. when usage is
> guarded by a mmu_notifier). The noref functions take an out parameter
> that is used to return the struct page if the hva was resolved via gup.
> The caller needs to drop its reference such a returned page.

I dislike the "noref" name and the approach itself (of providing an entirely
separate set of APIs).  Using "noref" is confusing because the callers do actually
get a reference to the page (if a refcounted page is found).

As for the approach, I really, really don't want to end up with yet more APIs
for getting PFNs from GFNs.  We already have far too many.  In the short term,
I think we'll need to carry multiple sets of APIs, as converting all architectures
to any new API will be too much for a single series.  But I want to have line of
sight to convering on a single, as-small-as-possible set of APIs, and I think/hope
it should be possible to make the old APIs, e.g. gfn_to_pfn(), to be shims around
the new APIs.

And since this series is essentially overhauling the gfn_to_pfn APIs, I think it's
the right series to take on refactoring the APIs to clean up the growing flag
problem.  There was a bit of discussion back when "interruptible" support was
added (https://lore.kernel.org/all/YrTbKaRe497n8M0o@xz-m1.loca), but it got punted
because it wasn't necessary, and because there wasn't immediate agreement on what
exactly the APIs should look like.

Overhauling the APIs would also let us clean up things like async #PF, specifically
replacing the unintuitive "*async = true" logic with something like this:

		if ((flags & FOLL_NOWAIT) && vma_is_valid(vma, flags & FOLL_WRITE))
			pfn = KVM_PFN_ERR_FAULT_MINOR;
		else
			pfn = KVM_PFN_ERR_FAULT;

Lastly, I think there's also an opportunity here to harden KVM's interaction with
mmu_notifiers, and to dedup arch code in KVM .  Specifically, even when the proposed
"allow_unsafe_kmap" is true, KVM should either (a) be "in" an mmu_notifier sequence
or (b) _want_ to grab a reference.  And when KVM does NOT want a reference, the core
API can/should immediately drop the reference even before returning.

My thought is it provide an "entirely" new API, named something like kvm_follow_pfn()
to somewhat mirror follow_{pfn,pte,phys}().  Ideally something to pair with gup()
would be nice, but having a dedicated KVM helper to get _only_ struct page memory
doesn't work well because KVM almost never wants only struct page memory.

As for the flags vs. bools debate (see link above), I think the best approach is
a mix of the two.  Specifically, reuse the FOLL_* flags as-is for inputs, and use
booleans for outputs.  I don't _think_ there are any input bools/flags that don't
map 1:1 with existing FOLL_* flags.

As a very, *very* rough sketch, provide APIs that look a bit like this.

  kvm_pfn_t __kvm_follow_pfn(struct kvm_follow_pfn *foll)
  {
	kvm_pfn_t pfn;

	if (WARN_ON_ONCE(!(foll->flags & FOLL_GET) && !foll.mmu_seq))
		return KVM_PFN_ERR_FAULT;

	pfn = ???;

	if (foll->page && !(foll->flags & FOLL_GET))
		put_page(foll->page);

	return pfn;
  }

  kvm_pfn_t kvm_follow_pfn(struct kvm_vcpu *vcpu, gfn_t gfn, struct page **page)
  {
	struct kvm_follow_pfn foll = {
		.flags = FOLL_GET | FOLL_WRITE,
	};

	<more stuff here?>

	foll.slot = ???;
	if (!foll.slot || foll.slot->flags & KVM_MEMSLOT_INVALID)
		return KVM_HVA_ERR_BAD;

	if (memslot_is_readonly(foll.slot))
		return KVM_HVA_ERR_RO_BAD;

	return __kvm_follow_pfn(&foll);
  }

and a few partially converted users

diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
index 67e2ac799aa7..5eaf0395ed87 100644
--- a/arch/x86/kvm/mmu/mmu.c
+++ b/arch/x86/kvm/mmu/mmu.c
@@ -550,12 +550,14 @@ static bool mmu_spte_update(u64 *sptep, u64 new_spte)
 
        if (is_accessed_spte(old_spte) && !is_accessed_spte(new_spte)) {
                flush = true;
-               kvm_set_pfn_accessed(spte_to_pfn(old_spte));
+               if (is_refcounted_page_pte(old_spte))
+                       kvm_set_page_accessed(pfn_to_page(spte_to_pfn));
        }
 
        if (is_dirty_spte(old_spte) && !is_dirty_spte(new_spte)) {
                flush = true;
-               kvm_set_pfn_dirty(spte_to_pfn(old_spte));
+               if (is_refcounted_page_pte(old_spte))
+                       kvm_set_page_dirty(pfn_to_page(spte_to_pfn));
        }
 
        return flush;
@@ -4278,6 +4280,10 @@ void kvm_arch_async_page_ready(struct kvm_vcpu *vcpu, struct kvm_async_pf *work)
 
 static int __kvm_faultin_pfn(struct kvm_vcpu *vcpu, struct kvm_page_fault *fault)
 {
+       struct kvm_follow_pfn foll = {
+               .mmu_seq = fault->mmu_seq,
+               .gfn = fault->gfn,
+       };
        struct kvm_memory_slot *slot = fault->slot;
        bool async;
 
@@ -4309,12 +4315,16 @@ static int __kvm_faultin_pfn(struct kvm_vcpu *vcpu, struct kvm_page_fault *fault
                        return RET_PF_EMULATE;
        }
 
-       async = false;
-       fault->pfn = __gfn_to_pfn_noref_memslot(slot, fault->gfn, false, false, &async,
-                                               fault->write, &fault->map_writable,
-                                               &fault->hva, &fault->page);
-       if (!async)
-               return RET_PF_CONTINUE; /* *pfn has correct page already */
+       foll.flags = FOLL_NOWAIT;
+       if (fault->write)
+               foll.flags |= FOLL_WRITE;
+
+       fault->pfn = __kvm_follow_pfn(&foll);
+       if (!is_error_noslot_pfn(fault->pfn))
+               goto success;
+
+       if (!is_fault_minor_pfn(fault->pfn))
+               return RET_PF_CONTINUE;
 
        if (!fault->prefetch && kvm_can_do_async_pf(vcpu)) {
                trace_kvm_try_async_get_page(fault->addr, fault->gfn);
@@ -4332,9 +4342,18 @@ static int __kvm_faultin_pfn(struct kvm_vcpu *vcpu, struct kvm_page_fault *fault
         * to wait for IO.  Note, gup always bails if it is unable to quickly
         * get a page and a fatal signal, i.e. SIGKILL, is pending.
         */
-       fault->pfn = __gfn_to_pfn_noref_memslot(slot, fault->gfn, false, true, NULL,
-                                               fault->write, &fault->map_writable,
-                                               &fault->hva, &fault->page);
+       foll.flags |= FOLL_INTERRUPTIBLE;
+       foll.flags &= ~FOLL_NOWAIT;
+
+       fault->pfn = kvm_follow_pfn(&foll);
+       if (!is_error_noslot_pfn(fault->pfn))
+               goto success;
+
+       return RET_PF_CONTINUE;
+success:
+       fault->hva = foll.hva;
+       fault->page = foll.page;
+       fault->map_writable = foll.writable;
        return RET_PF_CONTINUE;
 }
 
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 360eaa24456f..0bae253c88dd 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -2663,9 +2663,10 @@ kvm_pfn_t hva_to_pfn(unsigned long addr, bool atomic, bool interruptible,
                if (r < 0)
                        pfn = KVM_PFN_ERR_FAULT;
        } else {
-               if (async && vma_is_valid(vma, write_fault))
-                       *async = true;
-               pfn = KVM_PFN_ERR_FAULT;
+               if ((flags & FOLL_NOWAIT) && vma_is_valid(vma, flags & FOLL_WRITE))
+                       pfn = KVM_PFN_ERR_FAULT_MINOR;
+               else
...skipping...
+       fault->pfn = kvm_follow_pfn(&foll);
+       if (!is_error_noslot_pfn(fault->pfn))
+               goto success;
+
+       return RET_PF_CONTINUE;
+success:
+       fault->hva = foll.hva;
+       fault->page = foll.page;
+       fault->map_writable = foll.writable;
        return RET_PF_CONTINUE;
 }
 
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 360eaa24456f..0bae253c88dd 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -2663,9 +2663,10 @@ kvm_pfn_t hva_to_pfn(unsigned long addr, bool atomic, bool interruptible,
                if (r < 0)
                        pfn = KVM_PFN_ERR_FAULT;
        } else {
-               if (async && vma_is_valid(vma, write_fault))
-                       *async = true;
-               pfn = KVM_PFN_ERR_FAULT;
+               if ((flags & FOLL_NOWAIT) && vma_is_valid(vma, flags & FOLL_WRITE))
+                       pfn = KVM_PFN_ERR_FAULT_MINOR;
+               else
+                       pfn = KVM_PFN_ERR_FAULT;
        }
 exit:
        mmap_read_unlock(current->mm);
@@ -2732,6 +2733,30 @@ kvm_pfn_t __gfn_to_pfn_noref_memslot(const struct kvm_memory_slot *slot, gfn_t g
 }
 EXPORT_SYMBOL_GPL(__gfn_to_pfn_noref_memslot);
 
+kvm_pfn_t __kvm_follow_pfn(struct kvm_follow_pfn *foll)
+{
+       kvm_pfn_t pfn;
+
+       if (WARN_ON_ONCE(!(foll->flags & FOLL_GET) && !foll.mmu_seq))
+               return KVM_PFN_ERR_FAULT;
+
+       pfn = __gfn_to_pfn_noref_memslot(...);
+
+       if (foll->page && !(foll->flags & FOLL_GET))
+               put_page(foll->page);
+
+       return pfn;
+}
+
+kvm_pfn_t kvm_follow_pfn(struct kvm_vcpu *vcpu, gfn_t gfn, struct page **page)
+{
+       struct kvm_follow_pfn foll = {
+               .flags = FOLL_GET | FOLL_WRITE,
+       };
+
+       return __kvm_follow_pfn(&foll);
+}
+
 kvm_pfn_t gfn_to_pfn_noref_prot(struct kvm *kvm, gfn_t gfn, bool write_fault,
                                bool *writable, struct page **page)
 {
@@ -2910,25 +2935,23 @@ void kvm_release_pfn(kvm_pfn_t pfn, bool dirty)
 
 int kvm_vcpu_map(struct kvm_vcpu *vcpu, gfn_t gfn, struct kvm_host_map *map)
 {
+       struct page *page;
        kvm_pfn_t pfn;
        void *hva = NULL;
-       struct page *page = KVM_UNMAPPED_PAGE;
 
        if (!map)
                return -EINVAL;
 
-       pfn = gfn_to_pfn(vcpu->kvm, gfn);
+       pfn = kvm_follow_pfn(vcpu->kvm, gfn, &page)
        if (is_error_noslot_pfn(pfn))
                return -EINVAL;
 
-       if (pfn_valid(pfn)) {
-               page = pfn_to_page(pfn);
+       if (page)
                hva = kmap(page);
 #ifdef CONFIG_HAS_IOMEM
-       } else {
+       else if (allow_unsafe_kmap)
                hva = memremap(pfn_to_hpa(pfn), PAGE_SIZE, MEMREMAP_WB);
 #endif
-       }
 
        if (!hva)
                return -EFAULT;
  
Peter Xu May 24, 2023, 4:22 p.m. UTC | #2
On Mon, May 22, 2023 at 01:46:41PM -0700, Sean Christopherson wrote:
> +Peter
> 
> On Thu, Mar 30, 2023, David Stevens wrote:
> > From: David Stevens <stevensd@chromium.org>
> > 
> > Introduce new gfn_to_pfn_noref functions that parallel existing
> > gfn_to_pfn functions. These functions can be used when the caller does
> > not need to maintain a reference to the returned pfn (i.e. when usage is
> > guarded by a mmu_notifier). The noref functions take an out parameter
> > that is used to return the struct page if the hva was resolved via gup.
> > The caller needs to drop its reference such a returned page.
> 
> I dislike the "noref" name and the approach itself (of providing an entirely
> separate set of APIs).  Using "noref" is confusing because the callers do actually
> get a reference to the page (if a refcounted page is found).
> 
> As for the approach, I really, really don't want to end up with yet more APIs
> for getting PFNs from GFNs.  We already have far too many.  In the short term,
> I think we'll need to carry multiple sets of APIs, as converting all architectures
> to any new API will be too much for a single series.  But I want to have line of
> sight to convering on a single, as-small-as-possible set of APIs, and I think/hope
> it should be possible to make the old APIs, e.g. gfn_to_pfn(), to be shims around
> the new APIs.
> 
> And since this series is essentially overhauling the gfn_to_pfn APIs, I think it's
> the right series to take on refactoring the APIs to clean up the growing flag
> problem.  There was a bit of discussion back when "interruptible" support was
> added (https://lore.kernel.org/all/YrTbKaRe497n8M0o@xz-m1.loca), but it got punted
> because it wasn't necessary, and because there wasn't immediate agreement on what
> exactly the APIs should look like.
> 
> Overhauling the APIs would also let us clean up things like async #PF, specifically
> replacing the unintuitive "*async = true" logic with something like this:
> 
> 		if ((flags & FOLL_NOWAIT) && vma_is_valid(vma, flags & FOLL_WRITE))
> 			pfn = KVM_PFN_ERR_FAULT_MINOR;
> 		else
> 			pfn = KVM_PFN_ERR_FAULT;
> 
> Lastly, I think there's also an opportunity here to harden KVM's interaction with
> mmu_notifiers, and to dedup arch code in KVM .  Specifically, even when the proposed
> "allow_unsafe_kmap" is true, KVM should either (a) be "in" an mmu_notifier sequence
> or (b) _want_ to grab a reference.  And when KVM does NOT want a reference, the core
> API can/should immediately drop the reference even before returning.
> 
> My thought is it provide an "entirely" new API, named something like kvm_follow_pfn()
> to somewhat mirror follow_{pfn,pte,phys}().  Ideally something to pair with gup()
> would be nice, but having a dedicated KVM helper to get _only_ struct page memory
> doesn't work well because KVM almost never wants only struct page memory.
> 
> As for the flags vs. bools debate (see link above), I think the best approach is
> a mix of the two.  Specifically, reuse the FOLL_* flags as-is for inputs, and use
> booleans for outputs.  I don't _think_ there are any input bools/flags that don't
> map 1:1 with existing FOLL_* flags.
> 
> As a very, *very* rough sketch, provide APIs that look a bit like this.

Unifying ref vs nonref cases does look a bit cleaner to me too.

> 
>   kvm_pfn_t __kvm_follow_pfn(struct kvm_follow_pfn *foll)
>   {
> 	kvm_pfn_t pfn;
> 
> 	if (WARN_ON_ONCE(!(foll->flags & FOLL_GET) && !foll.mmu_seq))

IMHO we may not want to rely on mmu_seq==0 either for unlucky very initial
mmu_seq being zero, or avoid overflows?

I'd say we can stick with FOLL_GET in this case to identify ref vs nonref
and always assume mmu_seq a pure random number.

> 		return KVM_PFN_ERR_FAULT;
> 
> 	pfn = ???;
> 
> 	if (foll->page && !(foll->flags & FOLL_GET))
> 		put_page(foll->page);
> 
> 	return pfn;
>   }
> 
>   kvm_pfn_t kvm_follow_pfn(struct kvm_vcpu *vcpu, gfn_t gfn, struct page **page)
>   {
> 	struct kvm_follow_pfn foll = {
> 		.flags = FOLL_GET | FOLL_WRITE,
> 	};
> 
> 	<more stuff here?>
> 
> 	foll.slot = ???;
> 	if (!foll.slot || foll.slot->flags & KVM_MEMSLOT_INVALID)
> 		return KVM_HVA_ERR_BAD;
> 
> 	if (memslot_is_readonly(foll.slot))
> 		return KVM_HVA_ERR_RO_BAD;
> 
> 	return __kvm_follow_pfn(&foll);
>   }
> 
> and a few partially converted users
> 
> diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
> index 67e2ac799aa7..5eaf0395ed87 100644
> --- a/arch/x86/kvm/mmu/mmu.c
> +++ b/arch/x86/kvm/mmu/mmu.c
> @@ -550,12 +550,14 @@ static bool mmu_spte_update(u64 *sptep, u64 new_spte)
>  
>         if (is_accessed_spte(old_spte) && !is_accessed_spte(new_spte)) {
>                 flush = true;
> -               kvm_set_pfn_accessed(spte_to_pfn(old_spte));
> +               if (is_refcounted_page_pte(old_spte))

One question is how to impl is_refcounted_page_pte() here to identify
non-refcountable pages.

IIUC those pages are mostly identical to a normal page (so !PG_reserved)
but it has page_ref_count(page)==0 always, am I right?  I got that roughly
from reading f8be156be1 only though, so I could miss a lot of things..

When thinking about that, I'm also wondering whether we can trivially allow
kvm to support such mapping (without overhaul of the kvm pfn API) by
something like this:

===8<===
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 51e4882d0873..467acbac1a96 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -192,7 +192,13 @@ struct page *kvm_pfn_to_refcounted_page(kvm_pfn_t pfn)

        page = pfn_to_page(pfn);
        if (!PageReserved(page))
-               return page;
+               /*
+                * When page_ref_count(page)==0 it might be speical page
+                * that do not support refcounting.  Treating them the same
+                * as normal reserved (e.g. MMIO) pages by returning NULL,
+                * so they're exempt of refcounting.
+                */
+               return page_ref_count(page) == 0 ? NULL : page;

        /* The ZERO_PAGE(s) is marked PG_reserved, but is refcounted. */
        if (is_zero_pfn(pfn))
===8<===

So that we treat those special pages the same as normal PFNMAP ones by
skipping all refcountings on inc/dec.  This is based on the fact that kvm
should always hold at least 1 ref on a normal page so a normal page should
never hit ref==0 here, but again I could miss something somewhere..

> +                       kvm_set_page_accessed(pfn_to_page(spte_to_pfn));
>         }
>  
>         if (is_dirty_spte(old_spte) && !is_dirty_spte(new_spte)) {
>                 flush = true;
> -               kvm_set_pfn_dirty(spte_to_pfn(old_spte));
> +               if (is_refcounted_page_pte(old_spte))
> +                       kvm_set_page_dirty(pfn_to_page(spte_to_pfn));
>         }
>  
>         return flush;
> @@ -4278,6 +4280,10 @@ void kvm_arch_async_page_ready(struct kvm_vcpu *vcpu, struct kvm_async_pf *work)
>  
>  static int __kvm_faultin_pfn(struct kvm_vcpu *vcpu, struct kvm_page_fault *fault)
>  {
> +       struct kvm_follow_pfn foll = {
> +               .mmu_seq = fault->mmu_seq,
> +               .gfn = fault->gfn,
> +       };
>         struct kvm_memory_slot *slot = fault->slot;
>         bool async;
>  
> @@ -4309,12 +4315,16 @@ static int __kvm_faultin_pfn(struct kvm_vcpu *vcpu, struct kvm_page_fault *fault
>                         return RET_PF_EMULATE;
>         }
>  
> -       async = false;
> -       fault->pfn = __gfn_to_pfn_noref_memslot(slot, fault->gfn, false, false, &async,
> -                                               fault->write, &fault->map_writable,
> -                                               &fault->hva, &fault->page);
> -       if (!async)
> -               return RET_PF_CONTINUE; /* *pfn has correct page already */
> +       foll.flags = FOLL_NOWAIT;
> +       if (fault->write)
> +               foll.flags |= FOLL_WRITE;
> +
> +       fault->pfn = __kvm_follow_pfn(&foll);
> +       if (!is_error_noslot_pfn(fault->pfn))
> +               goto success;
> +
> +       if (!is_fault_minor_pfn(fault->pfn))
> +               return RET_PF_CONTINUE;
>  
>         if (!fault->prefetch && kvm_can_do_async_pf(vcpu)) {
>                 trace_kvm_try_async_get_page(fault->addr, fault->gfn);
> @@ -4332,9 +4342,18 @@ static int __kvm_faultin_pfn(struct kvm_vcpu *vcpu, struct kvm_page_fault *fault
>          * to wait for IO.  Note, gup always bails if it is unable to quickly
>          * get a page and a fatal signal, i.e. SIGKILL, is pending.
>          */
> -       fault->pfn = __gfn_to_pfn_noref_memslot(slot, fault->gfn, false, true, NULL,
> -                                               fault->write, &fault->map_writable,
> -                                               &fault->hva, &fault->page);
> +       foll.flags |= FOLL_INTERRUPTIBLE;
> +       foll.flags &= ~FOLL_NOWAIT;
> +
> +       fault->pfn = kvm_follow_pfn(&foll);
> +       if (!is_error_noslot_pfn(fault->pfn))
> +               goto success;
> +
> +       return RET_PF_CONTINUE;
> +success:
> +       fault->hva = foll.hva;
> +       fault->page = foll.page;
> +       fault->map_writable = foll.writable;
>         return RET_PF_CONTINUE;
>  }
>  
> diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
> index 360eaa24456f..0bae253c88dd 100644
> --- a/virt/kvm/kvm_main.c
> +++ b/virt/kvm/kvm_main.c
> @@ -2663,9 +2663,10 @@ kvm_pfn_t hva_to_pfn(unsigned long addr, bool atomic, bool interruptible,
>                 if (r < 0)
>                         pfn = KVM_PFN_ERR_FAULT;
>         } else {
> -               if (async && vma_is_valid(vma, write_fault))
> -                       *async = true;
> -               pfn = KVM_PFN_ERR_FAULT;
> +               if ((flags & FOLL_NOWAIT) && vma_is_valid(vma, flags & FOLL_WRITE))
> +                       pfn = KVM_PFN_ERR_FAULT_MINOR;
> +               else
> ...skipping...
> +       fault->pfn = kvm_follow_pfn(&foll);
> +       if (!is_error_noslot_pfn(fault->pfn))
> +               goto success;
> +
> +       return RET_PF_CONTINUE;
> +success:
> +       fault->hva = foll.hva;
> +       fault->page = foll.page;
> +       fault->map_writable = foll.writable;
>         return RET_PF_CONTINUE;
>  }
>  
> diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
> index 360eaa24456f..0bae253c88dd 100644
> --- a/virt/kvm/kvm_main.c
> +++ b/virt/kvm/kvm_main.c
> @@ -2663,9 +2663,10 @@ kvm_pfn_t hva_to_pfn(unsigned long addr, bool atomic, bool interruptible,
>                 if (r < 0)
>                         pfn = KVM_PFN_ERR_FAULT;
>         } else {
> -               if (async && vma_is_valid(vma, write_fault))
> -                       *async = true;
> -               pfn = KVM_PFN_ERR_FAULT;
> +               if ((flags & FOLL_NOWAIT) && vma_is_valid(vma, flags & FOLL_WRITE))
> +                       pfn = KVM_PFN_ERR_FAULT_MINOR;
> +               else
> +                       pfn = KVM_PFN_ERR_FAULT;
>         }
>  exit:
>         mmap_read_unlock(current->mm);
> @@ -2732,6 +2733,30 @@ kvm_pfn_t __gfn_to_pfn_noref_memslot(const struct kvm_memory_slot *slot, gfn_t g
>  }
>  EXPORT_SYMBOL_GPL(__gfn_to_pfn_noref_memslot);
>  
> +kvm_pfn_t __kvm_follow_pfn(struct kvm_follow_pfn *foll)
> +{
> +       kvm_pfn_t pfn;
> +
> +       if (WARN_ON_ONCE(!(foll->flags & FOLL_GET) && !foll.mmu_seq))
> +               return KVM_PFN_ERR_FAULT;
> +
> +       pfn = __gfn_to_pfn_noref_memslot(...);
> +
> +       if (foll->page && !(foll->flags & FOLL_GET))
> +               put_page(foll->page);
> +
> +       return pfn;
> +}
> +
> +kvm_pfn_t kvm_follow_pfn(struct kvm_vcpu *vcpu, gfn_t gfn, struct page **page)
> +{
> +       struct kvm_follow_pfn foll = {
> +               .flags = FOLL_GET | FOLL_WRITE,
> +       };
> +
> +       return __kvm_follow_pfn(&foll);
> +}
> +
>  kvm_pfn_t gfn_to_pfn_noref_prot(struct kvm *kvm, gfn_t gfn, bool write_fault,
>                                 bool *writable, struct page **page)
>  {
> @@ -2910,25 +2935,23 @@ void kvm_release_pfn(kvm_pfn_t pfn, bool dirty)
>  
>  int kvm_vcpu_map(struct kvm_vcpu *vcpu, gfn_t gfn, struct kvm_host_map *map)
>  {
> +       struct page *page;
>         kvm_pfn_t pfn;
>         void *hva = NULL;
> -       struct page *page = KVM_UNMAPPED_PAGE;
>  
>         if (!map)
>                 return -EINVAL;
>  
> -       pfn = gfn_to_pfn(vcpu->kvm, gfn);
> +       pfn = kvm_follow_pfn(vcpu->kvm, gfn, &page)
>         if (is_error_noslot_pfn(pfn))
>                 return -EINVAL;
>  
> -       if (pfn_valid(pfn)) {
> -               page = pfn_to_page(pfn);
> +       if (page)
>                 hva = kmap(page);
>  #ifdef CONFIG_HAS_IOMEM
> -       } else {
> +       else if (allow_unsafe_kmap)
>                 hva = memremap(pfn_to_hpa(pfn), PAGE_SIZE, MEMREMAP_WB);
>  #endif
> -       }
>  
>         if (!hva)
>                 return -EFAULT;
>
  
Sean Christopherson May 24, 2023, 4:46 p.m. UTC | #3
On Wed, May 24, 2023, Peter Xu wrote:
> On Mon, May 22, 2023 at 01:46:41PM -0700, Sean Christopherson wrote:
> > As for the flags vs. bools debate (see link above), I think the best approach is
> > a mix of the two.  Specifically, reuse the FOLL_* flags as-is for inputs, and use
> > booleans for outputs.  I don't _think_ there are any input bools/flags that don't
> > map 1:1 with existing FOLL_* flags.
> > 
> > As a very, *very* rough sketch, provide APIs that look a bit like this.
> 
> Unifying ref vs nonref cases does look a bit cleaner to me too.
> 
> > 
> >   kvm_pfn_t __kvm_follow_pfn(struct kvm_follow_pfn *foll)
> >   {
> > 	kvm_pfn_t pfn;
> > 
> > 	if (WARN_ON_ONCE(!(foll->flags & FOLL_GET) && !foll.mmu_seq))
> 
> IMHO we may not want to rely on mmu_seq==0 either for unlucky very initial
> mmu_seq being zero, or avoid overflows?

I was thinking we could initialize mmu_seq to '1' and make it a u64 to avoid
overflow.

> I'd say we can stick with FOLL_GET in this case to identify ref vs nonref
> and always assume mmu_seq a pure random number.

The intent of checking mmu_seq is to flag cases where the caller doesn't specify
FOLL_GET and isn't protected by mmu_invalidate_seq, i.e. isn't tapped into the
mmu_notifiers.  I.e. this is a sanity check, not functionally necessary.

> 
> > 		return KVM_PFN_ERR_FAULT;
> > 
> > 	pfn = ???;
> > 
> > 	if (foll->page && !(foll->flags & FOLL_GET))
> > 		put_page(foll->page);
> > 
> > 	return pfn;
> >   }
> > 
> >   kvm_pfn_t kvm_follow_pfn(struct kvm_vcpu *vcpu, gfn_t gfn, struct page **page)
> >   {
> > 	struct kvm_follow_pfn foll = {
> > 		.flags = FOLL_GET | FOLL_WRITE,
> > 	};
> > 
> > 	<more stuff here?>
> > 
> > 	foll.slot = ???;
> > 	if (!foll.slot || foll.slot->flags & KVM_MEMSLOT_INVALID)
> > 		return KVM_HVA_ERR_BAD;
> > 
> > 	if (memslot_is_readonly(foll.slot))
> > 		return KVM_HVA_ERR_RO_BAD;
> > 
> > 	return __kvm_follow_pfn(&foll);
> >   }
> > 
> > and a few partially converted users
> > 
> > diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
> > index 67e2ac799aa7..5eaf0395ed87 100644
> > --- a/arch/x86/kvm/mmu/mmu.c
> > +++ b/arch/x86/kvm/mmu/mmu.c
> > @@ -550,12 +550,14 @@ static bool mmu_spte_update(u64 *sptep, u64 new_spte)
> >  
> >         if (is_accessed_spte(old_spte) && !is_accessed_spte(new_spte)) {
> >                 flush = true;
> > -               kvm_set_pfn_accessed(spte_to_pfn(old_spte));
> > +               if (is_refcounted_page_pte(old_spte))
> 
> One question is how to impl is_refcounted_page_pte() here to identify
> non-refcountable pages.

KVM would use a software available bit in its PTEs to explicitly track which SPTEs
point at refcounted pages.  E.g. I think bit 59 is available for EPT and 64-bit
paging.  PAE paging doesn't have high available bits, which is why I called out
that this would have to be 64-bit only.

> IIUC those pages are mostly identical to a normal page (so !PG_reserved)
> but it has page_ref_count(page)==0 always, am I right?  I got that roughly
> from reading f8be156be1 only though, so I could miss a lot of things..
> 
> When thinking about that, I'm also wondering whether we can trivially allow
> kvm to support such mapping (without overhaul of the kvm pfn API) by
> something like this:
> 
> ===8<===
> diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
> index 51e4882d0873..467acbac1a96 100644
> --- a/virt/kvm/kvm_main.c
> +++ b/virt/kvm/kvm_main.c
> @@ -192,7 +192,13 @@ struct page *kvm_pfn_to_refcounted_page(kvm_pfn_t pfn)
> 
>         page = pfn_to_page(pfn);
>         if (!PageReserved(page))
> -               return page;
> +               /*
> +                * When page_ref_count(page)==0 it might be speical page
> +                * that do not support refcounting.  Treating them the same
> +                * as normal reserved (e.g. MMIO) pages by returning NULL,
> +                * so they're exempt of refcounting.
> +                */
> +               return page_ref_count(page) == 0 ? NULL : page;

Heh, because I got burned by this recently, using page_ref_count() is wrong.  This
needs to be page_count() so that tail pages of refcounted compound pages are
properly identified.

> 
>         /* The ZERO_PAGE(s) is marked PG_reserved, but is refcounted. */
>         if (is_zero_pfn(pfn))
> ===8<===
> 
> So that we treat those special pages the same as normal PFNMAP ones by
> skipping all refcountings on inc/dec.  This is based on the fact that kvm
> should always hold at least 1 ref on a normal page so a normal page should
> never hit ref==0 here, but again I could miss something somewhere..

This would "work" from a functionality perspective, and might be acceptable as an
out-of-tree patch to unblock the ChromeOS use case, but I don't want to rely on
this heuristic on the backend in KVM because it will suppress any and all
use-after-free bugs in KVM's MMU (see patch 4 of this series).  I really want to
go in the opposite direction and harden KVM against MMU bugs, e.g. I'm planning
on posting the below (which is how I learned about page_count() vs. page_ref_count()).

Today, KVM gets partial protection from check_new_page_bad(), which detects *some*
cases where KVM marks a page dirty after the page is freed.  But it's racy, and
the detection occurs well after the fact since it fires only when the page is
re-allocated.

If we hack kvm_pfn_to_refcounted_page(), then all of those protections are lost
because KVM would drop its assertions and also skip dirtying pages, i.e. would
effectively suppress the latent detection by check_new_page_bad().

Author: Sean Christopherson <seanjc@google.com>
Date:   Wed May 17 13:26:54 2023 -0700

    KVM: Assert that a page's refcount is elevated when marking accessed/dirty
    
    Assert that a page's refcount is elevated, i.e. that _something_ holds a
    reference to the page, when KVM marks a page as accessed and/or dirty.
    KVM typically doesn't hold a reference to pages that are mapped into the
    guest, e.g. to allow page migration, compaction, swap, etc., and instead
    relies on mmu_notifiers to react to changes in the primary MMU.
    
    Incorrect handling of mmu_notifier events (or similar mechanisms) can
    result in KVM keeping a mapping beyond the lifetime of the backing page,
    i.e. can (and often does) result in use-after-free.  Yelling if KVM marks
    a freed page as accessed/dirty doesn't prevent badness as KVM usually
    only does A/D updates when unmapping memory from the guest, i.e. the
    assertion fires well after an underlying bug has occured, but yelling
    does help detect, triage, and debug use-after-free bugs.
    
    Note, the assertion must use page_count(), NOT page_ref_count()!  For
    hugepages, the returned struct page may be a tailpage and thus not have
    its own refcount.
    
    Signed-off-by: Sean Christopherson <seanjc@google.com>

diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index d1abb331ea68..64f18697096c 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -2882,6 +2882,19 @@ EXPORT_SYMBOL_GPL(kvm_vcpu_unmap);
 
 static bool kvm_is_ad_tracked_page(struct page *page)
 {
+       /*
+        * Assert that KVM isn't attempting to mark a freed page as Accessed or
+        * Dirty, i.e. that KVM's MMU doesn't have a use-after-free bug.  KVM
+        * (typically) doesn't pin pages that are mapped in KVM's MMU, and
+        * instead relies on mmu_notifiers to know when a mapping needs to be
+        * zapped/invalidated.  Unmapping from KVM's MMU must happen _before_
+        * KVM returns from its mmu_notifier, i.e. the page should have an
+        * elevated refcount at this point even though KVM doesn't hold a
+        * reference of its own.
+        */
+       if (WARN_ON_ONCE(!page_count(page)))
+               return false;
+
        /*
         * Per page-flags.h, pages tagged PG_reserved "should in general not be
         * touched (e.g. set dirty) except by its owner".
  
Peter Xu May 24, 2023, 5:14 p.m. UTC | #4
On Wed, May 24, 2023 at 09:46:13AM -0700, Sean Christopherson wrote:
> On Wed, May 24, 2023, Peter Xu wrote:
> > On Mon, May 22, 2023 at 01:46:41PM -0700, Sean Christopherson wrote:
> > > As for the flags vs. bools debate (see link above), I think the best approach is
> > > a mix of the two.  Specifically, reuse the FOLL_* flags as-is for inputs, and use
> > > booleans for outputs.  I don't _think_ there are any input bools/flags that don't
> > > map 1:1 with existing FOLL_* flags.
> > > 
> > > As a very, *very* rough sketch, provide APIs that look a bit like this.
> > 
> > Unifying ref vs nonref cases does look a bit cleaner to me too.
> > 
> > > 
> > >   kvm_pfn_t __kvm_follow_pfn(struct kvm_follow_pfn *foll)
> > >   {
> > > 	kvm_pfn_t pfn;
> > > 
> > > 	if (WARN_ON_ONCE(!(foll->flags & FOLL_GET) && !foll.mmu_seq))
> > 
> > IMHO we may not want to rely on mmu_seq==0 either for unlucky very initial
> > mmu_seq being zero, or avoid overflows?
> 
> I was thinking we could initialize mmu_seq to '1' and make it a u64 to avoid
> overflow.

Yeah, that's fine too.

> 
> > I'd say we can stick with FOLL_GET in this case to identify ref vs nonref
> > and always assume mmu_seq a pure random number.
> 
> The intent of checking mmu_seq is to flag cases where the caller doesn't specify
> FOLL_GET and isn't protected by mmu_invalidate_seq, i.e. isn't tapped into the
> mmu_notifiers.  I.e. this is a sanity check, not functionally necessary.
> 
> > 
> > > 		return KVM_PFN_ERR_FAULT;
> > > 
> > > 	pfn = ???;
> > > 
> > > 	if (foll->page && !(foll->flags & FOLL_GET))
> > > 		put_page(foll->page);
> > > 
> > > 	return pfn;
> > >   }
> > > 
> > >   kvm_pfn_t kvm_follow_pfn(struct kvm_vcpu *vcpu, gfn_t gfn, struct page **page)
> > >   {
> > > 	struct kvm_follow_pfn foll = {
> > > 		.flags = FOLL_GET | FOLL_WRITE,
> > > 	};
> > > 
> > > 	<more stuff here?>
> > > 
> > > 	foll.slot = ???;
> > > 	if (!foll.slot || foll.slot->flags & KVM_MEMSLOT_INVALID)
> > > 		return KVM_HVA_ERR_BAD;
> > > 
> > > 	if (memslot_is_readonly(foll.slot))
> > > 		return KVM_HVA_ERR_RO_BAD;
> > > 
> > > 	return __kvm_follow_pfn(&foll);
> > >   }
> > > 
> > > and a few partially converted users
> > > 
> > > diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
> > > index 67e2ac799aa7..5eaf0395ed87 100644
> > > --- a/arch/x86/kvm/mmu/mmu.c
> > > +++ b/arch/x86/kvm/mmu/mmu.c
> > > @@ -550,12 +550,14 @@ static bool mmu_spte_update(u64 *sptep, u64 new_spte)
> > >  
> > >         if (is_accessed_spte(old_spte) && !is_accessed_spte(new_spte)) {
> > >                 flush = true;
> > > -               kvm_set_pfn_accessed(spte_to_pfn(old_spte));
> > > +               if (is_refcounted_page_pte(old_spte))
> > 
> > One question is how to impl is_refcounted_page_pte() here to identify
> > non-refcountable pages.
> 
> KVM would use a software available bit in its PTEs to explicitly track which SPTEs
> point at refcounted pages.  E.g. I think bit 59 is available for EPT and 64-bit
> paging.  PAE paging doesn't have high available bits, which is why I called out
> that this would have to be 64-bit only.
> 
> > IIUC those pages are mostly identical to a normal page (so !PG_reserved)
> > but it has page_ref_count(page)==0 always, am I right?  I got that roughly
> > from reading f8be156be1 only though, so I could miss a lot of things..
> > 
> > When thinking about that, I'm also wondering whether we can trivially allow
> > kvm to support such mapping (without overhaul of the kvm pfn API) by
> > something like this:
> > 
> > ===8<===
> > diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
> > index 51e4882d0873..467acbac1a96 100644
> > --- a/virt/kvm/kvm_main.c
> > +++ b/virt/kvm/kvm_main.c
> > @@ -192,7 +192,13 @@ struct page *kvm_pfn_to_refcounted_page(kvm_pfn_t pfn)
> > 
> >         page = pfn_to_page(pfn);
> >         if (!PageReserved(page))
> > -               return page;
> > +               /*
> > +                * When page_ref_count(page)==0 it might be speical page
> > +                * that do not support refcounting.  Treating them the same
> > +                * as normal reserved (e.g. MMIO) pages by returning NULL,
> > +                * so they're exempt of refcounting.
> > +                */
> > +               return page_ref_count(page) == 0 ? NULL : page;
> 
> Heh, because I got burned by this recently, using page_ref_count() is wrong.  This
> needs to be page_count() so that tail pages of refcounted compound pages are
> properly identified.

:-D

Actually when I was replying I explicitly didn't use page_count() to make
sure we're reading the tail page, but I just noticed that's exactly the way
how we identify the special page with a PageCompound()==true tail page.

Yeah, if we'd like that it needs to be page_count()==0.

> 
> > 
> >         /* The ZERO_PAGE(s) is marked PG_reserved, but is refcounted. */
> >         if (is_zero_pfn(pfn))
> > ===8<===
> > 
> > So that we treat those special pages the same as normal PFNMAP ones by
> > skipping all refcountings on inc/dec.  This is based on the fact that kvm
> > should always hold at least 1 ref on a normal page so a normal page should
> > never hit ref==0 here, but again I could miss something somewhere..
> 
> This would "work" from a functionality perspective, and might be acceptable as an
> out-of-tree patch to unblock the ChromeOS use case, but I don't want to rely on
> this heuristic on the backend in KVM because it will suppress any and all
> use-after-free bugs in KVM's MMU (see patch 4 of this series).  I really want to
> go in the opposite direction and harden KVM against MMU bugs, e.g. I'm planning
> on posting the below (which is how I learned about page_count() vs. page_ref_count()).
> 
> Today, KVM gets partial protection from check_new_page_bad(), which detects *some*
> cases where KVM marks a page dirty after the page is freed.  But it's racy, and
> the detection occurs well after the fact since it fires only when the page is
> re-allocated.
> 
> If we hack kvm_pfn_to_refcounted_page(), then all of those protections are lost
> because KVM would drop its assertions and also skip dirtying pages, i.e. would
> effectively suppress the latent detection by check_new_page_bad().

So it's probably that I totally have no idea what are the attributes for
those special pages so I don't understand enough on why we need to handle
those pages differently from e.g. PFNMAP pages, and also the benefits.

I think what I can tell is that they're pages that doesn't have
PageCompound bits set on either head or tails, however it's still a
multi-2-order large page.  Is there an example on how these pages are used
and allocated?  Why would we need those pages, and whether these pages need
to be set dirty/accessed after all?

> 
> Author: Sean Christopherson <seanjc@google.com>
> Date:   Wed May 17 13:26:54 2023 -0700
> 
>     KVM: Assert that a page's refcount is elevated when marking accessed/dirty
>     
>     Assert that a page's refcount is elevated, i.e. that _something_ holds a
>     reference to the page, when KVM marks a page as accessed and/or dirty.
>     KVM typically doesn't hold a reference to pages that are mapped into the
>     guest, e.g. to allow page migration, compaction, swap, etc., and instead
>     relies on mmu_notifiers to react to changes in the primary MMU.
>     
>     Incorrect handling of mmu_notifier events (or similar mechanisms) can
>     result in KVM keeping a mapping beyond the lifetime of the backing page,
>     i.e. can (and often does) result in use-after-free.  Yelling if KVM marks
>     a freed page as accessed/dirty doesn't prevent badness as KVM usually
>     only does A/D updates when unmapping memory from the guest, i.e. the
>     assertion fires well after an underlying bug has occured, but yelling
>     does help detect, triage, and debug use-after-free bugs.
>     
>     Note, the assertion must use page_count(), NOT page_ref_count()!  For
>     hugepages, the returned struct page may be a tailpage and thus not have
>     its own refcount.
>     
>     Signed-off-by: Sean Christopherson <seanjc@google.com>
> 
> diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
> index d1abb331ea68..64f18697096c 100644
> --- a/virt/kvm/kvm_main.c
> +++ b/virt/kvm/kvm_main.c
> @@ -2882,6 +2882,19 @@ EXPORT_SYMBOL_GPL(kvm_vcpu_unmap);
>  
>  static bool kvm_is_ad_tracked_page(struct page *page)
>  {
> +       /*
> +        * Assert that KVM isn't attempting to mark a freed page as Accessed or
> +        * Dirty, i.e. that KVM's MMU doesn't have a use-after-free bug.  KVM
> +        * (typically) doesn't pin pages that are mapped in KVM's MMU, and
> +        * instead relies on mmu_notifiers to know when a mapping needs to be
> +        * zapped/invalidated.  Unmapping from KVM's MMU must happen _before_
> +        * KVM returns from its mmu_notifier, i.e. the page should have an
> +        * elevated refcount at this point even though KVM doesn't hold a
> +        * reference of its own.
> +        */
> +       if (WARN_ON_ONCE(!page_count(page)))
> +               return false;
> +
>         /*
>          * Per page-flags.h, pages tagged PG_reserved "should in general not be
>          * touched (e.g. set dirty) except by its owner".
> 

This looks like a good thing to have, indeed.  But again it doesn't seem
like anything special to the pages we're discussing here, say, !Compound &&
refcount==0 ones.
  
Sean Christopherson May 24, 2023, 6:29 p.m. UTC | #5
On Wed, May 24, 2023, Peter Xu wrote:
> On Wed, May 24, 2023 at 09:46:13AM -0700, Sean Christopherson wrote:
> > If we hack kvm_pfn_to_refcounted_page(), then all of those protections are lost
> > because KVM would drop its assertions and also skip dirtying pages, i.e. would
> > effectively suppress the latent detection by check_new_page_bad().
> 
> So it's probably that I totally have no idea what are the attributes for
> those special pages so I don't understand enough on why we need to handle
> those pages differently from e.g. PFNMAP pages, and also the benefits.
> 
> I think what I can tell is that they're pages that doesn't have
> PageCompound bits set on either head or tails, however it's still a
> multi-2-order large page.  Is there an example on how these pages are used
> and allocated?  Why would we need those pages, and whether these pages need
> to be set dirty/accessed after all?

The use case David is interested in is where an AMD GPU driver kmallocs() a
chunk of memory, let's it be mmap()'d by userspace, and userspace then maps it
into the guest for a virtual (passthrough?) GPU.  For all intents and purposes,
it's normal memory, just not refcounted.

> >  static bool kvm_is_ad_tracked_page(struct page *page)
> >  {
> > +       /*
> > +        * Assert that KVM isn't attempting to mark a freed page as Accessed or
> > +        * Dirty, i.e. that KVM's MMU doesn't have a use-after-free bug.  KVM
> > +        * (typically) doesn't pin pages that are mapped in KVM's MMU, and
> > +        * instead relies on mmu_notifiers to know when a mapping needs to be
> > +        * zapped/invalidated.  Unmapping from KVM's MMU must happen _before_
> > +        * KVM returns from its mmu_notifier, i.e. the page should have an
> > +        * elevated refcount at this point even though KVM doesn't hold a
> > +        * reference of its own.
> > +        */
> > +       if (WARN_ON_ONCE(!page_count(page)))
> > +               return false;
> > +
> >         /*
> >          * Per page-flags.h, pages tagged PG_reserved "should in general not be
> >          * touched (e.g. set dirty) except by its owner".
> > 
> 
> This looks like a good thing to have, indeed.  But again it doesn't seem
> like anything special to the pages we're discussing here, say, !Compound &&
> refcount==0 ones.

The problem is that if KVM ignores refcount==0 pages, then KVM can't distinguish
between the legitimate[*] refcount==0 AMD GPU case and a buggy refcount==0
use-after-free scenario.  I don't want to make that sacrifice as the legimiate
!refcounted use case is a very specific use case, whereas consuming refcounted
memory is ubiquituous (outside of maybe AWS).

[*] Consuming !refcounted pages is safe only for flows that are tied into the
    mmu_notifiers.  The current proposal/plan is to add an off-by-default module
    param that let's userspace opt-in to kmap() use of !refcounted memory, e.g.
    this case and PFNMAP memory.
  
Peter Xu May 24, 2023, 7:09 p.m. UTC | #6
On Wed, May 24, 2023 at 11:29:45AM -0700, Sean Christopherson wrote:
> On Wed, May 24, 2023, Peter Xu wrote:
> > On Wed, May 24, 2023 at 09:46:13AM -0700, Sean Christopherson wrote:
> > > If we hack kvm_pfn_to_refcounted_page(), then all of those protections are lost
> > > because KVM would drop its assertions and also skip dirtying pages, i.e. would
> > > effectively suppress the latent detection by check_new_page_bad().
> > 
> > So it's probably that I totally have no idea what are the attributes for
> > those special pages so I don't understand enough on why we need to handle
> > those pages differently from e.g. PFNMAP pages, and also the benefits.
> > 
> > I think what I can tell is that they're pages that doesn't have
> > PageCompound bits set on either head or tails, however it's still a
> > multi-2-order large page.  Is there an example on how these pages are used
> > and allocated?  Why would we need those pages, and whether these pages need
> > to be set dirty/accessed after all?
> 
> The use case David is interested in is where an AMD GPU driver kmallocs() a
> chunk of memory, let's it be mmap()'d by userspace, and userspace then maps it
> into the guest for a virtual (passthrough?) GPU.  For all intents and purposes,
> it's normal memory, just not refcounted.

I'm not familiar enough with kmalloc, but what I think is kmalloc for large
chunks will be the same as alloc_pages, and I thought it should also be a
compound page already.  If that needs to be mmap()ed to userapp then I
assume it mostly should be kmalloc_large().

kmalloc -> kmalloc_large -> __kmalloc_large_node:

	flags |= __GFP_COMP;

Then when the new page allocated and being prepared (prep_new_page):

	if (order && (gfp_flags & __GFP_COMP))
		prep_compound_page(page, order);

I assume prep_compound_page() will make PageCompound return true for those
pages returned.  So I know I still miss something, but not sure
where.. because IIRC we're at least talking about !PageCompound pages.

> 
> > >  static bool kvm_is_ad_tracked_page(struct page *page)
> > >  {
> > > +       /*
> > > +        * Assert that KVM isn't attempting to mark a freed page as Accessed or
> > > +        * Dirty, i.e. that KVM's MMU doesn't have a use-after-free bug.  KVM
> > > +        * (typically) doesn't pin pages that are mapped in KVM's MMU, and
> > > +        * instead relies on mmu_notifiers to know when a mapping needs to be
> > > +        * zapped/invalidated.  Unmapping from KVM's MMU must happen _before_
> > > +        * KVM returns from its mmu_notifier, i.e. the page should have an
> > > +        * elevated refcount at this point even though KVM doesn't hold a
> > > +        * reference of its own.
> > > +        */
> > > +       if (WARN_ON_ONCE(!page_count(page)))
> > > +               return false;
> > > +
> > >         /*
> > >          * Per page-flags.h, pages tagged PG_reserved "should in general not be
> > >          * touched (e.g. set dirty) except by its owner".
> > > 
> > 
> > This looks like a good thing to have, indeed.  But again it doesn't seem
> > like anything special to the pages we're discussing here, say, !Compound &&
> > refcount==0 ones.
> 
> The problem is that if KVM ignores refcount==0 pages, then KVM can't distinguish
> between the legitimate[*] refcount==0 AMD GPU case and a buggy refcount==0
> use-after-free scenario.  I don't want to make that sacrifice as the legimiate
> !refcounted use case is a very specific use case, whereas consuming refcounted
> memory is ubiquituous (outside of maybe AWS).
> 
> [*] Consuming !refcounted pages is safe only for flows that are tied into the
>     mmu_notifiers.  The current proposal/plan is to add an off-by-default module
>     param that let's userspace opt-in to kmap() use of !refcounted memory, e.g.
>     this case and PFNMAP memory.

I see.

I think you mentioned that we can use one special bit in the shadow pte to
mark such special pages.  Does it mean that your above patch will still
cover what you wanted to protect even if we use the trick?  Because then
kvm_is_ad_tracked_page() should only be called when we're sure the special
bit is not set.  IOW, we can still rule out these pages already and
page_count()==0 check here can still be helpful to track kvm bugs?
  
Sean Christopherson May 24, 2023, 8:05 p.m. UTC | #7
On Wed, May 24, 2023, Peter Xu wrote:
> On Wed, May 24, 2023 at 11:29:45AM -0700, Sean Christopherson wrote:
> > On Wed, May 24, 2023, Peter Xu wrote:
> > > On Wed, May 24, 2023 at 09:46:13AM -0700, Sean Christopherson wrote:
> > > > If we hack kvm_pfn_to_refcounted_page(), then all of those protections are lost
> > > > because KVM would drop its assertions and also skip dirtying pages, i.e. would
> > > > effectively suppress the latent detection by check_new_page_bad().
> > > 
> > > So it's probably that I totally have no idea what are the attributes for
> > > those special pages so I don't understand enough on why we need to handle
> > > those pages differently from e.g. PFNMAP pages, and also the benefits.
> > > 
> > > I think what I can tell is that they're pages that doesn't have
> > > PageCompound bits set on either head or tails, however it's still a
> > > multi-2-order large page.  Is there an example on how these pages are used
> > > and allocated?  Why would we need those pages, and whether these pages need
> > > to be set dirty/accessed after all?
> > 
> > The use case David is interested in is where an AMD GPU driver kmallocs() a
> > chunk of memory, let's it be mmap()'d by userspace, and userspace then maps it
> > into the guest for a virtual (passthrough?) GPU.  For all intents and purposes,
> > it's normal memory, just not refcounted.
> 
> I'm not familiar enough with kmalloc, but what I think is kmalloc for large
> chunks will be the same as alloc_pages, and I thought it should also be a
> compound page already.  If that needs to be mmap()ed to userapp then I
> assume it mostly should be kmalloc_large().

Sorry, by "kmalloc()" I was handwaving at all of the variations of kernel allocated
memory.  From a separate thread[*], looks like the actual usage is a direct call to
alloc_pages() that deliberately doesn't set __GFP_COMP.  Note, I'm pretty sure the
comment about "mapping pages directly into userspace" being illegal really means
something like "don't allow these pages to be gup()'d or mapped via standard mmap()".
IIUC, ttm_pool_alloc() fills tt->pages and then ttm_bo_vm_fault_reserved() does
vmf_insert_pfn_prot() to shove the pfn into userspace.

  static struct page *ttm_pool_alloc_page(struct ttm_pool *pool, gfp_t gfp_flags,
					unsigned int order)
  {
	unsigned long attr = DMA_ATTR_FORCE_CONTIGUOUS;
	struct ttm_pool_dma *dma;
	struct page *p;
	void *vaddr;

	/* Don't set the __GFP_COMP flag for higher order allocations.
	 * Mapping pages directly into an userspace process and calling
	 * put_page() on a TTM allocated page is illegal.
	 */
	if (order)
		gfp_flags |= __GFP_NOMEMALLOC | __GFP_NORETRY | __GFP_NOWARN |
			__GFP_KSWAPD_RECLAIM;

	if (!pool->use_dma_alloc) {
		p = alloc_pages(gfp_flags, order);
		if (p)
			p->private = order;
		return p;

	}

[*] https://lore.kernel.org/all/20220815095423.11131-1-dmitry.osipenko@collabora.com

> kmalloc -> kmalloc_large -> __kmalloc_large_node:
> 
> 	flags |= __GFP_COMP;
> 
> Then when the new page allocated and being prepared (prep_new_page):
> 
> 	if (order && (gfp_flags & __GFP_COMP))
> 		prep_compound_page(page, order);
> 
> I assume prep_compound_page() will make PageCompound return true for those
> pages returned.  So I know I still miss something, but not sure
> where.. because IIRC we're at least talking about !PageCompound pages.

Yeah, they're !PageCompound().

> > > >  static bool kvm_is_ad_tracked_page(struct page *page)
> > > >  {
> > > > +       /*
> > > > +        * Assert that KVM isn't attempting to mark a freed page as Accessed or
> > > > +        * Dirty, i.e. that KVM's MMU doesn't have a use-after-free bug.  KVM
> > > > +        * (typically) doesn't pin pages that are mapped in KVM's MMU, and
> > > > +        * instead relies on mmu_notifiers to know when a mapping needs to be
> > > > +        * zapped/invalidated.  Unmapping from KVM's MMU must happen _before_
> > > > +        * KVM returns from its mmu_notifier, i.e. the page should have an
> > > > +        * elevated refcount at this point even though KVM doesn't hold a
> > > > +        * reference of its own.
> > > > +        */
> > > > +       if (WARN_ON_ONCE(!page_count(page)))
> > > > +               return false;
> > > > +
> > > >         /*
> > > >          * Per page-flags.h, pages tagged PG_reserved "should in general not be
> > > >          * touched (e.g. set dirty) except by its owner".
> > > > 
> > > 
> > > This looks like a good thing to have, indeed.  But again it doesn't seem
> > > like anything special to the pages we're discussing here, say, !Compound &&
> > > refcount==0 ones.
> > 
> > The problem is that if KVM ignores refcount==0 pages, then KVM can't distinguish
> > between the legitimate[*] refcount==0 AMD GPU case and a buggy refcount==0
> > use-after-free scenario.  I don't want to make that sacrifice as the legimiate
> > !refcounted use case is a very specific use case, whereas consuming refcounted
> > memory is ubiquituous (outside of maybe AWS).
> > 
> > [*] Consuming !refcounted pages is safe only for flows that are tied into the
> >     mmu_notifiers.  The current proposal/plan is to add an off-by-default module
> >     param that let's userspace opt-in to kmap() use of !refcounted memory, e.g.
> >     this case and PFNMAP memory.
> 
> I see.
> 
> I think you mentioned that we can use one special bit in the shadow pte to
> mark such special pages.  Does it mean that your above patch will still
> cover what you wanted to protect even if we use the trick?  Because then
> kvm_is_ad_tracked_page() should only be called when we're sure the special
> bit is not set.  IOW, we can still rule out these pages already and
> page_count()==0 check here can still be helpful to track kvm bugs?

Yep, exactly.  FWIW, I was thinking that the SPTE bit would flag refcounted pages,
not these "special" pages, but either way would work.  All that matters is that
KVM tracks whether or not the page was refcounted when KVM installed the SPTE.
  

Patch

diff --git a/include/linux/kvm_host.h b/include/linux/kvm_host.h
index 90edc16d37e5..146f220cc25b 100644
--- a/include/linux/kvm_host.h
+++ b/include/linux/kvm_host.h
@@ -1162,8 +1162,22 @@  kvm_pfn_t __gfn_to_pfn_memslot(const struct kvm_memory_slot *slot, gfn_t gfn,
 			       bool atomic, bool interruptible, bool *async,
 			       bool write_fault, bool *writable, hva_t *hva);
 
+kvm_pfn_t gfn_to_pfn_noref(struct kvm *kvm, gfn_t gfn, struct page **page);
+kvm_pfn_t gfn_to_pfn_noref_prot(struct kvm *kvm, gfn_t gfn,
+				bool write_fault, bool *writable,
+				struct page **page);
+kvm_pfn_t gfn_to_pfn_noref_memslot(const struct kvm_memory_slot *slot,
+				   gfn_t gfn, struct page **page);
+kvm_pfn_t gfn_to_pfn_noref_memslot_atomic(const struct kvm_memory_slot *slot,
+					  gfn_t gfn, struct page **page);
+kvm_pfn_t __gfn_to_pfn_noref_memslot(const struct kvm_memory_slot *slot,
+				     gfn_t gfn, bool atomic, bool interruptible,
+				     bool *async, bool write_fault, bool *writable,
+				     hva_t *hva, struct page **page);
+
 void kvm_release_pfn_clean(kvm_pfn_t pfn);
 void kvm_release_pfn_dirty(kvm_pfn_t pfn);
+void kvm_release_pfn_noref_clean(kvm_pfn_t pfn, struct page *page);
 void kvm_set_pfn_dirty(kvm_pfn_t pfn);
 void kvm_set_pfn_accessed(kvm_pfn_t pfn);
 
@@ -1242,6 +1256,10 @@  struct kvm_memslots *kvm_vcpu_memslots(struct kvm_vcpu *vcpu);
 struct kvm_memory_slot *kvm_vcpu_gfn_to_memslot(struct kvm_vcpu *vcpu, gfn_t gfn);
 kvm_pfn_t kvm_vcpu_gfn_to_pfn_atomic(struct kvm_vcpu *vcpu, gfn_t gfn);
 kvm_pfn_t kvm_vcpu_gfn_to_pfn(struct kvm_vcpu *vcpu, gfn_t gfn);
+kvm_pfn_t kvm_vcpu_gfn_to_pfn_noref_atomic(struct kvm_vcpu *vcpu, gfn_t gfn,
+					   struct page **page);
+kvm_pfn_t kvm_vcpu_gfn_to_pfn_noref(struct kvm_vcpu *vcpu, gfn_t gfn,
+				    struct page **page);
 int kvm_vcpu_map(struct kvm_vcpu *vcpu, gpa_t gpa, struct kvm_host_map *map);
 void kvm_vcpu_unmap(struct kvm_vcpu *vcpu, struct kvm_host_map *map, bool dirty);
 unsigned long kvm_vcpu_gfn_to_hva(struct kvm_vcpu *vcpu, gfn_t gfn);
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index f40b72eb0e7b..007dd984eeea 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -2484,9 +2484,9 @@  static inline int check_user_page_hwpoison(unsigned long addr)
  * only part that runs if we can in atomic context.
  */
 static bool hva_to_pfn_fast(unsigned long addr, bool write_fault,
-			    bool *writable, kvm_pfn_t *pfn)
+			    bool *writable, kvm_pfn_t *pfn,
+			    struct page **page)
 {
-	struct page *page[1];
 
 	/*
 	 * Fast pin a writable pfn only if it is a write fault request
@@ -2497,7 +2497,7 @@  static bool hva_to_pfn_fast(unsigned long addr, bool write_fault,
 		return false;
 
 	if (get_user_page_fast_only(addr, FOLL_WRITE, page)) {
-		*pfn = page_to_pfn(page[0]);
+		*pfn = page_to_pfn(*page);
 
 		if (writable)
 			*writable = true;
@@ -2512,10 +2512,10 @@  static bool hva_to_pfn_fast(unsigned long addr, bool write_fault,
  * 1 indicates success, -errno is returned if error is detected.
  */
 static int hva_to_pfn_slow(unsigned long addr, bool *async, bool write_fault,
-			   bool interruptible, bool *writable, kvm_pfn_t *pfn)
+			   bool interruptible, bool *writable, kvm_pfn_t *pfn,
+			   struct page **page)
 {
 	unsigned int flags = FOLL_HWPOISON;
-	struct page *page;
 	int npages;
 
 	might_sleep();
@@ -2530,7 +2530,7 @@  static int hva_to_pfn_slow(unsigned long addr, bool *async, bool write_fault,
 	if (interruptible)
 		flags |= FOLL_INTERRUPTIBLE;
 
-	npages = get_user_pages_unlocked(addr, 1, &page, flags);
+	npages = get_user_pages_unlocked(addr, 1, page, flags);
 	if (npages != 1)
 		return npages;
 
@@ -2540,11 +2540,11 @@  static int hva_to_pfn_slow(unsigned long addr, bool *async, bool write_fault,
 
 		if (get_user_page_fast_only(addr, FOLL_WRITE, &wpage)) {
 			*writable = true;
-			put_page(page);
-			page = wpage;
+			put_page(*page);
+			*page = wpage;
 		}
 	}
-	*pfn = page_to_pfn(page);
+	*pfn = page_to_pfn(*page);
 	return npages;
 }
 
@@ -2559,16 +2559,6 @@  static bool vma_is_valid(struct vm_area_struct *vma, bool write_fault)
 	return true;
 }
 
-static int kvm_try_get_pfn(kvm_pfn_t pfn)
-{
-	struct page *page = kvm_pfn_to_refcounted_page(pfn);
-
-	if (!page)
-		return 1;
-
-	return get_page_unless_zero(page);
-}
-
 static int hva_to_pfn_remapped(struct vm_area_struct *vma,
 			       unsigned long addr, bool write_fault,
 			       bool *writable, kvm_pfn_t *p_pfn)
@@ -2607,26 +2597,6 @@  static int hva_to_pfn_remapped(struct vm_area_struct *vma,
 		*writable = pte_write(*ptep);
 	pfn = pte_pfn(*ptep);
 
-	/*
-	 * Get a reference here because callers of *hva_to_pfn* and
-	 * *gfn_to_pfn* ultimately call kvm_release_pfn_clean on the
-	 * returned pfn.  This is only needed if the VMA has VM_MIXEDMAP
-	 * set, but the kvm_try_get_pfn/kvm_release_pfn_clean pair will
-	 * simply do nothing for reserved pfns.
-	 *
-	 * Whoever called remap_pfn_range is also going to call e.g.
-	 * unmap_mapping_range before the underlying pages are freed,
-	 * causing a call to our MMU notifier.
-	 *
-	 * Certain IO or PFNMAP mappings can be backed with valid
-	 * struct pages, but be allocated without refcounting e.g.,
-	 * tail pages of non-compound higher order allocations, which
-	 * would then underflow the refcount when the caller does the
-	 * required put_page. Don't allow those pages here.
-	 */ 
-	if (!kvm_try_get_pfn(pfn))
-		r = -EFAULT;
-
 out:
 	pte_unmap_unlock(ptep, ptl);
 	*p_pfn = pfn;
@@ -2643,6 +2613,7 @@  static int hva_to_pfn_remapped(struct vm_area_struct *vma,
  *         host page is not in the memory
  * @write_fault: whether we should get a writable host page
  * @writable: whether it allows to map a writable host page for !@write_fault
+ * @page: outparam for the refcounted page assicated with the pfn, if any
  *
  * The function will map a writable host page for these two cases:
  * 1): @write_fault = true
@@ -2650,23 +2621,25 @@  static int hva_to_pfn_remapped(struct vm_area_struct *vma,
  *     whether the mapping is writable.
  */
 kvm_pfn_t hva_to_pfn(unsigned long addr, bool atomic, bool interruptible,
-		     bool *async, bool write_fault, bool *writable)
+		     bool *async, bool write_fault, bool *writable,
+		     struct page **page)
 {
 	struct vm_area_struct *vma;
 	kvm_pfn_t pfn;
 	int npages, r;
+	*page = NULL;
 
 	/* we can do it either atomically or asynchronously, not both */
 	BUG_ON(atomic && async);
 
-	if (hva_to_pfn_fast(addr, write_fault, writable, &pfn))
+	if (hva_to_pfn_fast(addr, write_fault, writable, &pfn, page))
 		return pfn;
 
 	if (atomic)
 		return KVM_PFN_ERR_FAULT;
 
 	npages = hva_to_pfn_slow(addr, async, write_fault, interruptible,
-				 writable, &pfn);
+				 writable, &pfn, page);
 	if (npages == 1)
 		return pfn;
 	if (npages == -EINTR)
@@ -2700,9 +2673,37 @@  kvm_pfn_t hva_to_pfn(unsigned long addr, bool atomic, bool interruptible,
 	return pfn;
 }
 
-kvm_pfn_t __gfn_to_pfn_memslot(const struct kvm_memory_slot *slot, gfn_t gfn,
-			       bool atomic, bool interruptible, bool *async,
-			       bool write_fault, bool *writable, hva_t *hva)
+/*
+ * Helper function for managing refcounts of pfn returned by hva_to_pfn.
+ * @pfn: pfn returned by hva_to_pfn
+ * @page: page outparam from hva_to_pfn
+ *
+ * In cases where access to the pfn resolved by hva_to_pfn isn't protected by
+ * our MMU notifier, if the pfn was resolved by hva_to_pfn_remapped instead of
+ * gup, then its refcount needs to be bumped.
+ *
+ * Certain IO or PFNMAP mappings can be backed with valid struct pages, but be
+ * allocated without refcounting e.g., tail pages of non-compound higher order
+ * allocations, which would then underflow the refcount when the caller does the
+ * required put_page. Don't allow those pages here.
+ */
+kvm_pfn_t kvm_try_get_refcounted_page_ref(kvm_pfn_t pfn, struct page *page)
+{
+	/* If @page is valid, KVM already has a reference to the pfn/page. */
+	if (page || is_error_pfn(pfn))
+		return pfn;
+
+	page = kvm_pfn_to_refcounted_page(pfn);
+	if (!page || get_page_unless_zero(page))
+		return pfn;
+
+	return KVM_PFN_ERR_FAULT;
+}
+
+kvm_pfn_t __gfn_to_pfn_noref_memslot(const struct kvm_memory_slot *slot, gfn_t gfn,
+				     bool atomic, bool interruptible, bool *async,
+				     bool write_fault, bool *writable, hva_t *hva,
+				     struct page **page)
 {
 	unsigned long addr = __gfn_to_hva_many(slot, gfn, NULL, write_fault);
 
@@ -2728,47 +2729,134 @@  kvm_pfn_t __gfn_to_pfn_memslot(const struct kvm_memory_slot *slot, gfn_t gfn,
 	}
 
 	return hva_to_pfn(addr, atomic, interruptible, async, write_fault,
-			  writable);
+			  writable, page);
+}
+EXPORT_SYMBOL_GPL(__gfn_to_pfn_noref_memslot);
+
+kvm_pfn_t gfn_to_pfn_noref_prot(struct kvm *kvm, gfn_t gfn, bool write_fault,
+				bool *writable, struct page **page)
+{
+	return __gfn_to_pfn_noref_memslot(gfn_to_memslot(kvm, gfn), gfn, false, false,
+					  NULL, write_fault, writable, NULL, page);
+}
+EXPORT_SYMBOL_GPL(gfn_to_pfn_noref_prot);
+
+kvm_pfn_t gfn_to_pfn_noref_memslot(const struct kvm_memory_slot *slot, gfn_t gfn,
+				   struct page **page)
+{
+	return __gfn_to_pfn_noref_memslot(slot, gfn, false, false, NULL, true,
+					  NULL, NULL, page);
+}
+EXPORT_SYMBOL_GPL(gfn_to_pfn_noref_memslot);
+
+kvm_pfn_t gfn_to_pfn_noref_memslot_atomic(const struct kvm_memory_slot *slot,
+					  gfn_t gfn, struct page **page)
+{
+	return __gfn_to_pfn_noref_memslot(slot, gfn, true, false, NULL, true, NULL,
+					  NULL, page);
+}
+EXPORT_SYMBOL_GPL(gfn_to_pfn_noref_memslot_atomic);
+
+kvm_pfn_t kvm_vcpu_gfn_to_pfn_noref_atomic(struct kvm_vcpu *vcpu, gfn_t gfn,
+					   struct page **page)
+{
+	return gfn_to_pfn_noref_memslot_atomic(
+			kvm_vcpu_gfn_to_memslot(vcpu, gfn), gfn, page);
+}
+EXPORT_SYMBOL_GPL(kvm_vcpu_gfn_to_pfn_noref_atomic);
+
+kvm_pfn_t gfn_to_pfn_noref(struct kvm *kvm, gfn_t gfn, struct page **page)
+{
+	return gfn_to_pfn_noref_memslot(gfn_to_memslot(kvm, gfn), gfn, page);
+}
+EXPORT_SYMBOL_GPL(gfn_to_pfn_noref);
+
+kvm_pfn_t kvm_vcpu_gfn_to_pfn_noref(struct kvm_vcpu *vcpu, gfn_t gfn,
+				    struct page **page)
+{
+	return gfn_to_pfn_noref_memslot(kvm_vcpu_gfn_to_memslot(vcpu, gfn),
+					gfn, page);
+}
+EXPORT_SYMBOL_GPL(kvm_vcpu_gfn_to_pfn_noref);
+
+kvm_pfn_t __gfn_to_pfn_memslot(const struct kvm_memory_slot *slot, gfn_t gfn,
+			       bool atomic, bool interruptible, bool *async,
+			       bool write_fault, bool *writable, hva_t *hva)
+{
+	struct page *page;
+	kvm_pfn_t pfn;
+
+	pfn = __gfn_to_pfn_noref_memslot(slot, gfn, atomic, interruptible, async,
+					 write_fault, writable, hva, &page);
+
+	return kvm_try_get_refcounted_page_ref(pfn, page);
 }
 EXPORT_SYMBOL_GPL(__gfn_to_pfn_memslot);
 
 kvm_pfn_t gfn_to_pfn_prot(struct kvm *kvm, gfn_t gfn, bool write_fault,
 		      bool *writable)
 {
-	return __gfn_to_pfn_memslot(gfn_to_memslot(kvm, gfn), gfn, false, false,
-				    NULL, write_fault, writable, NULL);
+	struct page *page;
+	kvm_pfn_t pfn;
+
+	pfn = gfn_to_pfn_noref_prot(kvm, gfn, write_fault, writable, &page);
+
+	return kvm_try_get_refcounted_page_ref(pfn, page);
 }
 EXPORT_SYMBOL_GPL(gfn_to_pfn_prot);
 
 kvm_pfn_t gfn_to_pfn_memslot(const struct kvm_memory_slot *slot, gfn_t gfn)
 {
-	return __gfn_to_pfn_memslot(slot, gfn, false, false, NULL, true,
-				    NULL, NULL);
+	struct page *page;
+	kvm_pfn_t pfn;
+
+	pfn = gfn_to_pfn_noref_memslot(slot, gfn, &page);
+
+	return kvm_try_get_refcounted_page_ref(pfn, page);
 }
 EXPORT_SYMBOL_GPL(gfn_to_pfn_memslot);
 
 kvm_pfn_t gfn_to_pfn_memslot_atomic(const struct kvm_memory_slot *slot, gfn_t gfn)
 {
-	return __gfn_to_pfn_memslot(slot, gfn, true, false, NULL, true,
-				    NULL, NULL);
+	struct page *page;
+	kvm_pfn_t pfn;
+
+	pfn = gfn_to_pfn_noref_memslot_atomic(slot, gfn, &page);
+
+	return kvm_try_get_refcounted_page_ref(pfn, page);
 }
 EXPORT_SYMBOL_GPL(gfn_to_pfn_memslot_atomic);
 
 kvm_pfn_t kvm_vcpu_gfn_to_pfn_atomic(struct kvm_vcpu *vcpu, gfn_t gfn)
 {
-	return gfn_to_pfn_memslot_atomic(kvm_vcpu_gfn_to_memslot(vcpu, gfn), gfn);
+	struct page *page;
+	kvm_pfn_t pfn;
+
+	pfn = kvm_vcpu_gfn_to_pfn_noref_atomic(vcpu, gfn, &page);
+
+	return kvm_try_get_refcounted_page_ref(pfn, page);
 }
 EXPORT_SYMBOL_GPL(kvm_vcpu_gfn_to_pfn_atomic);
 
 kvm_pfn_t gfn_to_pfn(struct kvm *kvm, gfn_t gfn)
 {
-	return gfn_to_pfn_memslot(gfn_to_memslot(kvm, gfn), gfn);
+	struct page *page;
+	kvm_pfn_t pfn;
+
+	pfn = gfn_to_pfn_noref(kvm, gfn, &page);
+
+	return kvm_try_get_refcounted_page_ref(pfn, page);
 }
 EXPORT_SYMBOL_GPL(gfn_to_pfn);
 
 kvm_pfn_t kvm_vcpu_gfn_to_pfn(struct kvm_vcpu *vcpu, gfn_t gfn)
 {
-	return gfn_to_pfn_memslot(kvm_vcpu_gfn_to_memslot(vcpu, gfn), gfn);
+	struct page *page;
+	kvm_pfn_t pfn;
+
+	pfn = kvm_vcpu_gfn_to_pfn_noref(vcpu, gfn, &page);
+
+	return kvm_try_get_refcounted_page_ref(pfn, page);
 }
 EXPORT_SYMBOL_GPL(kvm_vcpu_gfn_to_pfn);
 
@@ -2925,6 +3013,17 @@  void kvm_release_pfn_clean(kvm_pfn_t pfn)
 }
 EXPORT_SYMBOL_GPL(kvm_release_pfn_clean);
 
+void kvm_release_pfn_noref_clean(kvm_pfn_t pfn, struct page *page)
+{
+	if (is_error_noslot_pfn(pfn))
+		return;
+
+	kvm_set_pfn_accessed(pfn);
+	if (page)
+		put_page(page);
+}
+EXPORT_SYMBOL_GPL(kvm_release_pfn_noref_clean);
+
 void kvm_release_page_dirty(struct page *page)
 {
 	WARN_ON(is_error_page(page));
diff --git a/virt/kvm/kvm_mm.h b/virt/kvm/kvm_mm.h
index 180f1a09e6ba..a4072cc5a189 100644
--- a/virt/kvm/kvm_mm.h
+++ b/virt/kvm/kvm_mm.h
@@ -3,6 +3,8 @@ 
 #ifndef __KVM_MM_H__
 #define __KVM_MM_H__ 1
 
+#include <linux/mm_types.h>
+
 /*
  * Architectures can choose whether to use an rwlock or spinlock
  * for the mmu_lock.  These macros, for use in common code
@@ -21,7 +23,9 @@ 
 #endif /* KVM_HAVE_MMU_RWLOCK */
 
 kvm_pfn_t hva_to_pfn(unsigned long addr, bool atomic, bool interruptible,
-		     bool *async, bool write_fault, bool *writable);
+		     bool *async, bool write_fault, bool *writable,
+		     struct page **page);
+kvm_pfn_t kvm_try_get_refcounted_page_ref(kvm_pfn_t pfn, struct page *page);
 
 #ifdef CONFIG_HAVE_KVM_PFNCACHE
 void gfn_to_pfn_cache_invalidate_start(struct kvm *kvm,
diff --git a/virt/kvm/pfncache.c b/virt/kvm/pfncache.c
index 2d6aba677830..e25d3af969f4 100644
--- a/virt/kvm/pfncache.c
+++ b/virt/kvm/pfncache.c
@@ -144,6 +144,7 @@  static kvm_pfn_t hva_to_pfn_retry(struct gfn_to_pfn_cache *gpc)
 	kvm_pfn_t new_pfn = KVM_PFN_ERR_FAULT;
 	void *new_khva = NULL;
 	unsigned long mmu_seq;
+	struct page *page;
 
 	lockdep_assert_held(&gpc->refresh_lock);
 
@@ -183,10 +184,19 @@  static kvm_pfn_t hva_to_pfn_retry(struct gfn_to_pfn_cache *gpc)
 		}
 
 		/* We always request a writeable mapping */
-		new_pfn = hva_to_pfn(gpc->uhva, false, false, NULL, true, NULL);
+		new_pfn = hva_to_pfn(gpc->uhva, false, false, NULL, true, NULL, &page);
 		if (is_error_noslot_pfn(new_pfn))
 			goto out_error;
 
+		/*
+		 * Filter out pages that support refcounting but which aren't
+		 * currently being refcounted. Some KVM MMUs support such pages, but
+		 * although we could support them here, kvm internals more generally
+		 * don't. Reject them here for consistency.
+		 */
+		if (kvm_try_get_refcounted_page_ref(new_pfn, page) != new_pfn)
+			goto out_error;
+
 		/*
 		 * Obtain a new kernel mapping if KVM itself will access the
 		 * pfn.  Note, kmap() and memremap() can both sleep, so this