[PATCHv11,04/16] x86/mm: Handle LAM on context switch

Message ID 20221025001722.17466-5-kirill.shutemov@linux.intel.com
State New
Headers
Series Linear Address Masking enabling |

Commit Message

Kirill A. Shutemov Oct. 25, 2022, 12:17 a.m. UTC
  Linear Address Masking mode for userspace pointers encoded in CR3 bits.
The mode is selected per-process and stored in mm_context_t.

switch_mm_irqs_off() now respects selected LAM mode and constructs CR3
accordingly.

The active LAM mode gets recorded in the tlb_state.

Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
Tested-by: Alexander Potapenko <glider@google.com>
Acked-by: Peter Zijlstra (Intel) <peterz@infradead.org>
---
 arch/x86/include/asm/mmu.h         |  3 ++
 arch/x86/include/asm/mmu_context.h | 24 +++++++++++++++
 arch/x86/include/asm/tlbflush.h    | 34 +++++++++++++++++++++
 arch/x86/mm/tlb.c                  | 48 ++++++++++++++++++++----------
 4 files changed, 93 insertions(+), 16 deletions(-)
  

Comments

Andy Lutomirski Nov. 7, 2022, 2:58 p.m. UTC | #1
On 10/24/22 17:17, Kirill A. Shutemov wrote:
> Linear Address Masking mode for userspace pointers encoded in CR3 bits.
> The mode is selected per-process and stored in mm_context_t.
> 
> switch_mm_irqs_off() now respects selected LAM mode and constructs CR3
> accordingly.
> 
> The active LAM mode gets recorded in the tlb_state.
> 
> Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
> Tested-by: Alexander Potapenko <glider@google.com>
> Acked-by: Peter Zijlstra (Intel) <peterz@infradead.org>
> ---
>   arch/x86/include/asm/mmu.h         |  3 ++
>   arch/x86/include/asm/mmu_context.h | 24 +++++++++++++++
>   arch/x86/include/asm/tlbflush.h    | 34 +++++++++++++++++++++
>   arch/x86/mm/tlb.c                  | 48 ++++++++++++++++++++----------
>   4 files changed, 93 insertions(+), 16 deletions(-)
> 
> diff --git a/arch/x86/include/asm/mmu.h b/arch/x86/include/asm/mmu.h
> index 5d7494631ea9..002889ca8978 100644
> --- a/arch/x86/include/asm/mmu.h
> +++ b/arch/x86/include/asm/mmu.h
> @@ -40,6 +40,9 @@ typedef struct {
>   
>   #ifdef CONFIG_X86_64
>   	unsigned short flags;
> +
> +	/* Active LAM mode:  X86_CR3_LAM_U48 or X86_CR3_LAM_U57 or 0 (disabled) */
> +	unsigned long lam_cr3_mask;
>   #endif
>   
>   	struct mutex lock;
> diff --git a/arch/x86/include/asm/mmu_context.h b/arch/x86/include/asm/mmu_context.h
> index b8d40ddeab00..69c943b2ae90 100644
> --- a/arch/x86/include/asm/mmu_context.h
> +++ b/arch/x86/include/asm/mmu_context.h
> @@ -91,6 +91,29 @@ static inline void switch_ldt(struct mm_struct *prev, struct mm_struct *next)
>   }
>   #endif
>   
> +#ifdef CONFIG_X86_64
> +static inline unsigned long mm_lam_cr3_mask(struct mm_struct *mm)
> +{
> +	return mm->context.lam_cr3_mask;
> +}
> +
> +static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
> +{
> +	mm->context.lam_cr3_mask = oldmm->context.lam_cr3_mask;
> +}
> +
> +#else
> +
> +static inline unsigned long mm_lam_cr3_mask(struct mm_struct *mm)
> +{
> +	return 0;
> +}
> +
> +static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
> +{
> +}
> +#endif
> +
>   #define enter_lazy_tlb enter_lazy_tlb
>   extern void enter_lazy_tlb(struct mm_struct *mm, struct task_struct *tsk);
>   
> @@ -168,6 +191,7 @@ static inline int arch_dup_mmap(struct mm_struct *oldmm, struct mm_struct *mm)
>   {
>   	arch_dup_pkeys(oldmm, mm);
>   	paravirt_arch_dup_mmap(oldmm, mm);
> +	dup_lam(oldmm, mm);
>   	return ldt_dup_context(oldmm, mm);
>   }
>   
> diff --git a/arch/x86/include/asm/tlbflush.h b/arch/x86/include/asm/tlbflush.h
> index cda3118f3b27..662598dea937 100644
> --- a/arch/x86/include/asm/tlbflush.h
> +++ b/arch/x86/include/asm/tlbflush.h
> @@ -101,6 +101,16 @@ struct tlb_state {
>   	 */
>   	bool invalidate_other;
>   
> +#ifdef CONFIG_X86_64
> +	/*
> +	 * Active LAM mode.
> +	 *
> +	 * X86_CR3_LAM_U57/U48 shifted right by X86_CR3_LAM_U57_BIT or 0 if LAM
> +	 * disabled.
> +	 */
> +	u8 lam;
> +#endif
> +
>   	/*
>   	 * Mask that contains TLB_NR_DYN_ASIDS+1 bits to indicate
>   	 * the corresponding user PCID needs a flush next time we
> @@ -357,6 +367,30 @@ static inline bool huge_pmd_needs_flush(pmd_t oldpmd, pmd_t newpmd)
>   }
>   #define huge_pmd_needs_flush huge_pmd_needs_flush
>   
> +#ifdef CONFIG_X86_64
> +static inline unsigned long tlbstate_lam_cr3_mask(void)
> +{
> +	unsigned long lam = this_cpu_read(cpu_tlbstate.lam);
> +
> +	return lam << X86_CR3_LAM_U57_BIT;
> +}
> +
> +static inline void set_tlbstate_cr3_lam_mask(unsigned long mask)
> +{
> +	this_cpu_write(cpu_tlbstate.lam, mask >> X86_CR3_LAM_U57_BIT);
> +}
> +
> +#else
> +
> +static inline unsigned long tlbstate_lam_cr3_mask(void)
> +{
> +	return 0;
> +}
> +
> +static inline void set_tlbstate_cr3_lam_mask(u64 mask)
> +{
> +}
> +#endif
>   #endif /* !MODULE */
>   
>   static inline void __native_tlb_flush_global(unsigned long cr4)
> diff --git a/arch/x86/mm/tlb.c b/arch/x86/mm/tlb.c
> index c1e31e9a85d7..d6c9c15d2ad2 100644
> --- a/arch/x86/mm/tlb.c
> +++ b/arch/x86/mm/tlb.c
> @@ -154,26 +154,30 @@ static inline u16 user_pcid(u16 asid)
>   	return ret;
>   }
>   
> -static inline unsigned long build_cr3(pgd_t *pgd, u16 asid)
> +static inline unsigned long build_cr3(pgd_t *pgd, u16 asid, unsigned long lam)
>   {
> +	unsigned long cr3 = __sme_pa(pgd) | lam;
> +
>   	if (static_cpu_has(X86_FEATURE_PCID)) {
> -		return __sme_pa(pgd) | kern_pcid(asid);
> +		VM_WARN_ON_ONCE(asid > MAX_ASID_AVAILABLE);
> +		cr3 |= kern_pcid(asid);
>   	} else {
>   		VM_WARN_ON_ONCE(asid != 0);
> -		return __sme_pa(pgd);
>   	}
> +
> +	return cr3;
>   }
>   
> -static inline unsigned long build_cr3_noflush(pgd_t *pgd, u16 asid)
> +static inline unsigned long build_cr3_noflush(pgd_t *pgd, u16 asid,
> +					      unsigned long lam)
>   {
> -	VM_WARN_ON_ONCE(asid > MAX_ASID_AVAILABLE);
>   	/*
>   	 * Use boot_cpu_has() instead of this_cpu_has() as this function
>   	 * might be called during early boot. This should work even after
>   	 * boot because all CPU's the have same capabilities:
>   	 */
>   	VM_WARN_ON_ONCE(!boot_cpu_has(X86_FEATURE_PCID));
> -	return __sme_pa(pgd) | kern_pcid(asid) | CR3_NOFLUSH;
> +	return build_cr3(pgd, asid, lam) | CR3_NOFLUSH;
>   }
>   
>   /*
> @@ -274,15 +278,16 @@ static inline void invalidate_user_asid(u16 asid)
>   		  (unsigned long *)this_cpu_ptr(&cpu_tlbstate.user_pcid_flush_mask));
>   }
>   
> -static void load_new_mm_cr3(pgd_t *pgdir, u16 new_asid, bool need_flush)
> +static void load_new_mm_cr3(pgd_t *pgdir, u16 new_asid, unsigned long lam,
> +			    bool need_flush)
>   {
>   	unsigned long new_mm_cr3;
>   
>   	if (need_flush) {
>   		invalidate_user_asid(new_asid);
> -		new_mm_cr3 = build_cr3(pgdir, new_asid);
> +		new_mm_cr3 = build_cr3(pgdir, new_asid, lam);
>   	} else {
> -		new_mm_cr3 = build_cr3_noflush(pgdir, new_asid);
> +		new_mm_cr3 = build_cr3_noflush(pgdir, new_asid, lam);
>   	}
>   
>   	/*
> @@ -491,6 +496,8 @@ void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
>   {
>   	struct mm_struct *real_prev = this_cpu_read(cpu_tlbstate.loaded_mm);
>   	u16 prev_asid = this_cpu_read(cpu_tlbstate.loaded_mm_asid);
> +	unsigned long prev_lam = tlbstate_lam_cr3_mask();
> +	unsigned long new_lam = mm_lam_cr3_mask(next);
>   	bool was_lazy = this_cpu_read(cpu_tlbstate_shared.is_lazy);
>   	unsigned cpu = smp_processor_id();
>   	u64 next_tlb_gen;
> @@ -520,7 +527,7 @@ void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
>   	 * isn't free.
>   	 */
>   #ifdef CONFIG_DEBUG_VM
> -	if (WARN_ON_ONCE(__read_cr3() != build_cr3(real_prev->pgd, prev_asid))) {
> +	if (WARN_ON_ONCE(__read_cr3() != build_cr3(real_prev->pgd, prev_asid, prev_lam))) {
>   		/*
>   		 * If we were to BUG here, we'd be very likely to kill
>   		 * the system so hard that we don't see the call trace.
> @@ -554,6 +561,7 @@ void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
>   	if (real_prev == next) {
>   		VM_WARN_ON(this_cpu_read(cpu_tlbstate.ctxs[prev_asid].ctx_id) !=
>   			   next->context.ctx_id);
> +		VM_WARN_ON(prev_lam != new_lam);

What prevents this warning from firing if a remote cpu does 
prctl_enable_tagged_addr() and this cpu hits this code path before 
getting the LAM-enabling IPI?  Conceptually this would be like if we 
asserted that LDTR matched the mm_context's ldt setting in this code path.

I think (haven't really verified) that you can fix this by removing the 
warning and adding a comment explaining that CR3 can be out of sync due 
to a race against changes to LAM settings.  I don't think there's any 
way to eliminate the race -- there is no lock you can take while 
changing lam that prevents a remote CPU from switching mm or scheduling.

--Andy
  
Kirill A. Shutemov Nov. 7, 2022, 5:14 p.m. UTC | #2
On Mon, Nov 07, 2022 at 06:58:59AM -0800, Andy Lutomirski wrote:
> > @@ -554,6 +561,7 @@ void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
> >   	if (real_prev == next) {
> >   		VM_WARN_ON(this_cpu_read(cpu_tlbstate.ctxs[prev_asid].ctx_id) !=
> >   			   next->context.ctx_id);
> > +		VM_WARN_ON(prev_lam != new_lam);
> 
> What prevents this warning from firing if a remote cpu does
> prctl_enable_tagged_addr() and this cpu hits this code path before getting
> the LAM-enabling IPI?  Conceptually this would be like if we asserted that
> LDTR matched the mm_context's ldt setting in this code path.
> 
> I think (haven't really verified) that you can fix this by removing the
> warning and adding a comment explaining that CR3 can be out of sync due to a
> race against changes to LAM settings.  I don't think there's any way to
> eliminate the race -- there is no lock you can take while changing lam that
> prevents a remote CPU from switching mm or scheduling.

Something like this?

diff --git a/arch/x86/mm/tlb.c b/arch/x86/mm/tlb.c
index d6c9c15d2ad2..c6cac1a1bc64 100644
--- a/arch/x86/mm/tlb.c
+++ b/arch/x86/mm/tlb.c
@@ -561,7 +561,15 @@ void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 	if (real_prev == next) {
 		VM_WARN_ON(this_cpu_read(cpu_tlbstate.ctxs[prev_asid].ctx_id) !=
 			   next->context.ctx_id);
-		VM_WARN_ON(prev_lam != new_lam);
+
+		/*
+		 * 'prev_lam' does not necessary match 'new_lam' here. In case
+		 * of race with LAM enabling, the updated 'lam_cr3_mask' can be
+		 * been before LAM-enabling IPI kicks in.
+		 *
+		 * The race is harmless: it is okay to update CR3 with new LAM
+		 * mode. The IPI will rewrite CR3 shortly.
+		 */
 
 		/*
 		 * Even in lazy TLB mode, the CPU should stay set in the
  
Dave Hansen Nov. 7, 2022, 6:02 p.m. UTC | #3
On 11/7/22 09:14, Kirill A. Shutemov wrote:
> --- a/arch/x86/mm/tlb.c
> +++ b/arch/x86/mm/tlb.c
> @@ -561,7 +561,15 @@ void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
>  	if (real_prev == next) {
>  		VM_WARN_ON(this_cpu_read(cpu_tlbstate.ctxs[prev_asid].ctx_id) !=
>  			   next->context.ctx_id);
> -		VM_WARN_ON(prev_lam != new_lam);
> +
> +		/*
> +		 * 'prev_lam' does not necessary match 'new_lam' here. In case
> +		 * of race with LAM enabling, the updated 'lam_cr3_mask' can be
> +		 * been before LAM-enabling IPI kicks in.
> +		 *
> +		 * The race is harmless: it is okay to update CR3 with new LAM
> +		 * mode. The IPI will rewrite CR3 shortly.
> +		 */

So, let's do something like this in switch_mm_irqs_off():

		/* Not actually switching mm's */
		VM_WARN_ON(this_cpu_read(cpu_tlbstate....

		/*
		 * If this races with another thread that enables
		 * lam, 'new_lam' might not match 'prev_lam'.
		 */

Then, in enable_lam_func(), something like this:

	/*
	 * Update CR3 to get LAM active on the CPU
	 *
	 * This might not actually need to update CR3 if a context
	 * switch happened between updating 'lam_cr3_mask' and
	 * running this IPI handler.  Update it unconditionally for
	 * simplicity.
	 */
	cr3 = __read_cr3();
	cr3 &= ~(X86_CR3_LAM_U48 | X86_CR3_LAM_U57);
	cr3 |= lam_mask;
	write_cr3(cr3);
	set_tlbstate_cr3_lam_mask(lam_mask);


I'd much rather get folks thinking about IPI races in the IPI handler
rather than thinking about the IPI handler in the context switch path.

It's kinda silly to be describing the occasional superfluous
enable_lam_func() activity from switch_mm_irqs_off().
  

Patch

diff --git a/arch/x86/include/asm/mmu.h b/arch/x86/include/asm/mmu.h
index 5d7494631ea9..002889ca8978 100644
--- a/arch/x86/include/asm/mmu.h
+++ b/arch/x86/include/asm/mmu.h
@@ -40,6 +40,9 @@  typedef struct {
 
 #ifdef CONFIG_X86_64
 	unsigned short flags;
+
+	/* Active LAM mode:  X86_CR3_LAM_U48 or X86_CR3_LAM_U57 or 0 (disabled) */
+	unsigned long lam_cr3_mask;
 #endif
 
 	struct mutex lock;
diff --git a/arch/x86/include/asm/mmu_context.h b/arch/x86/include/asm/mmu_context.h
index b8d40ddeab00..69c943b2ae90 100644
--- a/arch/x86/include/asm/mmu_context.h
+++ b/arch/x86/include/asm/mmu_context.h
@@ -91,6 +91,29 @@  static inline void switch_ldt(struct mm_struct *prev, struct mm_struct *next)
 }
 #endif
 
+#ifdef CONFIG_X86_64
+static inline unsigned long mm_lam_cr3_mask(struct mm_struct *mm)
+{
+	return mm->context.lam_cr3_mask;
+}
+
+static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
+{
+	mm->context.lam_cr3_mask = oldmm->context.lam_cr3_mask;
+}
+
+#else
+
+static inline unsigned long mm_lam_cr3_mask(struct mm_struct *mm)
+{
+	return 0;
+}
+
+static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
+{
+}
+#endif
+
 #define enter_lazy_tlb enter_lazy_tlb
 extern void enter_lazy_tlb(struct mm_struct *mm, struct task_struct *tsk);
 
@@ -168,6 +191,7 @@  static inline int arch_dup_mmap(struct mm_struct *oldmm, struct mm_struct *mm)
 {
 	arch_dup_pkeys(oldmm, mm);
 	paravirt_arch_dup_mmap(oldmm, mm);
+	dup_lam(oldmm, mm);
 	return ldt_dup_context(oldmm, mm);
 }
 
diff --git a/arch/x86/include/asm/tlbflush.h b/arch/x86/include/asm/tlbflush.h
index cda3118f3b27..662598dea937 100644
--- a/arch/x86/include/asm/tlbflush.h
+++ b/arch/x86/include/asm/tlbflush.h
@@ -101,6 +101,16 @@  struct tlb_state {
 	 */
 	bool invalidate_other;
 
+#ifdef CONFIG_X86_64
+	/*
+	 * Active LAM mode.
+	 *
+	 * X86_CR3_LAM_U57/U48 shifted right by X86_CR3_LAM_U57_BIT or 0 if LAM
+	 * disabled.
+	 */
+	u8 lam;
+#endif
+
 	/*
 	 * Mask that contains TLB_NR_DYN_ASIDS+1 bits to indicate
 	 * the corresponding user PCID needs a flush next time we
@@ -357,6 +367,30 @@  static inline bool huge_pmd_needs_flush(pmd_t oldpmd, pmd_t newpmd)
 }
 #define huge_pmd_needs_flush huge_pmd_needs_flush
 
+#ifdef CONFIG_X86_64
+static inline unsigned long tlbstate_lam_cr3_mask(void)
+{
+	unsigned long lam = this_cpu_read(cpu_tlbstate.lam);
+
+	return lam << X86_CR3_LAM_U57_BIT;
+}
+
+static inline void set_tlbstate_cr3_lam_mask(unsigned long mask)
+{
+	this_cpu_write(cpu_tlbstate.lam, mask >> X86_CR3_LAM_U57_BIT);
+}
+
+#else
+
+static inline unsigned long tlbstate_lam_cr3_mask(void)
+{
+	return 0;
+}
+
+static inline void set_tlbstate_cr3_lam_mask(u64 mask)
+{
+}
+#endif
 #endif /* !MODULE */
 
 static inline void __native_tlb_flush_global(unsigned long cr4)
diff --git a/arch/x86/mm/tlb.c b/arch/x86/mm/tlb.c
index c1e31e9a85d7..d6c9c15d2ad2 100644
--- a/arch/x86/mm/tlb.c
+++ b/arch/x86/mm/tlb.c
@@ -154,26 +154,30 @@  static inline u16 user_pcid(u16 asid)
 	return ret;
 }
 
-static inline unsigned long build_cr3(pgd_t *pgd, u16 asid)
+static inline unsigned long build_cr3(pgd_t *pgd, u16 asid, unsigned long lam)
 {
+	unsigned long cr3 = __sme_pa(pgd) | lam;
+
 	if (static_cpu_has(X86_FEATURE_PCID)) {
-		return __sme_pa(pgd) | kern_pcid(asid);
+		VM_WARN_ON_ONCE(asid > MAX_ASID_AVAILABLE);
+		cr3 |= kern_pcid(asid);
 	} else {
 		VM_WARN_ON_ONCE(asid != 0);
-		return __sme_pa(pgd);
 	}
+
+	return cr3;
 }
 
-static inline unsigned long build_cr3_noflush(pgd_t *pgd, u16 asid)
+static inline unsigned long build_cr3_noflush(pgd_t *pgd, u16 asid,
+					      unsigned long lam)
 {
-	VM_WARN_ON_ONCE(asid > MAX_ASID_AVAILABLE);
 	/*
 	 * Use boot_cpu_has() instead of this_cpu_has() as this function
 	 * might be called during early boot. This should work even after
 	 * boot because all CPU's the have same capabilities:
 	 */
 	VM_WARN_ON_ONCE(!boot_cpu_has(X86_FEATURE_PCID));
-	return __sme_pa(pgd) | kern_pcid(asid) | CR3_NOFLUSH;
+	return build_cr3(pgd, asid, lam) | CR3_NOFLUSH;
 }
 
 /*
@@ -274,15 +278,16 @@  static inline void invalidate_user_asid(u16 asid)
 		  (unsigned long *)this_cpu_ptr(&cpu_tlbstate.user_pcid_flush_mask));
 }
 
-static void load_new_mm_cr3(pgd_t *pgdir, u16 new_asid, bool need_flush)
+static void load_new_mm_cr3(pgd_t *pgdir, u16 new_asid, unsigned long lam,
+			    bool need_flush)
 {
 	unsigned long new_mm_cr3;
 
 	if (need_flush) {
 		invalidate_user_asid(new_asid);
-		new_mm_cr3 = build_cr3(pgdir, new_asid);
+		new_mm_cr3 = build_cr3(pgdir, new_asid, lam);
 	} else {
-		new_mm_cr3 = build_cr3_noflush(pgdir, new_asid);
+		new_mm_cr3 = build_cr3_noflush(pgdir, new_asid, lam);
 	}
 
 	/*
@@ -491,6 +496,8 @@  void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 {
 	struct mm_struct *real_prev = this_cpu_read(cpu_tlbstate.loaded_mm);
 	u16 prev_asid = this_cpu_read(cpu_tlbstate.loaded_mm_asid);
+	unsigned long prev_lam = tlbstate_lam_cr3_mask();
+	unsigned long new_lam = mm_lam_cr3_mask(next);
 	bool was_lazy = this_cpu_read(cpu_tlbstate_shared.is_lazy);
 	unsigned cpu = smp_processor_id();
 	u64 next_tlb_gen;
@@ -520,7 +527,7 @@  void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 	 * isn't free.
 	 */
 #ifdef CONFIG_DEBUG_VM
-	if (WARN_ON_ONCE(__read_cr3() != build_cr3(real_prev->pgd, prev_asid))) {
+	if (WARN_ON_ONCE(__read_cr3() != build_cr3(real_prev->pgd, prev_asid, prev_lam))) {
 		/*
 		 * If we were to BUG here, we'd be very likely to kill
 		 * the system so hard that we don't see the call trace.
@@ -554,6 +561,7 @@  void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 	if (real_prev == next) {
 		VM_WARN_ON(this_cpu_read(cpu_tlbstate.ctxs[prev_asid].ctx_id) !=
 			   next->context.ctx_id);
+		VM_WARN_ON(prev_lam != new_lam);
 
 		/*
 		 * Even in lazy TLB mode, the CPU should stay set in the
@@ -622,15 +630,16 @@  void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 		barrier();
 	}
 
+	set_tlbstate_cr3_lam_mask(new_lam);
 	if (need_flush) {
 		this_cpu_write(cpu_tlbstate.ctxs[new_asid].ctx_id, next->context.ctx_id);
 		this_cpu_write(cpu_tlbstate.ctxs[new_asid].tlb_gen, next_tlb_gen);
-		load_new_mm_cr3(next->pgd, new_asid, true);
+		load_new_mm_cr3(next->pgd, new_asid, new_lam, true);
 
 		trace_tlb_flush(TLB_FLUSH_ON_TASK_SWITCH, TLB_FLUSH_ALL);
 	} else {
 		/* The new ASID is already up to date. */
-		load_new_mm_cr3(next->pgd, new_asid, false);
+		load_new_mm_cr3(next->pgd, new_asid, new_lam, false);
 
 		trace_tlb_flush(TLB_FLUSH_ON_TASK_SWITCH, 0);
 	}
@@ -691,6 +700,10 @@  void initialize_tlbstate_and_flush(void)
 	/* Assert that CR3 already references the right mm. */
 	WARN_ON((cr3 & CR3_ADDR_MASK) != __pa(mm->pgd));
 
+	/* LAM expected to be disabled in CR3 and init_mm */
+	WARN_ON(cr3 & (X86_CR3_LAM_U48 | X86_CR3_LAM_U57));
+	WARN_ON(mm_lam_cr3_mask(&init_mm));
+
 	/*
 	 * Assert that CR4.PCIDE is set if needed.  (CR4.PCIDE initialization
 	 * doesn't work like other CR4 bits because it can only be set from
@@ -699,8 +712,8 @@  void initialize_tlbstate_and_flush(void)
 	WARN_ON(boot_cpu_has(X86_FEATURE_PCID) &&
 		!(cr4_read_shadow() & X86_CR4_PCIDE));
 
-	/* Force ASID 0 and force a TLB flush. */
-	write_cr3(build_cr3(mm->pgd, 0));
+	/* Disable LAM, force ASID 0 and force a TLB flush. */
+	write_cr3(build_cr3(mm->pgd, 0, 0));
 
 	/* Reinitialize tlbstate. */
 	this_cpu_write(cpu_tlbstate.last_user_mm_spec, LAST_USER_MM_INIT);
@@ -708,6 +721,7 @@  void initialize_tlbstate_and_flush(void)
 	this_cpu_write(cpu_tlbstate.next_asid, 1);
 	this_cpu_write(cpu_tlbstate.ctxs[0].ctx_id, mm->context.ctx_id);
 	this_cpu_write(cpu_tlbstate.ctxs[0].tlb_gen, tlb_gen);
+	set_tlbstate_cr3_lam_mask(0);
 
 	for (i = 1; i < TLB_NR_DYN_ASIDS; i++)
 		this_cpu_write(cpu_tlbstate.ctxs[i].ctx_id, 0);
@@ -1071,8 +1085,10 @@  void flush_tlb_kernel_range(unsigned long start, unsigned long end)
  */
 unsigned long __get_current_cr3_fast(void)
 {
-	unsigned long cr3 = build_cr3(this_cpu_read(cpu_tlbstate.loaded_mm)->pgd,
-		this_cpu_read(cpu_tlbstate.loaded_mm_asid));
+	unsigned long cr3 =
+		build_cr3(this_cpu_read(cpu_tlbstate.loaded_mm)->pgd,
+			  this_cpu_read(cpu_tlbstate.loaded_mm_asid),
+			  tlbstate_lam_cr3_mask());
 
 	/* For now, be very restrictive about when this can be called. */
 	VM_WARN_ON(in_nmi() || preemptible());