[RFC] sched: Introduce mm_cid runqueue cache

Message ID 20230327195318.137094-1-mathieu.desnoyers@efficios.com
State New
Headers
Series [RFC] sched: Introduce mm_cid runqueue cache |

Commit Message

Mathieu Desnoyers March 27, 2023, 7:53 p.m. UTC
  Introduce a per-runqueue cache containing { mm, mm_cid } entries.
Keep track of the recently allocated mm_cid for each mm rather than
freeing them immediately. This eliminates most atomic ops when
context switching back and forth between threads belonging to
different memory spaces in multi-threaded scenarios (many processes,
each with many threads).

Signed-off-by: Mathieu Desnoyers <mathieu.desnoyers@efficios.com>
Cc: Aaron Lu <aaron.lu@intel.com>
Cc: Peter Zijlstra <peterz@infradead.org>
---
 kernel/sched/core.c     |  45 +++++++++----
 kernel/sched/deadline.c |   3 +
 kernel/sched/fair.c     |   1 +
 kernel/sched/rt.c       |   2 +
 kernel/sched/sched.h    | 138 ++++++++++++++++++++++++++++++++++------
 5 files changed, 158 insertions(+), 31 deletions(-)
  

Comments

Liu, Yujie April 4, 2023, 9:15 a.m. UTC | #1
Hello,

kernel test robot noticed "BUG:KASAN:slab-use-after-free_in__lock_acquire" on:

commit: 1ed2ac17a591daac640ef7149cdc3c8e0870e474 ("[RFC PATCH] sched: Introduce mm_cid runqueue cache")
url: https://github.com/intel-lab-lkp/linux/commits/Mathieu-Desnoyers/sched-Introduce-mm_cid-runqueue-cache/20230328-035418
base: https://git.kernel.org/cgit/linux/kernel/git/tip/tip.git 05bfb338fa8dd40b008ce443e397fc374f6bd107
patch link: https://lore.kernel.org/all/20230327195318.137094-1-mathieu.desnoyers@efficios.com/
patch subject: [RFC PATCH] sched: Introduce mm_cid runqueue cache

in testcase: kernel-selftests
version: kernel-selftests-x86_64-60acb023-1_20230329
with following parameters:

	group: net

test-description: The kernel contains a set of "self tests" under the tools/testing/selftests/ directory. These are intended to be small unit tests to exercise individual code paths in the kernel.
test-url: https://www.kernel.org/doc/Documentation/kselftest.txt

compiler: gcc-11
test machine: 8 threads Intel(R) Core(TM) i7-6700 CPU @ 3.40GHz (Skylake) with 28G memory

(please refer to attached dmesg/kmsg for entire log/backtrace)


If you fix the issue, kindly add following tag
| Reported-by: kernel test robot <yujie.liu@intel.com>
| Link: https://lore.kernel.org/oe-lkp/202304041648.ed32a338-yujie.liu@intel.com


