[3/3] x86/static_call: Add support for Jcc tail-calls

Message ID 20230123210607.173715335@infradead.org
State New
Headers
Series static_call/x86: Handle clang's conditional tail calls |

Commit Message

Peter Zijlstra Jan. 23, 2023, 8:59 p.m. UTC
  Clang likes to create conditional tail calls like:

0000000000000350 <amd_pmu_add_event>:
350:       0f 1f 44 00 00          nopl   0x0(%rax,%rax,1) 351: R_X86_64_NONE      __fentry__-0x4
355:       48 83 bf 20 01 00 00 00         cmpq   $0x0,0x120(%rdi)
35d:       0f 85 00 00 00 00       jne    363 <amd_pmu_add_event+0x13>     35f: R_X86_64_PLT32     __SCT__amd_pmu_branch_add-0x4
363:       e9 00 00 00 00          jmp    368 <amd_pmu_add_event+0x18>     364: R_X86_64_PLT32     __x86_return_thunk-0x4

Teach the in-line static call text patching about this.

Notably, since there is no conditional-ret, in that caes patch the Jcc
to point at an empty stub function that does the ret -- or the return
thunk when needed.

Reported-by: "Erhard F." <erhard_f@mailbox.org>
Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org>
Reviewed-by: Masami Hiramatsu (Google) <mhiramat@kernel.org>
---
 arch/x86/kernel/static_call.c |   50 +++++++++++++++++++++++++++++++++++++++---
 1 file changed, 47 insertions(+), 3 deletions(-)
  

Comments

Steven Rostedt Jan. 23, 2023, 10:44 p.m. UTC | #1
On Mon, 23 Jan 2023 21:59:18 +0100
Peter Zijlstra <peterz@infradead.org> wrote:

> Clang likes to create conditional tail calls like:
> 
> 0000000000000350 <amd_pmu_add_event>:
> 350:       0f 1f 44 00 00          nopl   0x0(%rax,%rax,1) 351: R_X86_64_NONE      __fentry__-0x4
> 355:       48 83 bf 20 01 00 00 00         cmpq   $0x0,0x120(%rdi)
> 35d:       0f 85 00 00 00 00       jne    363 <amd_pmu_add_event+0x13>     35f: R_X86_64_PLT32     __SCT__amd_pmu_branch_add-0x4
> 363:       e9 00 00 00 00          jmp    368 <amd_pmu_add_event+0x18>     364: R_X86_64_PLT32     __x86_return_thunk-0x4
> 

Just to confirm, as it's not clear if this is the static call site or one
of the functions that is being called.

I'm guessing that this is an issue because clang optimizes the static call
site, right?

> Teach the in-line static call text patching about this.
> 
> Notably, since there is no conditional-ret, in that caes patch the Jcc

					      "in that case"

-- Steve

