[RFC] sched: Introduce per-mm/cpu concurrency id state

Message ID 20230330230911.228720-1-mathieu.desnoyers@efficios.com
State New
Headers
Series [RFC] sched: Introduce per-mm/cpu concurrency id state |

Commit Message

Mathieu Desnoyers March 30, 2023, 11:09 p.m. UTC
  Keep track of the currently allocated mm_cid for each mm/cpu rather than
freeing them immediately. This eliminates most atomic ops when context
switching back and forth between threads belonging to different memory
spaces in multi-threaded scenarios (many processes, each with many
threads).

This patch is based on v6.3-rc4 with this patch applied:

("mm: Fix memory leak on mm_init error handling")

https://lore.kernel.org/lkml/20230330133822.66271-1-mathieu.desnoyers@efficios.com/

Signed-off-by: Mathieu Desnoyers <mathieu.desnoyers@efficios.com>
Cc: Aaron Lu <aaron.lu@intel.com>
Cc: Peter Zijlstra <peterz@infradead.org>
---
 include/linux/mm_types.h | 32 ++++++++++++++++
 kernel/fork.c            |  7 +++-
 kernel/sched/core.c      | 79 ++++++++++++++++++++++++++++++++++-----
 kernel/sched/sched.h     | 81 ++++++++++++++++++++++++++++++----------
 4 files changed, 169 insertions(+), 30 deletions(-)
  

Comments

Aaron Lu March 31, 2023, 8:38 a.m. UTC | #1
On Thu, Mar 30, 2023 at 07:09:11PM -0400, Mathieu Desnoyers wrote:

>  void sched_mm_cid_exit_signals(struct task_struct *t)
>  {
>  	struct mm_struct *mm = t->mm;
> -	unsigned long flags;
> +	struct rq *rq = this_rq();

Got many below messages due to the above line:

[   19.294089] BUG: using smp_processor_id() in preemptible [00000000] code: kworker/u449:0/1621

> +	struct rq_flags rf;
>  
>  	if (!mm)
>  		return;
> -	local_irq_save(flags);
> +	rq_lock_irqsave(rq, &rf);
>  	mm_cid_put(mm, t->mm_cid);
>  	t->mm_cid = -1;
>  	t->mm_cid_active = 0;
> -	local_irq_restore(flags);
> +	rq_unlock_irqrestore(rq, &rf);
>  }
>  
>  void sched_mm_cid_before_execve(struct task_struct *t)
>  {
>  	struct mm_struct *mm = t->mm;
> -	unsigned long flags;
> +	struct rq *rq = this_rq();

Also here;

> +	struct rq_flags rf;
>  
>  	if (!mm)
>  		return;
> -	local_irq_save(flags);
> +	rq_lock_irqsave(rq, &rf);
>  	mm_cid_put(mm, t->mm_cid);
>  	t->mm_cid = -1;
>  	t->mm_cid_active = 0;
> -	local_irq_restore(flags);
> +	rq_unlock_irqrestore(rq, &rf);
>  }
>  
>  void sched_mm_cid_after_execve(struct task_struct *t)
>  {
>  	struct mm_struct *mm = t->mm;
> -	unsigned long flags;
> +	struct rq *rq = this_rq();

And here.

> +	struct rq_flags rf;
>  
>  	if (!mm)
>  		return;
> -	local_irq_save(flags);
> +	rq_lock_irqsave(rq, &rf);
>  	t->mm_cid = mm_cid_get(mm);
>  	t->mm_cid_active = 1;
> -	local_irq_restore(flags);
> +	rq_unlock_irqrestore(rq, &rf);
>  	rseq_set_notify_resume(t);
>  }

I used below diff to get rid of these messages without understanding the
purpose of these functions:

diff --git a/kernel/sched/core.c b/kernel/sched/core.c
index f07b87d155bd..7194c29f3c91 100644
--- a/kernel/sched/core.c
+++ b/kernel/sched/core.c
@@ -11444,45 +11444,57 @@ void sched_mm_cid_migrate_to(struct rq *dst_rq, struct task_struct *t, int src_c
 void sched_mm_cid_exit_signals(struct task_struct *t)
 {
 	struct mm_struct *mm = t->mm;
-	struct rq *rq = this_rq();
 	struct rq_flags rf;
+	struct rq *rq;
 
 	if (!mm)
 		return;
+
+	preempt_disable();
+	rq = this_rq();
 	rq_lock_irqsave(rq, &rf);
 	mm_cid_put(mm, t->mm_cid);
 	t->mm_cid = -1;
 	t->mm_cid_active = 0;
 	rq_unlock_irqrestore(rq, &rf);
+	preempt_enable();
 }
 
 void sched_mm_cid_before_execve(struct task_struct *t)
 {
 	struct mm_struct *mm = t->mm;
-	struct rq *rq = this_rq();
 	struct rq_flags rf;
+	struct rq *rq;
 
 	if (!mm)
 		return;
+
+	preempt_disable();
+	rq = this_rq();
 	rq_lock_irqsave(rq, &rf);
 	mm_cid_put(mm, t->mm_cid);
 	t->mm_cid = -1;
 	t->mm_cid_active = 0;
 	rq_unlock_irqrestore(rq, &rf);
+	preempt_enable();
 }
 
 void sched_mm_cid_after_execve(struct task_struct *t)
 {
 	struct mm_struct *mm = t->mm;
-	struct rq *rq = this_rq();
 	struct rq_flags rf;
+	struct rq *rq;
 
 	if (!mm)
 		return;
+
+	preempt_disable();
+	rq = this_rq();
 	rq_lock_irqsave(rq, &rf);
 	t->mm_cid = mm_cid_get(mm);
 	t->mm_cid_active = 1;
 	rq_unlock_irqrestore(rq, &rf);
+	preempt_enable();
 	rseq_set_notify_resume(t);
 }
  
Aaron Lu March 31, 2023, 8:52 a.m. UTC | #2
On Thu, Mar 30, 2023 at 07:09:11PM -0400, Mathieu Desnoyers wrote:
> Keep track of the currently allocated mm_cid for each mm/cpu rather than
> freeing them immediately. This eliminates most atomic ops when context
> switching back and forth between threads belonging to different memory
> spaces in multi-threaded scenarios (many processes, each with many
> threads).

Good news, the lock contention is now gone and back to v6.2 level:

node0_0.profile:     0.07%     0.07%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node0_1.profile:     0.06%     0.06%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node0_2.profile:     0.09%     0.09%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node0_3.profile:     0.08%     0.08%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node0_4.profile:     0.09%     0.09%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node0_5.profile:     0.10%     0.10%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node0_6.profile:     0.10%     0.10%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node0_7.profile:     0.07%     0.07%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node0_8.profile:     0.08%     0.08%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node0_9.profile:     0.06%     0.06%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node1_0.profile:     0.41%     0.41%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node1_1.profile:     0.38%     0.38%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node1_2.profile:     0.44%     0.44%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node1_3.profile:     5.64%     5.64%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node1_4.profile:     6.08%     6.08%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node1_5.profile:     3.45%     3.45%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node1_6.profile:     2.09%     2.09%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node1_7.profile:     2.72%     2.72%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node1_8.profile:     0.16%     0.16%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
node1_9.profile:     0.15%     0.15%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
(those few profiles from node1's cpus that have more than 2% contention
are from thermal functions)

Tested-by: Aaron Lu <aaron.lu@intel.com> # lock contention part
  
Mathieu Desnoyers March 31, 2023, 11:56 p.m. UTC | #3
On 2023-03-31 04:38, Aaron Lu wrote:
> On Thu, Mar 30, 2023 at 07:09:11PM -0400, Mathieu Desnoyers wrote:
> 
>>   void sched_mm_cid_exit_signals(struct task_struct *t)
>>   {
>>   	struct mm_struct *mm = t->mm;
>> -	unsigned long flags;
>> +	struct rq *rq = this_rq();
> 
> Got many below messages due to the above line:
> 
> [   19.294089] BUG: using smp_processor_id() in preemptible [00000000] code: kworker/u449:0/1621
> 
>> +	struct rq_flags rf;
>>   
>>   	if (!mm)
>>   		return;
>> -	local_irq_save(flags);
>> +	rq_lock_irqsave(rq, &rf);
>>   	mm_cid_put(mm, t->mm_cid);
>>   	t->mm_cid = -1;
>>   	t->mm_cid_active = 0;
>> -	local_irq_restore(flags);
>> +	rq_unlock_irqrestore(rq, &rf);
>>   }
>>   
>>   void sched_mm_cid_before_execve(struct task_struct *t)
>>   {
>>   	struct mm_struct *mm = t->mm;
>> -	unsigned long flags;
>> +	struct rq *rq = this_rq();
> 
> Also here;
> 
>> +	struct rq_flags rf;
>>   
>>   	if (!mm)
>>   		return;
>> -	local_irq_save(flags);
>> +	rq_lock_irqsave(rq, &rf);
>>   	mm_cid_put(mm, t->mm_cid);
>>   	t->mm_cid = -1;
>>   	t->mm_cid_active = 0;
>> -	local_irq_restore(flags);
>> +	rq_unlock_irqrestore(rq, &rf);
>>   }
>>   
>>   void sched_mm_cid_after_execve(struct task_struct *t)
>>   {
>>   	struct mm_struct *mm = t->mm;
>> -	unsigned long flags;
>> +	struct rq *rq = this_rq();
> 
> And here.
> 
>> +	struct rq_flags rf;
>>   
>>   	if (!mm)
>>   		return;
>> -	local_irq_save(flags);
>> +	rq_lock_irqsave(rq, &rf);
>>   	t->mm_cid = mm_cid_get(mm);
>>   	t->mm_cid_active = 1;
>> -	local_irq_restore(flags);
>> +	rq_unlock_irqrestore(rq, &rf);
>>   	rseq_set_notify_resume(t);
>>   }
> 
> I used below diff to get rid of these messages without understanding the
> purpose of these functions:

I'll fold this fix into the next round, thanks!

Mathieu

> 
> diff --git a/kernel/sched/core.c b/kernel/sched/core.c
> index f07b87d155bd..7194c29f3c91 100644
> --- a/kernel/sched/core.c
> +++ b/kernel/sched/core.c
> @@ -11444,45 +11444,57 @@ void sched_mm_cid_migrate_to(struct rq *dst_rq, struct task_struct *t, int src_c
>   void sched_mm_cid_exit_signals(struct task_struct *t)
>   {
>   	struct mm_struct *mm = t->mm;
> -	struct rq *rq = this_rq();
>   	struct rq_flags rf;
> +	struct rq *rq;
>   
>   	if (!mm)
>   		return;
> +
> +	preempt_disable();
> +	rq = this_rq();
>   	rq_lock_irqsave(rq, &rf);
>   	mm_cid_put(mm, t->mm_cid);
>   	t->mm_cid = -1;
>   	t->mm_cid_active = 0;
>   	rq_unlock_irqrestore(rq, &rf);
> +	preempt_enable();
>   }
>   
>   void sched_mm_cid_before_execve(struct task_struct *t)
>   {
>   	struct mm_struct *mm = t->mm;
> -	struct rq *rq = this_rq();
>   	struct rq_flags rf;
> +	struct rq *rq;
>   
>   	if (!mm)
>   		return;
> +
> +	preempt_disable();
> +	rq = this_rq();
>   	rq_lock_irqsave(rq, &rf);
>   	mm_cid_put(mm, t->mm_cid);
>   	t->mm_cid = -1;
>   	t->mm_cid_active = 0;
>   	rq_unlock_irqrestore(rq, &rf);
> +	preempt_enable();
>   }
>   
>   void sched_mm_cid_after_execve(struct task_struct *t)
>   {
>   	struct mm_struct *mm = t->mm;
> -	struct rq *rq = this_rq();
>   	struct rq_flags rf;
> +	struct rq *rq;
>   
>   	if (!mm)
>   		return;
> +
> +	preempt_disable();
> +	rq = this_rq();
>   	rq_lock_irqsave(rq, &rf);
>   	t->mm_cid = mm_cid_get(mm);
>   	t->mm_cid_active = 1;
>   	rq_unlock_irqrestore(rq, &rf);
> +	preempt_enable();
>   	rseq_set_notify_resume(t);
>   }
>
  
Mathieu Desnoyers April 3, 2023, 6:17 p.m. UTC | #4
On 2023-03-31 04:52, Aaron Lu wrote:
> On Thu, Mar 30, 2023 at 07:09:11PM -0400, Mathieu Desnoyers wrote:
>> Keep track of the currently allocated mm_cid for each mm/cpu rather than
>> freeing them immediately. This eliminates most atomic ops when context
>> switching back and forth between threads belonging to different memory
>> spaces in multi-threaded scenarios (many processes, each with many
>> threads).
> 
> Good news, the lock contention is now gone and back to v6.2 level:

Hi Aaron,

Can you please test the updated patch I've sent ? I have updated the
subject to make it clear that this is a fix for a performance regression,
improved comments, and it now passes more thorough testing. See:

https://lore.kernel.org/lkml/20230403181342.210896-1-mathieu.desnoyers@efficios.com/

Thanks,

Mathieu

> 
> node0_0.profile:     0.07%     0.07%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node0_1.profile:     0.06%     0.06%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node0_2.profile:     0.09%     0.09%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node0_3.profile:     0.08%     0.08%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node0_4.profile:     0.09%     0.09%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node0_5.profile:     0.10%     0.10%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node0_6.profile:     0.10%     0.10%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node0_7.profile:     0.07%     0.07%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node0_8.profile:     0.08%     0.08%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node0_9.profile:     0.06%     0.06%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node1_0.profile:     0.41%     0.41%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node1_1.profile:     0.38%     0.38%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node1_2.profile:     0.44%     0.44%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node1_3.profile:     5.64%     5.64%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node1_4.profile:     6.08%     6.08%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node1_5.profile:     3.45%     3.45%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node1_6.profile:     2.09%     2.09%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node1_7.profile:     2.72%     2.72%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node1_8.profile:     0.16%     0.16%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> node1_9.profile:     0.15%     0.15%  [kernel.vmlinux]        [k] native_queued_spin_lock_slowpath
> (those few profiles from node1's cpus that have more than 2% contention
> are from thermal functions)
> 
> Tested-by: Aaron Lu <aaron.lu@intel.com> # lock contention part
  

Patch

diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index 0722859c3647..335af2da5b34 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -609,6 +609,7 @@  struct mm_struct {
 		 * were being concurrently updated by the updaters.
 		 */
 		raw_spinlock_t cid_lock;
+		int __percpu *pcpu_cid;
 #endif
 #ifdef CONFIG_MMU
 		atomic_long_t pgtables_bytes;	/* size of all page tables */
@@ -872,6 +873,16 @@  static inline void vma_iter_init(struct vma_iterator *vmi,
 }
 
 #ifdef CONFIG_SCHED_MM_CID
+
+enum pcpu_cid_state {
+	PCPU_CID_UNSET = -1U,
+};
+
+static inline bool pcpu_cid_is_unset(int cid)
+{
+	return cid == PCPU_CID_UNSET;
+}
+
 /* Accessor for struct mm_struct's cidmask. */
 static inline cpumask_t *mm_cidmask(struct mm_struct *mm)
 {
@@ -885,16 +896,37 @@  static inline cpumask_t *mm_cidmask(struct mm_struct *mm)
 
 static inline void mm_init_cid(struct mm_struct *mm)
 {
+	int i;
+
 	raw_spin_lock_init(&mm->cid_lock);
+	for_each_possible_cpu(i)
+		*per_cpu_ptr(mm->pcpu_cid, i) = PCPU_CID_UNSET;
 	cpumask_clear(mm_cidmask(mm));
 }
 
+static inline int mm_alloc_cid(struct mm_struct *mm)
+{
+	mm->pcpu_cid = alloc_percpu(int);
+	if (!mm->pcpu_cid)
+		return -ENOMEM;
+	mm_init_cid(mm);
+	return 0;
+}
+
+static inline void mm_destroy_cid(struct mm_struct *mm)
+{
+	free_percpu(mm->pcpu_cid);
+	mm->pcpu_cid = NULL;
+}
+
 static inline unsigned int mm_cid_size(void)
 {
 	return cpumask_size();
 }
 #else /* CONFIG_SCHED_MM_CID */
 static inline void mm_init_cid(struct mm_struct *mm) { }
+static inline int mm_alloc_cid(struct mm_struct *mm) { return 0; }
+static inline void mm_destroy_cid(struct mm_struct *mm) { }
 static inline unsigned int mm_cid_size(void)
 {
 	return 0;
diff --git a/kernel/fork.c b/kernel/fork.c
index c983c4fe3090..57fdc96ffa49 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -790,6 +790,7 @@  void __mmdrop(struct mm_struct *mm)
 	check_mm(mm);
 	put_user_ns(mm->user_ns);
 	mm_pasid_drop(mm);
+	mm_destroy_cid(mm);
 
 	for (i = 0; i < NR_MM_COUNTERS; i++)
 		percpu_counter_destroy(&mm->rss_stat[i]);
@@ -1159,18 +1160,22 @@  static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p,
 	if (init_new_context(p, mm))
 		goto fail_nocontext;
 
+	if (mm_alloc_cid(mm))
+		goto fail_cid;
+
 	for (i = 0; i < NR_MM_COUNTERS; i++)
 		if (percpu_counter_init(&mm->rss_stat[i], 0, GFP_KERNEL_ACCOUNT))
 			goto fail_pcpu;
 
 	mm->user_ns = get_user_ns(user_ns);
 	lru_gen_init_mm(mm);
-	mm_init_cid(mm);
 	return mm;
 
 fail_pcpu:
 	while (i > 0)
 		percpu_counter_destroy(&mm->rss_stat[--i]);
+	mm_destroy_cid(mm);
+fail_cid:
 	destroy_context(mm);
 fail_nocontext:
 	mm_free_pgd(mm);
diff --git a/kernel/sched/core.c b/kernel/sched/core.c
index 0d18c3969f90..f07b87d155bd 100644
--- a/kernel/sched/core.c
+++ b/kernel/sched/core.c
@@ -2326,16 +2326,20 @@  static inline bool is_cpu_allowed(struct task_struct *p, int cpu)
 static struct rq *move_queued_task(struct rq *rq, struct rq_flags *rf,
 				   struct task_struct *p, int new_cpu)
 {
+	int cid;
+
 	lockdep_assert_rq_held(rq);
 
 	deactivate_task(rq, p, DEQUEUE_NOCLOCK);
 	set_task_cpu(p, new_cpu);
+	cid = sched_mm_cid_migrate_from(rq, p);
 	rq_unlock(rq, rf);
 
 	rq = cpu_rq(new_cpu);
 
 	rq_lock(rq, rf);
 	WARN_ON_ONCE(task_cpu(p) != new_cpu);
+	sched_mm_cid_migrate_to(rq, p, cid);
 	activate_task(rq, p, 0);
 	check_preempt_curr(rq, p, 0);
 
@@ -11383,45 +11387,102 @@  void call_trace_sched_update_nr_running(struct rq *rq, int count)
 }
 
 #ifdef CONFIG_SCHED_MM_CID
+/*
+ * Migration is from src cpu to dst cpu. Always called from stopper thread on
+ * src cpu with rq lock held.
+ */
+int sched_mm_cid_migrate_from(struct rq *src_rq, struct task_struct *t)
+{
+	struct mm_struct *mm = t->mm;
+	int src_cpu, src_cid;
+	int *src_pcpu_cid;
+
+	if (!mm)
+		return PCPU_CID_UNSET;
+
+	src_cpu = cpu_of(src_rq);
+	src_pcpu_cid = per_cpu_ptr(mm->pcpu_cid, src_cpu);
+	src_cid = *src_pcpu_cid;
+	if (pcpu_cid_is_unset(src_cid)) {
+		/* src_cid is unset, nothing to clear/grab. */
+		return PCPU_CID_UNSET;
+	}
+	/* Set to PCPU_CID_UNSET, grab ownership. */
+	*src_pcpu_cid = PCPU_CID_UNSET;
+	return src_cid;
+}
+
+void sched_mm_cid_migrate_to(struct rq *dst_rq, struct task_struct *t, int src_cid)
+{
+	struct mm_struct *mm = t->mm;
+	int dst_cpu, dst_cid;
+	int *dst_pcpu_cid;
+
+	if (!mm || pcpu_cid_is_unset(src_cid))
+		return;
+
+	dst_cpu = cpu_of(dst_rq);
+	dst_pcpu_cid = per_cpu_ptr(mm->pcpu_cid, dst_cpu);
+
+	/* *dst_pcpu_cid = min(src_cid, *dst_pcpu_cid) */
+	dst_cid = *dst_pcpu_cid;
+	if (!pcpu_cid_is_unset(dst_cid) && dst_cid < src_cid) {
+		__mm_cid_put(mm, src_cid);
+		return;
+	}
+	*dst_pcpu_cid = src_cid;
+	if (!pcpu_cid_is_unset(dst_cid)) {
+		/*
+		 * Put dst_cid if not currently in use, else it will be
+		 * lazy put.
+		 */
+		if (dst_rq->curr->mm != mm)
+			__mm_cid_put(mm, dst_cid);
+	}
+}
+
 void sched_mm_cid_exit_signals(struct task_struct *t)
 {
 	struct mm_struct *mm = t->mm;
-	unsigned long flags;
+	struct rq *rq = this_rq();
+	struct rq_flags rf;
 
 	if (!mm)
 		return;
-	local_irq_save(flags);
+	rq_lock_irqsave(rq, &rf);
 	mm_cid_put(mm, t->mm_cid);
 	t->mm_cid = -1;
 	t->mm_cid_active = 0;
-	local_irq_restore(flags);
+	rq_unlock_irqrestore(rq, &rf);
 }
 
 void sched_mm_cid_before_execve(struct task_struct *t)
 {
 	struct mm_struct *mm = t->mm;
-	unsigned long flags;
+	struct rq *rq = this_rq();
+	struct rq_flags rf;
 
 	if (!mm)
 		return;
-	local_irq_save(flags);
+	rq_lock_irqsave(rq, &rf);
 	mm_cid_put(mm, t->mm_cid);
 	t->mm_cid = -1;
 	t->mm_cid_active = 0;
-	local_irq_restore(flags);
+	rq_unlock_irqrestore(rq, &rf);
 }
 
 void sched_mm_cid_after_execve(struct task_struct *t)
 {
 	struct mm_struct *mm = t->mm;
-	unsigned long flags;
+	struct rq *rq = this_rq();
+	struct rq_flags rf;
 
 	if (!mm)
 		return;
-	local_irq_save(flags);
+	rq_lock_irqsave(rq, &rf);
 	t->mm_cid = mm_cid_get(mm);
 	t->mm_cid_active = 1;
-	local_irq_restore(flags);
+	rq_unlock_irqrestore(rq, &rf);
 	rseq_set_notify_resume(t);
 }
 
diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h
index 3e8df6d31c1e..7b93847b89a3 100644
--- a/kernel/sched/sched.h
+++ b/kernel/sched/sched.h
@@ -3249,7 +3249,47 @@  static inline void update_current_exec_runtime(struct task_struct *curr,
 }
 
 #ifdef CONFIG_SCHED_MM_CID
-static inline int __mm_cid_get(struct mm_struct *mm)
+extern int sched_mm_cid_migrate_from(struct rq *src_rq, struct task_struct *t);
+extern void sched_mm_cid_migrate_to(struct rq *dst_rq, struct task_struct *t, int cid);
+
+static inline void __mm_cid_put(struct mm_struct *mm, int cid)
+{
+	lockdep_assert_irqs_disabled();
+	if (cid < 0)
+		return;
+	raw_spin_lock(&mm->cid_lock);
+	__cpumask_clear_cpu(cid, mm_cidmask(mm));
+	raw_spin_unlock(&mm->cid_lock);
+}
+
+static inline void mm_cid_put(struct mm_struct *mm, int thread_cid)
+{
+	int *pcpu_cid, cid;
+
+	lockdep_assert_irqs_disabled();
+	if (thread_cid < 0)
+		return;
+	pcpu_cid = this_cpu_ptr(mm->pcpu_cid);
+	cid = *pcpu_cid;
+	if (cid == thread_cid)
+		*pcpu_cid = PCPU_CID_UNSET;
+	__mm_cid_put(mm, thread_cid);
+}
+
+static inline void mm_cid_put_lazy(struct mm_struct *mm, int thread_cid)
+{
+	int *pcpu_cid, cid;
+
+	lockdep_assert_irqs_disabled();
+	if (thread_cid < 0)
+		return;
+	pcpu_cid = this_cpu_ptr(mm->pcpu_cid);
+	cid = *pcpu_cid;
+	if (cid != thread_cid)
+		__mm_cid_put(mm, thread_cid);
+}
+
+static inline int __mm_cid_get_locked(struct mm_struct *mm)
 {
 	struct cpumask *cpumask;
 	int cid;
@@ -3262,40 +3302,38 @@  static inline int __mm_cid_get(struct mm_struct *mm)
 	return cid;
 }
 
-static inline void mm_cid_put(struct mm_struct *mm, int cid)
+static inline int __mm_cid_get(struct mm_struct *mm)
 {
+	int ret;
+
 	lockdep_assert_irqs_disabled();
-	if (cid < 0)
-		return;
 	raw_spin_lock(&mm->cid_lock);
-	__cpumask_clear_cpu(cid, mm_cidmask(mm));
+	ret = __mm_cid_get_locked(mm);
 	raw_spin_unlock(&mm->cid_lock);
+	return ret;
 }
 
 static inline int mm_cid_get(struct mm_struct *mm)
 {
-	int ret;
+	int *pcpu_cid, cid;
 
 	lockdep_assert_irqs_disabled();
-	raw_spin_lock(&mm->cid_lock);
-	ret = __mm_cid_get(mm);
-	raw_spin_unlock(&mm->cid_lock);
-	return ret;
+	pcpu_cid = this_cpu_ptr(mm->pcpu_cid);
+	cid = *pcpu_cid;
+	if (pcpu_cid_is_unset(cid)) {
+		raw_spin_lock(&mm->cid_lock);
+		cid = __mm_cid_get_locked(mm);
+		raw_spin_unlock(&mm->cid_lock);
+		*pcpu_cid = cid;
+		return cid;
+	}
+	return cid;
 }
 
 static inline void switch_mm_cid(struct task_struct *prev, struct task_struct *next)
 {
 	if (prev->mm_cid_active) {
-		if (next->mm_cid_active && next->mm == prev->mm) {
-			/*
-			 * Context switch between threads in same mm, hand over
-			 * the mm_cid from prev to next.
-			 */
-			next->mm_cid = prev->mm_cid;
-			prev->mm_cid = -1;
-			return;
-		}
-		mm_cid_put(prev->mm, prev->mm_cid);
+		mm_cid_put_lazy(prev->mm, prev->mm_cid);
 		prev->mm_cid = -1;
 	}
 	if (next->mm_cid_active)
@@ -3304,6 +3342,9 @@  static inline void switch_mm_cid(struct task_struct *prev, struct task_struct *n
 
 #else
 static inline void switch_mm_cid(struct task_struct *prev, struct task_struct *next) { }
+static inline void sched_mm_cid_migrate(struct rq *rq, struct task_struct *t, int new_cpu) { }
+static inline int sched_mm_cid_migrate_from(struct rq *src_rq, struct task_struct *t) { return 0; }
+static inline void sched_mm_cid_migrate_to(struct rq *src_rq, struct task_struct *t, int cid) { }
 #endif
 
 #endif /* _KERNEL_SCHED_SCHED_H */