[RFC,V2,3/9] riscv: add support for misaligned handling in S-mode
Commit Message
Misalignment handling is only supported for M-mode and uses direct
accesses to user memory. in S-mode, this requires to use the
get_user()/put_user() accessors. Implement load_u8(), store_u8() and
get_insn() using these accessors.
Signed-off-by: Clément Léger <cleger@rivosinc.com>
---
arch/riscv/kernel/Makefile | 2 +-
arch/riscv/kernel/traps.c | 7 --
arch/riscv/kernel/traps_misaligned.c | 118 ++++++++++++++++++++++++---
3 files changed, 106 insertions(+), 21 deletions(-)
Comments
On Tue, Jul 04, 2023 at 04:09:18PM +0200, Clément Léger wrote:
> Misalignment handling is only supported for M-mode and uses direct
> accesses to user memory. in S-mode, this requires to use the
> get_user()/put_user() accessors. Implement load_u8(), store_u8() and
> get_insn() using these accessors.
>
> Signed-off-by: Clément Léger <cleger@rivosinc.com>
> ---
> arch/riscv/kernel/Makefile | 2 +-
> arch/riscv/kernel/traps.c | 7 --
> arch/riscv/kernel/traps_misaligned.c | 118 ++++++++++++++++++++++++---
> 3 files changed, 106 insertions(+), 21 deletions(-)
>
> diff --git a/arch/riscv/kernel/Makefile b/arch/riscv/kernel/Makefile
> index 153864e4f399..79b8dafc699d 100644
> --- a/arch/riscv/kernel/Makefile
> +++ b/arch/riscv/kernel/Makefile
> @@ -55,10 +55,10 @@ obj-y += riscv_ksyms.o
> obj-y += stacktrace.o
> obj-y += cacheinfo.o
> obj-y += patch.o
> +obj-y += traps_misaligned.o
> obj-y += probes/
> obj-$(CONFIG_MMU) += vdso.o vdso/
>
> -obj-$(CONFIG_RISCV_M_MODE) += traps_misaligned.o
I think it would be better to combine this wath with "[PATCH 2/9] riscv: avoid
missing prototypes warning" from the series to avoid the breakage on commit 2/9.
Also, it would be a bit more clear of the intent.
> obj-$(CONFIG_FPU) += fpu.o
> obj-$(CONFIG_SMP) += smpboot.o
> obj-$(CONFIG_SMP) += smp.o
> diff --git a/arch/riscv/kernel/traps.c b/arch/riscv/kernel/traps.c
> index 7fcaf2fd27a1..b2fb2266fb83 100644
> --- a/arch/riscv/kernel/traps.c
> +++ b/arch/riscv/kernel/traps.c
> @@ -149,12 +149,6 @@ DO_ERROR_INFO(do_trap_insn_illegal,
> SIGILL, ILL_ILLOPC, "illegal instruction");
> DO_ERROR_INFO(do_trap_load_fault,
> SIGSEGV, SEGV_ACCERR, "load access fault");
> -#ifndef CONFIG_RISCV_M_MODE
> -DO_ERROR_INFO(do_trap_load_misaligned,
> - SIGBUS, BUS_ADRALN, "Oops - load address misaligned");
> -DO_ERROR_INFO(do_trap_store_misaligned,
> - SIGBUS, BUS_ADRALN, "Oops - store (or AMO) address misaligned");
> -#else
>
> asmlinkage __visible __trap_section void do_trap_load_misaligned(struct pt_regs *regs)
> {
> @@ -197,7 +191,6 @@ asmlinkage __visible __trap_section void do_trap_store_misaligned(struct pt_regs
> irqentry_nmi_exit(regs, state);
> }
> }
> -#endif
> DO_ERROR_INFO(do_trap_store_fault,
> SIGSEGV, SEGV_ACCERR, "store (or AMO) access fault");
> DO_ERROR_INFO(do_trap_ecall_s,
> diff --git a/arch/riscv/kernel/traps_misaligned.c b/arch/riscv/kernel/traps_misaligned.c
> index 0cccac4822a8..9daed7d756ae 100644
> --- a/arch/riscv/kernel/traps_misaligned.c
> +++ b/arch/riscv/kernel/traps_misaligned.c
> @@ -152,21 +152,25 @@
> #define PRECISION_S 0
> #define PRECISION_D 1
>
> -static inline u8 load_u8(const u8 *addr)
> +#ifdef CONFIG_RISCV_M_MODE
> +static inline int load_u8(struct pt_regs *regs, const u8 *addr, u8 *r_val)
> {
> u8 val;
>
> asm volatile("lbu %0, %1" : "=&r" (val) : "m" (*addr));
> + *r_val = val;
>
> - return val;
> + return 0;
> }
>
> -static inline void store_u8(u8 *addr, u8 val)
> +static inline int store_u8(struct pt_regs *regs, u8 *addr, u8 val)
> {
> asm volatile ("sb %0, %1\n" : : "r" (val), "m" (*addr));
> +
> + return 0;
> }
>
> -static inline ulong get_insn(ulong mepc)
> +static inline int get_insn(struct pt_regs *regs, ulong mepc, ulong *r_insn)
> {
> register ulong __mepc asm ("a2") = mepc;
> ulong val, rvc_mask = 3, tmp;
> @@ -195,9 +199,87 @@ static inline ulong get_insn(ulong mepc)
> : [addr] "r" (__mepc), [rvc_mask] "r" (rvc_mask),
> [xlen_minus_16] "i" (XLEN_MINUS_16));
>
> - return val;
> + *r_insn = val;
> +
> + return 0;
> +}
> +#else
> +static inline int load_u8(struct pt_regs *regs, const u8 *addr, u8 *r_val)
> +{
> + if (user_mode(regs)) {
> + return __get_user(*r_val, addr);
> + } else {
> + *r_val = *addr;
> + return 0;
> + }
> }
>
> +static inline int store_u8(struct pt_regs *regs, u8 *addr, u8 val)
> +{
> + if (user_mode(regs)) {
> + return __put_user(val, addr);
> + } else {
> + *addr = val;
> + return 0;
> + }
> +}
> +
> +#define __read_insn(regs, insn, insn_addr) \
> +({ \
> + int __ret; \
> + \
> + if (user_mode(regs)) { \
> + __ret = __get_user(insn, insn_addr); \
> + } else { \
> + insn = *insn_addr; \
> + __ret = 0; \
> + } \
> + \
> + __ret; \
> +})
> +
> +static inline int get_insn(struct pt_regs *regs, ulong epc, ulong *r_insn)
> +{
> + ulong insn = 0;
> +
> + if (epc & 0x2) {
> + ulong tmp = 0;
> + u16 __user *insn_addr = (u16 __user *)epc;
> +
> + if (__read_insn(regs, insn, insn_addr))
> + return -EFAULT;
> + /* __get_user() uses regular "lw" which sign extend the loaded
> + * value make sure to clear higher order bits in case we "or" it
> + * below with the upper 16 bits half.
> + */
> + insn &= GENMASK(15, 0);
> + if ((insn & __INSN_LENGTH_MASK) != __INSN_LENGTH_32) {
> + *r_insn = insn;
> + return 0;
> + }
> + insn_addr++;
> + if (__read_insn(regs, tmp, insn_addr))
> + return -EFAULT;
> + *r_insn = (tmp << 16) | insn;
> +
> + return 0;
> + } else {
> + u32 __user *insn_addr = (u32 __user *)epc;
> +
> + if (__read_insn(regs, insn, insn_addr))
> + return -EFAULT;
> + if ((insn & __INSN_LENGTH_MASK) == __INSN_LENGTH_32) {
> + *r_insn = insn;
> + return 0;
> + }
> + insn &= GENMASK(15, 0);
> + *r_insn = insn;
> +
> + return 0;
> + }
> +}
> +#endif
> +
> union reg_data {
> u8 data_bytes[8];
> ulong data_ulong;
> @@ -208,10 +290,13 @@ int handle_misaligned_load(struct pt_regs *regs)
> {
> union reg_data val;
> unsigned long epc = regs->epc;
> - unsigned long insn = get_insn(epc);
> - unsigned long addr = csr_read(mtval);
> + unsigned long insn;
> + unsigned long addr = regs->badaddr;
Could your commit messages talk a bit about the change from using mtval to
regs->badaddr? Will this sill work in M-mode? I think so, but it would be good
to explain the change.
> int i, fp = 0, shift = 0, len = 0;
>
> + if (get_insn(regs, epc, &insn))
> + return -1;
> +
> regs->epc = 0;
>
> if ((insn & INSN_MASK_LW) == INSN_MATCH_LW) {
> @@ -275,8 +360,10 @@ int handle_misaligned_load(struct pt_regs *regs)
> }
>
> val.data_u64 = 0;
> - for (i = 0; i < len; i++)
> - val.data_bytes[i] = load_u8((void *)(addr + i));
> + for (i = 0; i < len; i++) {
> + if (load_u8(regs, (void *)(addr + i), &val.data_bytes[i]))
> + return -1;
> + }
>
> if (fp)
> return -1;
> @@ -291,10 +378,13 @@ int handle_misaligned_store(struct pt_regs *regs)
> {
> union reg_data val;
> unsigned long epc = regs->epc;
> - unsigned long insn = get_insn(epc);
> - unsigned long addr = csr_read(mtval);
> + unsigned long insn;
> + unsigned long addr = regs->badaddr;
> int i, len = 0;
>
> + if (get_insn(regs, epc, &insn))
> + return -1;
> +
> regs->epc = 0;
>
> val.data_ulong = GET_RS2(insn, regs);
> @@ -328,8 +418,10 @@ int handle_misaligned_store(struct pt_regs *regs)
> return -1;
> }
>
> - for (i = 0; i < len; i++)
> - store_u8((void *)(addr + i), val.data_bytes[i]);
> + for (i = 0; i < len; i++) {
> + if (store_u8(regs, (void *)(addr + i), val.data_bytes[i]))
> + return -1;
> + }
>
> regs->epc = epc + INSN_LEN(insn);
>
> --
> 2.40.1
>
@@ -55,10 +55,10 @@ obj-y += riscv_ksyms.o
obj-y += stacktrace.o
obj-y += cacheinfo.o
obj-y += patch.o
+obj-y += traps_misaligned.o
obj-y += probes/
obj-$(CONFIG_MMU) += vdso.o vdso/
-obj-$(CONFIG_RISCV_M_MODE) += traps_misaligned.o
obj-$(CONFIG_FPU) += fpu.o
obj-$(CONFIG_SMP) += smpboot.o
obj-$(CONFIG_SMP) += smp.o
@@ -149,12 +149,6 @@ DO_ERROR_INFO(do_trap_insn_illegal,
SIGILL, ILL_ILLOPC, "illegal instruction");
DO_ERROR_INFO(do_trap_load_fault,
SIGSEGV, SEGV_ACCERR, "load access fault");
-#ifndef CONFIG_RISCV_M_MODE
-DO_ERROR_INFO(do_trap_load_misaligned,
- SIGBUS, BUS_ADRALN, "Oops - load address misaligned");
-DO_ERROR_INFO(do_trap_store_misaligned,
- SIGBUS, BUS_ADRALN, "Oops - store (or AMO) address misaligned");
-#else
asmlinkage __visible __trap_section void do_trap_load_misaligned(struct pt_regs *regs)
{
@@ -197,7 +191,6 @@ asmlinkage __visible __trap_section void do_trap_store_misaligned(struct pt_regs
irqentry_nmi_exit(regs, state);
}
}
-#endif
DO_ERROR_INFO(do_trap_store_fault,
SIGSEGV, SEGV_ACCERR, "store (or AMO) access fault");
DO_ERROR_INFO(do_trap_ecall_s,
@@ -152,21 +152,25 @@
#define PRECISION_S 0
#define PRECISION_D 1
-static inline u8 load_u8(const u8 *addr)
+#ifdef CONFIG_RISCV_M_MODE
+static inline int load_u8(struct pt_regs *regs, const u8 *addr, u8 *r_val)
{
u8 val;
asm volatile("lbu %0, %1" : "=&r" (val) : "m" (*addr));
+ *r_val = val;
- return val;
+ return 0;
}
-static inline void store_u8(u8 *addr, u8 val)
+static inline int store_u8(struct pt_regs *regs, u8 *addr, u8 val)
{
asm volatile ("sb %0, %1\n" : : "r" (val), "m" (*addr));
+
+ return 0;
}
-static inline ulong get_insn(ulong mepc)
+static inline int get_insn(struct pt_regs *regs, ulong mepc, ulong *r_insn)
{
register ulong __mepc asm ("a2") = mepc;
ulong val, rvc_mask = 3, tmp;
@@ -195,9 +199,87 @@ static inline ulong get_insn(ulong mepc)
: [addr] "r" (__mepc), [rvc_mask] "r" (rvc_mask),
[xlen_minus_16] "i" (XLEN_MINUS_16));
- return val;
+ *r_insn = val;
+
+ return 0;
+}
+#else
+static inline int load_u8(struct pt_regs *regs, const u8 *addr, u8 *r_val)
+{
+ if (user_mode(regs)) {
+ return __get_user(*r_val, addr);
+ } else {
+ *r_val = *addr;
+ return 0;
+ }
}
+static inline int store_u8(struct pt_regs *regs, u8 *addr, u8 val)
+{
+ if (user_mode(regs)) {
+ return __put_user(val, addr);
+ } else {
+ *addr = val;
+ return 0;
+ }
+}
+
+#define __read_insn(regs, insn, insn_addr) \
+({ \
+ int __ret; \
+ \
+ if (user_mode(regs)) { \
+ __ret = __get_user(insn, insn_addr); \
+ } else { \
+ insn = *insn_addr; \
+ __ret = 0; \
+ } \
+ \
+ __ret; \
+})
+
+static inline int get_insn(struct pt_regs *regs, ulong epc, ulong *r_insn)
+{
+ ulong insn = 0;
+
+ if (epc & 0x2) {
+ ulong tmp = 0;
+ u16 __user *insn_addr = (u16 __user *)epc;
+
+ if (__read_insn(regs, insn, insn_addr))
+ return -EFAULT;
+ /* __get_user() uses regular "lw" which sign extend the loaded
+ * value make sure to clear higher order bits in case we "or" it
+ * below with the upper 16 bits half.
+ */
+ insn &= GENMASK(15, 0);
+ if ((insn & __INSN_LENGTH_MASK) != __INSN_LENGTH_32) {
+ *r_insn = insn;
+ return 0;
+ }
+ insn_addr++;
+ if (__read_insn(regs, tmp, insn_addr))
+ return -EFAULT;
+ *r_insn = (tmp << 16) | insn;
+
+ return 0;
+ } else {
+ u32 __user *insn_addr = (u32 __user *)epc;
+
+ if (__read_insn(regs, insn, insn_addr))
+ return -EFAULT;
+ if ((insn & __INSN_LENGTH_MASK) == __INSN_LENGTH_32) {
+ *r_insn = insn;
+ return 0;
+ }
+ insn &= GENMASK(15, 0);
+ *r_insn = insn;
+
+ return 0;
+ }
+}
+#endif
+
union reg_data {
u8 data_bytes[8];
ulong data_ulong;
@@ -208,10 +290,13 @@ int handle_misaligned_load(struct pt_regs *regs)
{
union reg_data val;
unsigned long epc = regs->epc;
- unsigned long insn = get_insn(epc);
- unsigned long addr = csr_read(mtval);
+ unsigned long insn;
+ unsigned long addr = regs->badaddr;
int i, fp = 0, shift = 0, len = 0;
+ if (get_insn(regs, epc, &insn))
+ return -1;
+
regs->epc = 0;
if ((insn & INSN_MASK_LW) == INSN_MATCH_LW) {
@@ -275,8 +360,10 @@ int handle_misaligned_load(struct pt_regs *regs)
}
val.data_u64 = 0;
- for (i = 0; i < len; i++)
- val.data_bytes[i] = load_u8((void *)(addr + i));
+ for (i = 0; i < len; i++) {
+ if (load_u8(regs, (void *)(addr + i), &val.data_bytes[i]))
+ return -1;
+ }
if (fp)
return -1;
@@ -291,10 +378,13 @@ int handle_misaligned_store(struct pt_regs *regs)
{
union reg_data val;
unsigned long epc = regs->epc;
- unsigned long insn = get_insn(epc);
- unsigned long addr = csr_read(mtval);
+ unsigned long insn;
+ unsigned long addr = regs->badaddr;
int i, len = 0;
+ if (get_insn(regs, epc, &insn))
+ return -1;
+
regs->epc = 0;
val.data_ulong = GET_RS2(insn, regs);
@@ -328,8 +418,10 @@ int handle_misaligned_store(struct pt_regs *regs)
return -1;
}
- for (i = 0; i < len; i++)
- store_u8((void *)(addr + i), val.data_bytes[i]);
+ for (i = 0; i < len; i++) {
+ if (store_u8(regs, (void *)(addr + i), val.data_bytes[i]))
+ return -1;
+ }
regs->epc = epc + INSN_LEN(insn);