> to point at an empty stub function that does the ret -- or the return
> thunk when needed.
> 
> Reported-by: "Erhard F." <erhard_f@mailbox.org>
> Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org>
> Reviewed-by: Masami Hiramatsu (Google) <mhiramat@kernel.org>
> ---
>  arch/x86/kernel/static_call.c |   50 +++++++++++++++++++++++++++++++++++++++---
>  1 file changed, 47 insertions(+), 3 deletions(-)
> 
> --- a/arch/x86/kernel/static_call.c
> +++ b/arch/x86/kernel/static_call.c
> @@ -9,6 +9,7 @@ enum insn_type {
>  	NOP = 1,  /* site cond-call */
>  	JMP = 2,  /* tramp / site tail-call */
>  	RET = 3,  /* tramp / site cond-tail-call */
> +	JCC = 4,
>  };
>  
>  /*
> @@ -25,12 +26,40 @@ static const u8 xor5rax[] = { 0x2e, 0x2e
>  
>  static const u8 retinsn[] = { RET_INSN_OPCODE, 0xcc, 0xcc, 0xcc, 0xcc };
>  
> +static u8 __is_Jcc(u8 *insn) /* Jcc.d32 */
> +{
> +	u8 ret = 0;
> +
> +	if (insn[0] == 0x0f) {
> +		u8 tmp = insn[1];
> +		if ((tmp & 0xf0) == 0x80)
> +			ret = tmp;
> +	}
> +
> +	return ret;
> +}
> +
> +extern void __static_call_return(void);
> +
> +asm (".global __static_call_return\n\t"
> +     ".type __static_call_return, @function\n\t"
> +     ASM_FUNC_ALIGN "\n\t"
> +     "__static_call_return:\n\t"
> +     ANNOTATE_NOENDBR
> +     ANNOTATE_RETPOLINE_SAFE
> +     "ret; int3\n\t"
> +     ".size __static_call_return, . - __static_call_return \n\t");
> +
>  static void __ref __static_call_transform(void *insn, enum insn_type type,
>  					  void *func, bool modinit)
>  {
>  	const void *emulate = NULL;
>  	int size = CALL_INSN_SIZE;
>  	const void *code;
> +	u8 op, buf[6];
> +
> +	if ((type == JMP || type == RET) && (op = __is_Jcc(insn)))
> +		type = JCC;
>  
>  	switch (type) {
>  	case CALL:
> @@ -57,6 +86,20 @@ static void __ref __static_call_transfor
>  		else
>  			code = &retinsn;
>  		break;
> +
> +	case JCC:
> +		if (!func) {
> +			func = __static_call_return;
> +			if (cpu_feature_enabled(X86_FEATURE_RETHUNK))
> +				func = x86_return_thunk;
> +		}
> +
> +		buf[0] = 0x0f;
> +		__text_gen_insn(buf+1, op, insn+1, func, 5);
> +		code = buf;
> +		size = 6;
> +
> +		break;
>  	}
>  
>  	if (memcmp(insn, code, size) == 0)
> @@ -68,9 +111,9 @@ static void __ref __static_call_transfor
>  	text_poke_bp(insn, code, size, emulate);
>  }
>  
> -static void __static_call_validate(void *insn, bool tail, bool tramp)
> +static void __static_call_validate(u8 *insn, bool tail, bool tramp)
>  {
> -	u8 opcode = *(u8 *)insn;
> +	u8 opcode = insn[0];
>  
>  	if (tramp && memcmp(insn+5, tramp_ud, 3)) {
>  		pr_err("trampoline signature fail");
> @@ -79,7 +122,8 @@ static void __static_call_validate(void
>  
>  	if (tail) {
>  		if (opcode == JMP32_INSN_OPCODE ||
> -		    opcode == RET_INSN_OPCODE)
> +		    opcode == RET_INSN_OPCODE ||
> +		    __is_Jcc(insn))
>  			return;
>  	} else {
>  		if (opcode == CALL_INSN_OPCODE ||
>
  
Peter Zijlstra Jan. 24, 2023, 1:06 p.m. UTC | #2
On Mon, Jan 23, 2023 at 05:44:31PM -0500, Steven Rostedt wrote:
> On Mon, 23 Jan 2023 21:59:18 +0100
> Peter Zijlstra <peterz@infradead.org> wrote:
> 
> > Clang likes to create conditional tail calls like:
> > 
> > 0000000000000350 <amd_pmu_add_event>:
> > 350:       0f 1f 44 00 00          nopl   0x0(%rax,%rax,1) 351: R_X86_64_NONE      __fentry__-0x4
> > 355:       48 83 bf 20 01 00 00 00         cmpq   $0x0,0x120(%rdi)
> > 35d:       0f 85 00 00 00 00       jne    363 <amd_pmu_add_event+0x13>     35f: R_X86_64_PLT32     __SCT__amd_pmu_branch_add-0x4
> > 363:       e9 00 00 00 00          jmp    368 <amd_pmu_add_event+0x18>     364: R_X86_64_PLT32     __x86_return_thunk-0x4
> > 
> 
> Just to confirm, as it's not clear if this is the static call site or one
> of the functions that is being called.

Ah, you've not looked at enough asm then? ;-) Yes this is the static
call site, see the __SCT_ target (instruction at 0x35d).

> I'm guessing that this is an issue because clang optimizes the static call
> site, right?

Specifically using Jcc (jne in this case) to tail-call the trampoline.

> > Teach the in-line static call text patching about this.
> > 
> > Notably, since there is no conditional-ret, in that caes patch the Jcc
> 
> 					      "in that case"

typing so hard.. :-)
  
Steven Rostedt Jan. 24, 2023, 3:07 p.m. UTC | #3
On Tue, 24 Jan 2023 14:06:49 +0100
Peter Zijlstra <peterz@infradead.org> wrote:

> > Just to confirm, as it's not clear if this is the static call site or one
> > of the functions that is being called.  
> 
> Ah, you've not looked at enough asm then? ;-) Yes this is the static
> call site, see the __SCT_ target (instruction at 0x35d).

Yeah, could you specify it a bit more in the change log such that those
looking back at this don't have to have that requirement of staring at
enough asm ;-)

It's actually been some time since I stared at compiler output (although
now that I'm starting to play with rust, that's going to start back up
soon).

-- Steve
  

Patch

--- a/arch/x86/kernel/static_call.c
+++ b/arch/x86/kernel/static_call.c
@@ -9,6 +9,7 @@  enum insn_type {
 	NOP = 1,  /* site cond-call */
 	JMP = 2,  /* tramp / site tail-call */
 	RET = 3,  /* tramp / site cond-tail-call */
+	JCC = 4,
 };
 
 /*
@@ -25,12 +26,40 @@  static const u8 xor5rax[] = { 0x2e, 0x2e
 
 static const u8 retinsn[] = { RET_INSN_OPCODE, 0xcc, 0xcc, 0xcc, 0xcc };
 
+static u8 __is_Jcc(u8 *insn) /* Jcc.d32 */
+{
+	u8 ret = 0;
+
+	if (insn[0] == 0x0f) {
+		u8 tmp = insn[1];
+		if ((tmp & 0xf0) == 0x80)
+			ret = tmp;
+	}
+
+	return ret;
+}
+
+extern void __static_call_return(void);
+
+asm (".global __static_call_return\n\t"
+     ".type __static_call_return, @function\n\t"
+     ASM_FUNC_ALIGN "\n\t"
+     "__static_call_return:\n\t"
+     ANNOTATE_NOENDBR
+     ANNOTATE_RETPOLINE_SAFE
+     "ret; int3\n\t"
+     ".size __static_call_return, . - __static_call_return \n\t");
+
 static void __ref __static_call_transform(void *insn, enum insn_type type,
 					  void *func, bool modinit)
 {
 	const void *emulate = NULL;
 	int size = CALL_INSN_SIZE;
 	const void *code;
+	u8 op, buf[6];
+
+	if ((type == JMP || type == RET) && (op = __is_Jcc(insn)))
+		type = JCC;
 
 	switch (type) {
 	case CALL:
@@ -57,6 +86,20 @@  static void __ref __static_call_transfor
 		else
 			code = &retinsn;
 		break;
+
+	case JCC:
+		if (!func) {
+			func = __static_call_return;
+			if (cpu_feature_enabled(X86_FEATURE_RETHUNK))
+				func = x86_return_thunk;
+		}
+
+		buf[0] = 0x0f;
+		__text_gen_insn(buf+1, op, insn+1, func, 5);
+		code = buf;
+		size = 6;
+
+		break;
 	}
 
 	if (memcmp(insn, code, size) == 0)
@@ -68,9 +111,9 @@  static void __ref __static_call_transfor
 	text_poke_bp(insn, code, size, emulate);
 }
 
-static void __static_call_validate(void *insn, bool tail, bool tramp)
+static void __static_call_validate(u8 *insn, bool tail, bool tramp)
 {
-	u8 opcode = *(u8 *)insn;
+	u8 opcode = insn[0];
 
 	if (tramp && memcmp(insn+5, tramp_ud, 3)) {
 		pr_err("trampoline signature fail");
@@ -79,7 +122,8 @@  static void __static_call_validate(void
 
 	if (tail) {
 		if (opcode == JMP32_INSN_OPCODE ||
-		    opcode == RET_INSN_OPCODE)
+		    opcode == RET_INSN_OPCODE ||
+		    __is_Jcc(insn))
 			return;
 	} else {
 		if (opcode == CALL_INSN_OPCODE ||