[3/3] x86/static_call: Add support for Jcc tail-calls
Commit Message
Clang likes to create conditional tail calls like:
0000000000000350 <amd_pmu_add_event>:
350: 0f 1f 44 00 00 nopl 0x0(%rax,%rax,1) 351: R_X86_64_NONE __fentry__-0x4
355: 48 83 bf 20 01 00 00 00 cmpq $0x0,0x120(%rdi)
35d: 0f 85 00 00 00 00 jne 363 <amd_pmu_add_event+0x13> 35f: R_X86_64_PLT32 __SCT__amd_pmu_branch_add-0x4
363: e9 00 00 00 00 jmp 368 <amd_pmu_add_event+0x18> 364: R_X86_64_PLT32 __x86_return_thunk-0x4
Teach the in-line static call text patching about this.
Notably, since there is no conditional-ret, in that caes patch the Jcc
to point at an empty stub function that does the ret -- or the return
thunk when needed.
Reported-by: "Erhard F." <erhard_f@mailbox.org>
Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org>
Reviewed-by: Masami Hiramatsu (Google) <mhiramat@kernel.org>
---
arch/x86/kernel/static_call.c | 50 +++++++++++++++++++++++++++++++++++++++---
1 file changed, 47 insertions(+), 3 deletions(-)
Comments
On Mon, 23 Jan 2023 21:59:18 +0100
Peter Zijlstra <peterz@infradead.org> wrote:
> Clang likes to create conditional tail calls like:
>
> 0000000000000350 <amd_pmu_add_event>:
> 350: 0f 1f 44 00 00 nopl 0x0(%rax,%rax,1) 351: R_X86_64_NONE __fentry__-0x4
> 355: 48 83 bf 20 01 00 00 00 cmpq $0x0,0x120(%rdi)
> 35d: 0f 85 00 00 00 00 jne 363 <amd_pmu_add_event+0x13> 35f: R_X86_64_PLT32 __SCT__amd_pmu_branch_add-0x4
> 363: e9 00 00 00 00 jmp 368 <amd_pmu_add_event+0x18> 364: R_X86_64_PLT32 __x86_return_thunk-0x4
>
Just to confirm, as it's not clear if this is the static call site or one
of the functions that is being called.
I'm guessing that this is an issue because clang optimizes the static call
site, right?
> Teach the in-line static call text patching about this.
>
> Notably, since there is no conditional-ret, in that caes patch the Jcc
"in that case"
-- Steve
> to point at an empty stub function that does the ret -- or the return
> thunk when needed.
>
> Reported-by: "Erhard F." <erhard_f@mailbox.org>
> Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org>
> Reviewed-by: Masami Hiramatsu (Google) <mhiramat@kernel.org>
> ---
> arch/x86/kernel/static_call.c | 50 +++++++++++++++++++++++++++++++++++++++---
> 1 file changed, 47 insertions(+), 3 deletions(-)
>
> --- a/arch/x86/kernel/static_call.c
> +++ b/arch/x86/kernel/static_call.c
> @@ -9,6 +9,7 @@ enum insn_type {
> NOP = 1, /* site cond-call */
> JMP = 2, /* tramp / site tail-call */
> RET = 3, /* tramp / site cond-tail-call */
> + JCC = 4,
> };
>
> /*
> @@ -25,12 +26,40 @@ static const u8 xor5rax[] = { 0x2e, 0x2e
>
> static const u8 retinsn[] = { RET_INSN_OPCODE, 0xcc, 0xcc, 0xcc, 0xcc };
>
> +static u8 __is_Jcc(u8 *insn) /* Jcc.d32 */
> +{
> + u8 ret = 0;
> +
> + if (insn[0] == 0x0f) {
> + u8 tmp = insn[1];
> + if ((tmp & 0xf0) == 0x80)
> + ret = tmp;
> + }
> +
> + return ret;
> +}
> +
> +extern void __static_call_return(void);
> +
> +asm (".global __static_call_return\n\t"
> + ".type __static_call_return, @function\n\t"
> + ASM_FUNC_ALIGN "\n\t"
> + "__static_call_return:\n\t"
> + ANNOTATE_NOENDBR
> + ANNOTATE_RETPOLINE_SAFE
> + "ret; int3\n\t"
> + ".size __static_call_return, . - __static_call_return \n\t");
> +
> static void __ref __static_call_transform(void *insn, enum insn_type type,
> void *func, bool modinit)
> {
> const void *emulate = NULL;
> int size = CALL_INSN_SIZE;
> const void *code;
> + u8 op, buf[6];
> +
> + if ((type == JMP || type == RET) && (op = __is_Jcc(insn)))
> + type = JCC;
>
> switch (type) {
> case CALL:
> @@ -57,6 +86,20 @@ static void __ref __static_call_transfor
> else
> code = &retinsn;
> break;
> +
> + case JCC:
> + if (!func) {
> + func = __static_call_return;
> + if (cpu_feature_enabled(X86_FEATURE_RETHUNK))
> + func = x86_return_thunk;
> + }
> +
> + buf[0] = 0x0f;
> + __text_gen_insn(buf+1, op, insn+1, func, 5);
> + code = buf;
> + size = 6;
> +
> + break;
> }
>
> if (memcmp(insn, code, size) == 0)
> @@ -68,9 +111,9 @@ static void __ref __static_call_transfor
> text_poke_bp(insn, code, size, emulate);
> }
>
> -static void __static_call_validate(void *insn, bool tail, bool tramp)
> +static void __static_call_validate(u8 *insn, bool tail, bool tramp)
> {
> - u8 opcode = *(u8 *)insn;
> + u8 opcode = insn[0];
>
> if (tramp && memcmp(insn+5, tramp_ud, 3)) {
> pr_err("trampoline signature fail");
> @@ -79,7 +122,8 @@ static void __static_call_validate(void
>
> if (tail) {
> if (opcode == JMP32_INSN_OPCODE ||
> - opcode == RET_INSN_OPCODE)
> + opcode == RET_INSN_OPCODE ||
> + __is_Jcc(insn))
> return;
> } else {
> if (opcode == CALL_INSN_OPCODE ||
>
On Mon, Jan 23, 2023 at 05:44:31PM -0500, Steven Rostedt wrote:
> On Mon, 23 Jan 2023 21:59:18 +0100
> Peter Zijlstra <peterz@infradead.org> wrote:
>
> > Clang likes to create conditional tail calls like:
> >
> > 0000000000000350 <amd_pmu_add_event>:
> > 350: 0f 1f 44 00 00 nopl 0x0(%rax,%rax,1) 351: R_X86_64_NONE __fentry__-0x4
> > 355: 48 83 bf 20 01 00 00 00 cmpq $0x0,0x120(%rdi)
> > 35d: 0f 85 00 00 00 00 jne 363 <amd_pmu_add_event+0x13> 35f: R_X86_64_PLT32 __SCT__amd_pmu_branch_add-0x4
> > 363: e9 00 00 00 00 jmp 368 <amd_pmu_add_event+0x18> 364: R_X86_64_PLT32 __x86_return_thunk-0x4
> >
>
> Just to confirm, as it's not clear if this is the static call site or one
> of the functions that is being called.
Ah, you've not looked at enough asm then? ;-) Yes this is the static
call site, see the __SCT_ target (instruction at 0x35d).
> I'm guessing that this is an issue because clang optimizes the static call
> site, right?
Specifically using Jcc (jne in this case) to tail-call the trampoline.
> > Teach the in-line static call text patching about this.
> >
> > Notably, since there is no conditional-ret, in that caes patch the Jcc
>
> "in that case"
typing so hard.. :-)
On Tue, 24 Jan 2023 14:06:49 +0100
Peter Zijlstra <peterz@infradead.org> wrote:
> > Just to confirm, as it's not clear if this is the static call site or one
> > of the functions that is being called.
>
> Ah, you've not looked at enough asm then? ;-) Yes this is the static
> call site, see the __SCT_ target (instruction at 0x35d).
Yeah, could you specify it a bit more in the change log such that those
looking back at this don't have to have that requirement of staring at
enough asm ;-)
It's actually been some time since I stared at compiler output (although
now that I'm starting to play with rust, that's going to start back up
soon).
-- Steve
@@ -9,6 +9,7 @@ enum insn_type {
NOP = 1, /* site cond-call */
JMP = 2, /* tramp / site tail-call */
RET = 3, /* tramp / site cond-tail-call */
+ JCC = 4,
};
/*
@@ -25,12 +26,40 @@ static const u8 xor5rax[] = { 0x2e, 0x2e
static const u8 retinsn[] = { RET_INSN_OPCODE, 0xcc, 0xcc, 0xcc, 0xcc };
+static u8 __is_Jcc(u8 *insn) /* Jcc.d32 */
+{
+ u8 ret = 0;
+
+ if (insn[0] == 0x0f) {
+ u8 tmp = insn[1];
+ if ((tmp & 0xf0) == 0x80)
+ ret = tmp;
+ }
+
+ return ret;
+}
+
+extern void __static_call_return(void);
+
+asm (".global __static_call_return\n\t"
+ ".type __static_call_return, @function\n\t"
+ ASM_FUNC_ALIGN "\n\t"
+ "__static_call_return:\n\t"
+ ANNOTATE_NOENDBR
+ ANNOTATE_RETPOLINE_SAFE
+ "ret; int3\n\t"
+ ".size __static_call_return, . - __static_call_return \n\t");
+
static void __ref __static_call_transform(void *insn, enum insn_type type,
void *func, bool modinit)
{
const void *emulate = NULL;
int size = CALL_INSN_SIZE;
const void *code;
+ u8 op, buf[6];
+
+ if ((type == JMP || type == RET) && (op = __is_Jcc(insn)))
+ type = JCC;
switch (type) {
case CALL:
@@ -57,6 +86,20 @@ static void __ref __static_call_transfor
else
code = &retinsn;
break;
+
+ case JCC:
+ if (!func) {
+ func = __static_call_return;
+ if (cpu_feature_enabled(X86_FEATURE_RETHUNK))
+ func = x86_return_thunk;
+ }
+
+ buf[0] = 0x0f;
+ __text_gen_insn(buf+1, op, insn+1, func, 5);
+ code = buf;
+ size = 6;
+
+ break;
}
if (memcmp(insn, code, size) == 0)
@@ -68,9 +111,9 @@ static void __ref __static_call_transfor
text_poke_bp(insn, code, size, emulate);
}
-static void __static_call_validate(void *insn, bool tail, bool tramp)
+static void __static_call_validate(u8 *insn, bool tail, bool tramp)
{
- u8 opcode = *(u8 *)insn;
+ u8 opcode = insn[0];
if (tramp && memcmp(insn+5, tramp_ud, 3)) {
pr_err("trampoline signature fail");
@@ -79,7 +122,8 @@ static void __static_call_validate(void
if (tail) {
if (opcode == JMP32_INSN_OPCODE ||
- opcode == RET_INSN_OPCODE)
+ opcode == RET_INSN_OPCODE ||
+ __is_Jcc(insn))
return;
} else {
if (opcode == CALL_INSN_OPCODE ||