[ 1109.619462][T29663] ==================================================================
[ 1109.627355][T29663] BUG: KASAN: slab-use-after-free in __lock_acquire+0x1f45/0x2390
[ 1109.634978][T29663] Read of size 8 at addr ffff888214d05430 by task dmesg/29663
[ 1109.642245][T29663] 
[ 1109.644420][T29663] CPU: 6 PID: 29663 Comm: dmesg Not tainted 6.3.0-rc3-00009-g1ed2ac17a591 #1
[ 1109.652983][T29663] Hardware name: Dell Inc. OptiPlex 7040/0Y7WYT, BIOS 1.2.8 01/26/2016
[ 1109.661026][T29663] Call Trace:
[ 1109.664149][T29663]  <TASK>
[ 1109.666926][T29663]  dump_stack_lvl+0x4b/0x80
[ 1109.671261][T29663]  print_address_description+0x2c/0x3d0
[ 1109.677668][T29663]  print_report+0xb5/0x270
[ 1109.681916][T29663]  ? kasan_addr_to_slab+0xd/0xa0
[ 1109.686681][T29663]  ? __lock_acquire+0x1f45/0x2390
[ 1109.691533][T29663]  kasan_report+0xc5/0xf0
[ 1109.695708][T29663]  ? __lock_acquire+0x1f45/0x2390
[ 1109.700585][T29663]  ? __lock_acquire+0x1f45/0x2390
[ 1109.705438][T29663]  ? __lock_acquire+0x15c3/0x2390
[ 1109.710289][T29663]  ? mark_usage+0x2a0/0x2a0
[ 1109.714638][T29663]  ? lock_acquire+0x19d/0x4c0
[ 1109.719145][T29663]  ? mm_cid_get+0x221/0x4f0
[ 1109.724084][T29663]  ? lock_release+0x200/0x200
[ 1109.728589][T29663]  ? lock_downgrade+0x100/0x100
[ 1109.733270][T29663]  ? do_raw_spin_lock+0x137/0x280
[ 1109.738120][T29663]  ? spin_bug+0x1d0/0x1d0
[ 1109.742283][T29663]  ? _raw_spin_lock+0x30/0x40
[ 1109.746788][T29663]  ? mm_cid_get+0x221/0x4f0
[ 1109.751725][T29663]  ? mm_cid_get+0x221/0x4f0
[ 1109.756664][T29663]  ? sched_mm_cid_after_execve+0x1c2/0x4e0
[ 1109.762293][T29663]  ? bprm_execve+0x1b9/0x5e0
[ 1109.766714][T29663]  ? do_execveat_common+0x4cc/0x6b0
[ 1109.772343][T29663]  ? getname_flags+0x8e/0x450
[ 1109.777454][T29663]  ? __x64_sys_execve+0x8c/0xb0
[ 1109.782132][T29663]  ? do_syscall_64+0x5a/0x80
[ 1109.786566][T29663]  ? entry_SYSCALL_64_after_hwframe+0x5e/0xc8
[ 1109.792471][T29663]  </TASK>
[ 1109.795334][T29663] 
[ 1109.797526][T29663] Allocated by task 23542:
[ 1109.801773][T29663]  kasan_save_stack+0x27/0x50
[ 1109.806276][T29663]  kasan_set_track+0x25/0x30
[ 1109.810694][T29663]  __kasan_slab_alloc+0x55/0x60
[ 1109.815371][T29663]  kmem_cache_alloc+0x190/0x360
[ 1109.820045][T29663]  dup_mm+0x22/0x310
[ 1109.824809][T29663]  copy_process+0x52d0/0x5520
[ 1109.829313][T29663]  kernel_clone+0xc8/0x5d0
[ 1109.833580][T29663]  __do_sys_clone+0xa6/0xe0
[ 1109.837911][T29663]  do_syscall_64+0x5a/0x80
[ 1109.842160][T29663]  entry_SYSCALL_64_after_hwframe+0x5e/0xc8
[ 1109.847883][T29663] 
[ 1109.850057][T29663] Freed by task 29602:
[ 1109.853963][T29663]  kasan_save_stack+0x27/0x50
[ 1109.858475][T29663]  kasan_set_track+0x25/0x30
[ 1109.862901][T29663]  kasan_save_free_info+0x2e/0x40
[ 1109.867760][T29663]  __kasan_slab_free+0x10a/0x190
[ 1109.872541][T29663]  slab_free_freelist_hook+0xba/0x170
[ 1109.877739][T29663]  kmem_cache_free+0x1a4/0x300
[ 1109.882329][T29663]  finish_task_switch+0x556/0x910
[ 1109.887783][T29663]  __schedule+0x751/0x1740
[ 1109.892029][T29663]  schedule+0x13e/0x230
[ 1109.896015][T29663]  wait_for_partner+0x15d/0x320
[ 1109.900694][T29663]  fifo_open+0x8a3/0xa10
[ 1109.904766][T29663]  do_dentry_open+0x449/0x1020
[ 1109.909356][T29663]  do_open+0x678/0xf70
[ 1109.913257][T29663]  path_openat+0x25f/0x650
[ 1109.917521][T29663]  do_filp_open+0x1ba/0x3f0
[ 1109.921852][T29663]  do_sys_openat2+0x127/0x400
[ 1109.926355][T29663]  __x64_sys_openat+0x128/0x1e0
[ 1109.931032][T29663]  do_syscall_64+0x5a/0x80
[ 1109.935276][T29663]  entry_SYSCALL_64_after_hwframe+0x5e/0xc8
[ 1109.940988][T29663] 
[ 1109.943159][T29663] The buggy address belongs to the object at ffff888214d05380
[ 1109.943159][T29663]  which belongs to the cache mm_struct of size 2168
[ 1109.956897][T29663] The buggy address is located 176 bytes inside of
[ 1109.956897][T29663]  freed 2168-byte region [ffff888214d05380, ffff888214d05bf8)
[ 1109.970558][T29663] 
[ 1109.972733][T29663] The buggy address belongs to the physical page:
[ 1109.978962][T29663] page:ffffea0008534000 refcount:1 mapcount:0 mapping:0000000000000000 index:0x0 pfn:0x214d00
[ 1109.988990][T29663] head:ffffea0008534000 order:3 entire_mapcount:0 nr_pages_mapped:0 pincount:0
[ 1109.997724][T29663] memcg:ffff8881f2640681
[ 1110.001796][T29663] flags: 0x17ffffc0010200(slab|head|node=0|zone=2|lastcpupid=0x1fffff)
[ 1110.009844][T29663] raw: 0017ffffc0010200 ffff888100052340 ffffea001d4d8410 ffffea000d38be10
[ 1110.018233][T29663] raw: 0000000000000000 00000000000d000d 00000001ffffffff ffff8881f2640681
[ 1110.026621][T29663] page dumped because: kasan: bad access detected
[ 1110.032851][T29663] 
[ 1110.035025][T29663] Memory state around the buggy address:
[ 1110.040478][T29663]  ffff888214d05300: fc fc fc fc fc fc fc fc fc fc fc fc fc fc fc fc
[ 1110.048348][T29663]  ffff888214d05380: fa fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb
[ 1110.056218][T29663] >ffff888214d05400: fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb
[ 1110.064087][T29663]                                      ^
[ 1110.069544][T29663]  ffff888214d05480: fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb
[ 1110.077410][T29663]  ffff888214d05500: fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb fb
[ 1110.085289][T29663] ==================================================================
[ 1110.093168][T29663] Disabling lock debugging due to kernel taint
  

