[RFC,2/5] mm: Select victim memcg using bpf prog

Message ID 20230727073632.44983-3-zhouchuyi@bytedance.com
State New
Headers
Series mm: Select victim memcg using BPF_OOM_POLICY |

Commit Message

Chuyi Zhou July 27, 2023, 7:36 a.m. UTC
  This patch use BPF prog to bypass the default select_bad_process method
and select a victim memcg when gobal oom is invoked. Specifically, we
iterate root_mem_cgroup's children and select a next iteration root
through __bpf_run_oom_policy(). Repeat until we finally find a leaf
memcg in the last layer. Then we use oom_evaluate_task() to find a
victim task in the selected memcg. If there are no suitable process
to be killed in the memcg, we go back to the default method.

Suggested-by: Abel Wu <wuyun.abel@bytedance.com>
Signed-off-by: Chuyi Zhou <zhouchuyi@bytedance.com>
---
 include/linux/memcontrol.h |  6 +++++
 mm/memcontrol.c            | 50 ++++++++++++++++++++++++++++++++++++++
 mm/oom_kill.c              | 17 +++++++++++++
 3 files changed, 73 insertions(+)
  

Patch

diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 5818af8eca5a..7fedc2521c8b 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -1155,6 +1155,7 @@  unsigned long mem_cgroup_soft_limit_reclaim(pg_data_t *pgdat, int order,
 						gfp_t gfp_mask,
 						unsigned long *total_scanned);
 
+struct mem_cgroup *select_victim_memcg(void);
 #else /* CONFIG_MEMCG */
 
 #define MEM_CGROUP_ID_SHIFT	0
@@ -1588,6 +1589,11 @@  unsigned long mem_cgroup_soft_limit_reclaim(pg_data_t *pgdat, int order,
 {
 	return 0;
 }
+
+static inline struct mem_cgroup *select_victim_memcg(void)
+{
+	return NULL;
+}
 #endif /* CONFIG_MEMCG */
 
 static inline void __inc_lruvec_kmem_state(void *p, enum node_stat_item idx)
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index e8ca4bdcb03c..c6b42635f1af 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -64,6 +64,7 @@ 
 #include <linux/psi.h>
 #include <linux/seq_buf.h>
 #include <linux/sched/isolation.h>
+#include <linux/bpf_oom.h>
 #include "internal.h"
 #include <net/sock.h>
 #include <net/ip.h>
@@ -2638,6 +2639,55 @@  void mem_cgroup_handle_over_high(void)
 	css_put(&memcg->css);
 }
 
+struct mem_cgroup *select_victim_memcg(void)
+{
+	struct cgroup_subsys_state *pos, *parent, *victim;
+	struct mem_cgroup *victim_memcg;
+
+	parent = &root_mem_cgroup->css;
+	victim_memcg = NULL;
+
+	if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+		return NULL;
+
+	rcu_read_lock();
+	while (parent) {
+		struct cgroup_subsys_state *chosen = NULL;
+		struct mem_cgroup *pos_mem, *chosen_mem;
+		u64 chosen_id, pos_id;
+		int cmp_ret;
+
+		victim = parent;
+
+		list_for_each_entry_rcu(pos, &parent->children, sibling) {
+			pos_id = cgroup_id(pos->cgroup);
+			if (!chosen)
+				goto chose;
+
+			cmp_ret = __bpf_run_oom_policy(chosen_id, pos_id);
+			if (cmp_ret == BPF_OOM_CMP_GREATER)
+				continue;
+			if (cmp_ret == BPF_OOM_CMP_EQUAL) {
+				pos_mem = mem_cgroup_from_css(pos);
+				chosen_mem = mem_cgroup_from_css(chosen);
+				if (page_counter_read(&pos_mem->memory) <=
+					page_counter_read(&chosen_mem->memory))
+					continue;
+			}
+chose:
+			chosen = pos;
+			chosen_id = pos_id;
+		}
+		parent = chosen;
+	}
+
+	if (victim && css_tryget(victim))
+		victim_memcg = mem_cgroup_from_css(victim);
+	rcu_read_unlock();
+
+	return victim_memcg;
+}
+
 static int try_charge_memcg(struct mem_cgroup *memcg, gfp_t gfp_mask,
 			unsigned int nr_pages)
 {
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 01af8adaa16c..b88c8c7d4ee4 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -361,6 +361,19 @@  static int oom_evaluate_task(struct task_struct *task, void *arg)
 	return 1;
 }
 
+static bool bpf_select_bad_process(struct oom_control *oc)
+{
+	struct mem_cgroup *victim_memcg;
+
+	victim_memcg = select_victim_memcg();
+	if (victim_memcg) {
+		mem_cgroup_scan_tasks(victim_memcg, oom_evaluate_task, oc);
+		css_put(&victim_memcg->css);
+	}
+
+	return !!oc->chosen;
+}
+
 /*
  * Simple selection loop. We choose the process with the highest number of
  * 'points'. In case scan was aborted, oc->chosen is set to -1.
@@ -372,6 +385,9 @@  static void select_bad_process(struct oom_control *oc)
 	if (is_memcg_oom(oc))
 		mem_cgroup_scan_tasks(oc->memcg, oom_evaluate_task, oc);
 	else {
+		if (bpf_oom_policy_enabled() && bpf_select_bad_process(oc))
+			return;
+
 		struct task_struct *p;
 
 		rcu_read_lock();
@@ -1426,3 +1442,4 @@  bool bpf_oom_policy_enabled(void)
 	rcu_read_unlock();
 	return !empty;
 }
+