Patch

diff --git a/kernel/sched/core.c b/kernel/sched/core.c
index 0d18c3969f90..e91fc3b810e1 100644
--- a/kernel/sched/core.c
+++ b/kernel/sched/core.c
@@ -2329,6 +2329,7 @@  static struct rq *move_queued_task(struct rq *rq, struct rq_flags *rf,
 	lockdep_assert_rq_held(rq);
 
 	deactivate_task(rq, p, DEQUEUE_NOCLOCK);
+	rq_cid_cache_remove_mm_locked(rq, p->mm, false);
 	set_task_cpu(p, new_cpu);
 	rq_unlock(rq, rf);
 
@@ -2516,6 +2517,7 @@  int push_cpu_stop(void *arg)
 	// XXX validate p is still the highest prio task
 	if (task_rq(p) == rq) {
 		deactivate_task(rq, p, 0);
+		rq_cid_cache_remove_mm_locked(rq, p->mm, false);
 		set_task_cpu(p, lowest_rq->cpu);
 		activate_task(lowest_rq, p, 0);
 		resched_curr(lowest_rq);
@@ -3215,6 +3217,7 @@  static void __migrate_swap_task(struct task_struct *p, int cpu)
 		rq_pin_lock(dst_rq, &drf);
 
 		deactivate_task(src_rq, p, 0);
+		rq_cid_cache_remove_mm_locked(src_rq, p->mm, false);
 		set_task_cpu(p, cpu);
 		activate_task(dst_rq, p, 0);
 		check_preempt_curr(dst_rq, p, 0);
@@ -3852,6 +3855,8 @@  static void __ttwu_queue_wakelist(struct task_struct *p, int cpu, int wake_flags
 	p->sched_remote_wakeup = !!(wake_flags & WF_MIGRATED);
 
 	WRITE_ONCE(rq->ttwu_pending, 1);
+	if (WARN_ON_ONCE(task_cpu(p) != cpu_of(rq)))
+		rq_cid_cache_remove_mm(task_rq(p), p->mm, false);
 	__smp_call_single_queue(cpu, &p->wake_entry.llist);
 }
 
@@ -4269,6 +4274,7 @@  try_to_wake_up(struct task_struct *p, unsigned int state, int wake_flags)
 
 		wake_flags |= WF_MIGRATED;
 		psi_ttwu_dequeue(p);
+		rq_cid_cache_remove_mm(task_rq(p), p->mm, false);
 		set_task_cpu(p, cpu);
 	}
 #else
@@ -5114,7 +5120,7 @@  prepare_task_switch(struct rq *rq, struct task_struct *prev,
 	sched_info_switch(rq, prev, next);
 	perf_event_task_sched_out(prev, next);
 	rseq_preempt(prev);
-	switch_mm_cid(prev, next);
+	switch_mm_cid(rq, prev, next);
 	fire_sched_out_preempt_notifiers(prev, next);
 	kmap_local_sched_out();
 	prepare_task(next);
@@ -6253,6 +6259,7 @@  static bool try_steal_cookie(int this, int that)
 			goto next;
 
 		deactivate_task(src, p, 0);
+		rq_cid_cache_remove_mm_locked(src, p->mm, false);
 		set_task_cpu(p, this);
 		activate_task(dst, p, 0);
 
@@ -11386,42 +11393,54 @@  void call_trace_sched_update_nr_running(struct rq *rq, int count)
 void sched_mm_cid_exit_signals(struct task_struct *t)
 {
 	struct mm_struct *mm = t->mm;
-	unsigned long flags;
+	struct rq_flags rf;
+	struct rq *rq;
 
 	if (!mm)
 		return;
-	local_irq_save(flags);
-	mm_cid_put(mm, t->mm_cid);
+	preempt_disable();
+	rq = this_rq();
+	rq_lock_irqsave(rq, &rf);
 	t->mm_cid = -1;
 	t->mm_cid_active = 0;
-	local_irq_restore(flags);
+	rq_cid_cache_remove_mm_locked(rq, mm, true);
+	rq_unlock_irqrestore(rq, &rf);
+	preempt_enable();
 }
 
 void sched_mm_cid_before_execve(struct task_struct *t)
 {
 	struct mm_struct *mm = t->mm;
-	unsigned long flags;
+	struct rq_flags rf;
+	struct rq *rq;
 
 	if (!mm)
 		return;
-	local_irq_save(flags);
-	mm_cid_put(mm, t->mm_cid);
+	preempt_disable();
+	rq = this_rq();
+	rq_lock_irqsave(rq, &rf);
 	t->mm_cid = -1;
 	t->mm_cid_active = 0;
-	local_irq_restore(flags);
+	rq_cid_cache_remove_mm_locked(rq, mm, true);
+	rq_unlock_irqrestore(rq, &rf);
+	preempt_enable();
 }
 
 void sched_mm_cid_after_execve(struct task_struct *t)
 {
 	struct mm_struct *mm = t->mm;
-	unsigned long flags;
+	struct rq_flags rf;
+	struct rq *rq;
 
 	if (!mm)
 		return;
-	local_irq_save(flags);
-	t->mm_cid = mm_cid_get(mm);
+	preempt_disable();
+	rq = this_rq();
+	rq_lock_irqsave(rq, &rf);
+	t->mm_cid = mm_cid_get(rq, mm);
 	t->mm_cid_active = 1;
-	local_irq_restore(flags);
+	rq_unlock_irqrestore(rq, &rf);
+	preempt_enable();
 	rseq_set_notify_resume(t);
 }
 
diff --git a/kernel/sched/deadline.c b/kernel/sched/deadline.c
index 71b24371a6f7..34bb47442912 100644
--- a/kernel/sched/deadline.c
+++ b/kernel/sched/deadline.c
@@ -729,6 +729,7 @@  static struct rq *dl_task_offline_migration(struct rq *rq, struct task_struct *p
 	__dl_add(dl_b, p->dl.dl_bw, cpumask_weight(later_rq->rd->span));
 	raw_spin_unlock(&dl_b->lock);
 
+	rq_cid_cache_remove_mm_locked(rq, p->mm, false);
 	set_task_cpu(p, later_rq->cpu);
 	double_unlock_balance(later_rq, rq);
 
@@ -2357,6 +2358,7 @@  static int push_dl_task(struct rq *rq)
 	}
 
 	deactivate_task(rq, next_task, 0);
+	rq_cid_cache_remove_mm_locked(rq, next_task->mm, false);
 	set_task_cpu(next_task, later_rq->cpu);
 	activate_task(later_rq, next_task, 0);
 	ret = 1;
@@ -2445,6 +2447,7 @@  static void pull_dl_task(struct rq *this_rq)
 				push_task = get_push_task(src_rq);
 			} else {
 				deactivate_task(src_rq, p, 0);
+				rq_cid_cache_remove_mm_locked(src_rq, p->mm, false);
 				set_task_cpu(p, this_cpu);
 				activate_task(this_rq, p, 0);
 				dmin = p->dl.deadline;
diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c
index 6986ea31c984..70ed6aef87ec 100644
--- a/kernel/sched/fair.c
+++ b/kernel/sched/fair.c
@@ -8542,6 +8542,7 @@  static void detach_task(struct task_struct *p, struct lb_env *env)
 	lockdep_assert_rq_held(env->src_rq);
 
 	deactivate_task(env->src_rq, p, DEQUEUE_NOCLOCK);
+	rq_cid_cache_remove_mm_locked(env->src_rq, p->mm, false);
 	set_task_cpu(p, env->dst_cpu);
 }
 
diff --git a/kernel/sched/rt.c b/kernel/sched/rt.c
index 0a11f44adee5..3ad325db1db3 100644
--- a/kernel/sched/rt.c
+++ b/kernel/sched/rt.c
@@ -2156,6 +2156,7 @@  static int push_rt_task(struct rq *rq, bool pull)
 	}
 
 	deactivate_task(rq, next_task, 0);
+	rq_cid_cache_remove_mm_locked(rq, next_task->mm, false);
 	set_task_cpu(next_task, lowest_rq->cpu);
 	activate_task(lowest_rq, next_task, 0);
 	resched_curr(lowest_rq);
@@ -2429,6 +2430,7 @@  static void pull_rt_task(struct rq *this_rq)
 				push_task = get_push_task(src_rq);
 			} else {
 				deactivate_task(src_rq, p, 0);
+				rq_cid_cache_remove_mm_locked(src_rq, p->mm, false);
 				set_task_cpu(p, this_cpu);
 				activate_task(this_rq, p, 0);
 				resched = true;
diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h
index 3e8df6d31c1e..b2e12857e2c3 100644
--- a/kernel/sched/sched.h
+++ b/kernel/sched/sched.h
@@ -947,6 +947,19 @@  struct balance_callback {
 	void (*func)(struct rq *rq);
 };
 
+#ifdef CONFIG_SCHED_MM_CID
+# define RQ_CID_CACHE_SIZE    8
+struct rq_cid_entry {
+	struct mm_struct *mm;   /* NULL if unset */
+	int mm_cid;
+};
+
+struct rq_cid_cache {
+	struct rq_cid_entry entry[RQ_CID_CACHE_SIZE];
+	unsigned int head;
+};
+#endif
+
 /*
  * This is the main, per-CPU runqueue data structure.
  *
@@ -1161,6 +1174,9 @@  struct rq {
 	call_single_data_t	cfsb_csd;
 	struct list_head	cfsb_csd_list;
 #endif
+#ifdef CONFIG_SCHED_MM_CID
+	struct rq_cid_cache	cid_cache;
+#endif
 };
 
 #ifdef CONFIG_FAIR_GROUP_SCHED
@@ -3249,6 +3265,92 @@  static inline void update_current_exec_runtime(struct task_struct *curr,
 }
 
 #ifdef CONFIG_SCHED_MM_CID
+
+static inline void mm_cid_put(struct mm_struct *mm, int cid)
+{
+	lockdep_assert_irqs_disabled();
+	if (cid < 0)
+		return;
+	raw_spin_lock(&mm->cid_lock);
+	__cpumask_clear_cpu(cid, mm_cidmask(mm));
+	raw_spin_unlock(&mm->cid_lock);
+}
+
+static inline struct rq_cid_entry *rq_cid_cache_lookup(struct rq *rq, struct mm_struct *mm)
+{
+	struct rq_cid_cache *cid_cache = &rq->cid_cache;
+	int i;
+
+	for (i = 0; i < RQ_CID_CACHE_SIZE; i++) {
+		struct rq_cid_entry *entry = &cid_cache->entry[i];
+
+		if (entry->mm == mm)
+			return entry;
+	}
+	return NULL;
+}
+
+/* Removal from cache simply leaves an unused hole. */
+static inline int rq_cid_cache_lookup_remove(struct rq *rq, struct mm_struct *mm)
+{
+	struct rq_cid_entry *entry = rq_cid_cache_lookup(rq, mm);
+
+	if (!entry)
+		return -1;
+	entry->mm = NULL;       /* Remove from cache */
+	return entry->mm_cid;
+}
+
+static inline void rq_cid_cache_remove_mm_locked(struct rq *rq, struct mm_struct *mm, bool release_mm)
+{
+	int cid;
+
+	if (!mm)
+		return;
+	/*
+	 * Do not remove the cache entry for a runqueue that runs a task which
+	 * currently uses the target mm.
+	 */
+	if (!release_mm && rq->curr->mm == mm)
+		return;
+	cid = rq_cid_cache_lookup_remove(rq, mm);
+	mm_cid_put(mm, cid);
+}
+
+static inline void rq_cid_cache_remove_mm(struct rq *rq, struct mm_struct *mm, bool release_mm)
+{
+	struct rq_flags rf;
+
+	rq_lock_irqsave(rq, &rf);
+	rq_cid_cache_remove_mm_locked(rq, mm, release_mm);
+	rq_unlock_irqrestore(rq, &rf);
+}
+
+/*
+  * Add at head, move head forward. Cheap LRU cache.
+  * Only need to clear the cid mask bit from its own mm_cidmask(mm) when we
+  * overwrite an old entry from the cache. Note that this is not needed if the
+  * overwritten entry is an unused hole. This access to the old_mm from an
+  * unrelated thread requires that cache entry for a given mm gets pruned from
+  * the cache when a task is dequeued from the runqueue.
+  */
+static inline void rq_cid_cache_add(struct rq *rq, struct mm_struct *mm, int cid)
+{
+	struct rq_cid_cache *cid_cache = &rq->cid_cache;
+	struct mm_struct *old_mm;
+	struct rq_cid_entry *entry;
+	unsigned int pos;
+
+	pos = cid_cache->head;
+	entry = &cid_cache->entry[pos];
+	old_mm = entry->mm;
+	if (old_mm)
+		mm_cid_put(old_mm, entry->mm_cid);
+	entry->mm = mm;
+	entry->mm_cid = cid;
+	cid_cache->head = (pos + 1) % RQ_CID_CACHE_SIZE;
+}
+
 static inline int __mm_cid_get(struct mm_struct *mm)
 {
 	struct cpumask *cpumask;
@@ -3262,28 +3364,26 @@  static inline int __mm_cid_get(struct mm_struct *mm)
 	return cid;
 }
 
-static inline void mm_cid_put(struct mm_struct *mm, int cid)
+static inline int mm_cid_get(struct rq *rq, struct mm_struct *mm)
 {
-	lockdep_assert_irqs_disabled();
-	if (cid < 0)
-		return;
-	raw_spin_lock(&mm->cid_lock);
-	__cpumask_clear_cpu(cid, mm_cidmask(mm));
-	raw_spin_unlock(&mm->cid_lock);
-}
-
-static inline int mm_cid_get(struct mm_struct *mm)
-{
-	int ret;
+	struct rq_cid_entry *entry;
+	int cid;
 
 	lockdep_assert_irqs_disabled();
+	entry = rq_cid_cache_lookup(rq, mm);
+	if (entry) {
+		cid = entry->mm_cid;
+		goto end;
+	}
 	raw_spin_lock(&mm->cid_lock);
-	ret = __mm_cid_get(mm);
+	cid = __mm_cid_get(mm);
 	raw_spin_unlock(&mm->cid_lock);
-	return ret;
+	rq_cid_cache_add(rq, mm, cid);
+end:
+	return cid;
 }
 
-static inline void switch_mm_cid(struct task_struct *prev, struct task_struct *next)
+static inline void switch_mm_cid(struct rq *rq, struct task_struct *prev, struct task_struct *next)
 {
 	if (prev->mm_cid_active) {
 		if (next->mm_cid_active && next->mm == prev->mm) {
@@ -3295,15 +3395,17 @@  static inline void switch_mm_cid(struct task_struct *prev, struct task_struct *n
 			prev->mm_cid = -1;
 			return;
 		}
-		mm_cid_put(prev->mm, prev->mm_cid);
+		/* Leave the prev mm_cid in the cid rq cache. */
 		prev->mm_cid = -1;
 	}
 	if (next->mm_cid_active)
-		next->mm_cid = mm_cid_get(next->mm);
+		next->mm_cid = mm_cid_get(rq, next->mm);
 }
 
 #else
-static inline void switch_mm_cid(struct task_struct *prev, struct task_struct *next) { }
+static inline void switch_mm_cid(struct rq *rq, struct task_struct *prev, struct task_struct *next) { }
+static inline void rq_cid_cache_remove_mm_locked(struct rq *rq, struct mm_struct *mm, bool release_mm) { }
+static inline void rq_cid_cache_remove_mm(struct rq *rq, struct mm_struct *mm, bool release_mm) { }
 #endif
 
 #endif /* _KERNEL_SCHED_SCHED_H */