[3/3] RISC-V: cmpmem for RISCV with V extension

Message ID 20231211094728.1623032-4-slewis@rivosinc.com
State Unresolved
Headers
Series RISC-V: vectorised memory operations |

Checks

Context Check Description
snail/gcc-patch-check warning Git am fail log

Commit Message

Sergei Lewis Dec. 11, 2023, 9:47 a.m. UTC
  gcc/ChangeLog:

    * config/riscv/riscv-protos.h (riscv_vector::expand_vec_cmpmem): New function
    declaration.

    * config/riscv/riscv-string.cc (riscv_vector::expand_vec_cmpmem): New
    function; this generates an inline vectorised memory compare, if and only if
    we know the entire operation can be performed in a single vector load per
    input

    * config/riscv/riscv.md (cmpmemsi): Try riscv_vector::expand_vec_cmpmem for
    constant lengths

gcc/testsuite/ChangeLog:

    * gcc.target/riscv/rvv/base/cmpmem-1.c: New codegen tests
    * gcc.target/riscv/rvv/base/cmpmem-2.c: New execution tests
---
 gcc/config/riscv/riscv-protos.h               |   1 +
 gcc/config/riscv/riscv-string.cc              | 111 ++++++++++++++++++
 gcc/config/riscv/riscv.md                     |  15 +++
 .../gcc.target/riscv/rvv/base/cmpmem-1.c      |  85 ++++++++++++++
 .../gcc.target/riscv/rvv/base/cmpmem-2.c      |  69 +++++++++++
 5 files changed, 281 insertions(+)
 create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/cmpmem-1.c
 create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/cmpmem-2.c
  

Patch

diff --git a/gcc/config/riscv/riscv-protos.h b/gcc/config/riscv/riscv-protos.h
index 950cb65c910..72378438552 100644
--- a/gcc/config/riscv/riscv-protos.h
+++ b/gcc/config/riscv/riscv-protos.h
@@ -561,6 +561,7 @@  void expand_rawmemchr (machine_mode, rtx, rtx, rtx, bool = false);
 bool expand_strcmp (rtx, rtx, rtx, rtx, unsigned HOST_WIDE_INT, bool);
 void emit_vec_extract (rtx, rtx, poly_int64);
 bool expand_vec_setmem (rtx, rtx, rtx, rtx);
+bool expand_vec_cmpmem (rtx, rtx, rtx, rtx);
 
 /* Rounding mode bitfield for fixed point VXRM.  */
 enum fixed_point_rounding_mode
diff --git a/gcc/config/riscv/riscv-string.cc b/gcc/config/riscv/riscv-string.cc
index 0abbd5f8b28..6128565310b 100644
--- a/gcc/config/riscv/riscv-string.cc
+++ b/gcc/config/riscv/riscv-string.cc
@@ -1329,4 +1329,115 @@  expand_vec_setmem (rtx dst_in, rtx length_in, rtx fill_value_in,
   return true;
 }
 
+
+/* Used by cmpmemsi in riscv.md.  */
+
+bool
+expand_vec_cmpmem (rtx result_out, rtx blk_a_in, rtx blk_b_in, rtx length_in)
+{
+  /* we're generating vector code.  */
+  if (!TARGET_VECTOR)
+    return false;
+  /* if we can't reason about the length, let libc handle the operation.  */
+  if (!CONST_INT_P (length_in))
+    return false;
+
+  HOST_WIDE_INT length = INTVAL (length_in);
+  HOST_WIDE_INT lmul;
+
+  /* select an lmul such that the data just fits into one vector operation;
+     bail if we can't.  */
+  if (!select_appropriate_lmul (length, lmul))
+    return false;
+
+  /* strategy:
+     load entire blocks at a and b into vector regs
+     generate mask of bytes that differ
+     find first set bit in mask
+     find offset of first set bit in mask, use 0 if none set
+     result is ((char*)a[offset] - (char*)b[offset])
+   */
+
+  machine_mode vmode = riscv_vector::get_vector_mode (QImode,
+	    BYTES_PER_RISCV_VECTOR * lmul).require ();
+  rtx blk_a_addr = copy_addr_to_reg (XEXP (blk_a_in, 0));
+  rtx blk_a = change_address (blk_a_in, vmode, blk_a_addr);
+  rtx blk_b_addr = copy_addr_to_reg (XEXP (blk_b_in, 0));
+  rtx blk_b = change_address (blk_b_in, vmode, blk_b_addr);
+
+  rtx vec_a = gen_reg_rtx (vmode);
+  rtx vec_b = gen_reg_rtx (vmode);
+
+  machine_mode mask_mode = get_mask_mode (vmode);
+  rtx mask = gen_reg_rtx (mask_mode);
+  rtx mismatch_ofs = gen_reg_rtx (Pmode);
+
+  rtx ne = gen_rtx_NE (mask_mode, vec_a, vec_b);
+  rtx vmsops[] = {mask, ne, vec_a, vec_b};
+  rtx vfops[] = {mismatch_ofs, mask};
+
+  /* If the length is exactly vlmax for the selected mode, do that.
+     Otherwise, use a predicated store.  */
+
+  if (known_eq (GET_MODE_SIZE (vmode), INTVAL (length_in)))
+    {
+      emit_move_insn (vec_a, blk_a);
+      emit_move_insn (vec_b, blk_b);
+      emit_vlmax_insn (code_for_pred_cmp (vmode),
+	      riscv_vector::COMPARE_OP, vmsops);
+
+      emit_vlmax_insn (code_for_pred_ffs (mask_mode, Pmode),
+	      riscv_vector::CPOP_OP, vfops);
+    }
+  else
+    {
+      if (!satisfies_constraint_K (length_in))
+	      length_in= force_reg (Pmode, length_in);
+
+      rtx memmask =  CONSTM1_RTX (mask_mode);
+
+      rtx m_ops_a[] = {vec_a, memmask, blk_a};
+      rtx m_ops_b[] = {vec_b, memmask, blk_b};
+
+      emit_nonvlmax_insn (code_for_pred_mov (vmode),
+	      riscv_vector::UNARY_OP_TAMA, m_ops_a, length_in);
+      emit_nonvlmax_insn (code_for_pred_mov (vmode),
+	      riscv_vector::UNARY_OP_TAMA, m_ops_b, length_in);
+
+      emit_nonvlmax_insn (code_for_pred_cmp (vmode),
+	      riscv_vector::COMPARE_OP, vmsops, length_in);
+
+      emit_nonvlmax_insn (code_for_pred_ffs (mask_mode, Pmode),
+	      riscv_vector::CPOP_OP, vfops, length_in);
+    }
+
+  /* mismatch_ofs is -1 if blocks match, or the offset of
+     the first mismatch otherwise.  */
+    rtx ltz = gen_reg_rtx (Xmode);
+    emit_insn (gen_slt_3 (LT, Xmode, Xmode, ltz, mismatch_ofs, const0_rtx));
+  /* mismatch_ofs += (mismatch_ofs < 0) ? 1 : 0.  */
+    emit_insn (gen_rtx_SET (mismatch_ofs, gen_rtx_PLUS (Pmode,
+	    mismatch_ofs, ltz)));
+
+  /* unconditionally load the bytes at mismatch_ofs and subtract them
+     to get our result.  */
+    emit_insn (gen_rtx_SET (blk_a_addr, gen_rtx_PLUS (Pmode,
+	    mismatch_ofs, blk_a_addr)));
+    emit_insn (gen_rtx_SET (blk_b_addr, gen_rtx_PLUS (Pmode,
+	    mismatch_ofs, blk_b_addr)));
+
+    blk_a = change_address (blk_a, QImode, blk_a_addr);
+    blk_b = change_address (blk_b, QImode, blk_b_addr);
+
+    rtx byte_a = gen_reg_rtx (SImode);
+    rtx byte_b = gen_reg_rtx (SImode);
+    do_zero_extendqi2 (byte_a, blk_a);
+    do_zero_extendqi2 (byte_b, blk_b);
+
+    emit_insn (gen_rtx_SET (result_out, gen_rtx_MINUS (SImode,
+	    byte_a, byte_b)));
+
+
+  return true;
+}
 }
diff --git a/gcc/config/riscv/riscv.md b/gcc/config/riscv/riscv.md
index 29d3b1aa342..39829c8566c 100644
--- a/gcc/config/riscv/riscv.md
+++ b/gcc/config/riscv/riscv.md
@@ -2395,6 +2395,21 @@ 
     FAIL;
 })
 
+(define_expand "cmpmemsi"
+ [(set (match_operand:SI 0 "register_operand" "")
+       (compare:SI (match_operand:BLK 1 "memory_operand" "")
+				  (match_operand:BLK 2 "memory_operand" "")))
+  (use (match_operand:SI 3 "general_operand" ""))
+  (use (match_operand:SI 4 "" ""))]
+ "TARGET_VECTOR"
+{
+ if (riscv_vector::expand_vec_cmpmem (operands[0], operands[1],
+				  operands[2], operands[3]))
+   DONE;
+ else
+   FAIL;
+})
+
 ;; Expand in-line code to clear the instruction cache between operand[0] and
 ;; operand[1].
 (define_expand "clear_cache"
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/cmpmem-1.c b/gcc/testsuite/gcc.target/riscv/rvv/base/cmpmem-1.c
new file mode 100644
index 00000000000..686ac6d6b0c
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/cmpmem-1.c
@@ -0,0 +1,85 @@ 
+/* { dg-do compile } */
+/* { dg-add-options riscv_v } */
+/* { dg-additional-options "-O3" } */
+/* { dg-final { check-function-bodies "**" "" } } */
+
+#include <string.h>
+
+#define MIN_VECTOR_BYTES (__riscv_v_min_vlen/8)
+
+/* trivial memcmp should use inline scalar ops
+** f1:
+**  lbu\s+a\d+,0\(a0\)
+**  lbu\s+a\d+,0\(a1\)
+**  subw\s+a0,a\d+,a\d+
+**  ret
+*/
+int f1 (void * a, void * b)
+{
+  return memcmp (a, b, 1);
+}
+
+/* tiny memcmp should use libc
+** f2:
+**  li\s+a2,\d+
+**  tail\s+memcmp
+*/
+int f2 (void * a, void * b)
+{
+  return memcmp (a, b, MIN_VECTOR_BYTES-1);
+}
+
+/* vectorise+inline minimum vector register width with LMUL=1
+** f3:
+**  (
+**  vsetivli\s+zero,\d+,e8,m1,ta,ma
+**  |
+**  li\s+a\d+,\d+
+**  vsetvli\s+zero,a\d+,e8,m1,ta,ma
+**  )
+**  ...
+**  ret
+*/
+int f3 (void * a, void * b)
+{
+  return memcmp (a, b, MIN_VECTOR_BYTES);
+}
+
+/* vectorised code should use smallest lmul known to fit length
+** f4:
+**  (
+**  vsetivli\s+zero,\d+,e8,m2,ta,ma
+**  |
+**  li\s+a\d+,\d+
+**  vsetvli\s+zero,a\d+,e8,m2,ta,ma
+**  )
+**  ...
+**  ret
+*/
+int f4 (void * a, void * b)
+{
+  return memcmp (a, b, MIN_VECTOR_BYTES+1);
+}
+
+/* vectorise+inline up to LMUL=8
+** f5:
+**  li\s+a\d+,\d+
+**  vsetvli\s+zero,a\d+,e8,m8,ta,ma
+**  ...
+**  ret
+*/
+int f5 (void * a, void * b)
+{
+  return memcmp (a, b, MIN_VECTOR_BYTES*8);
+}
+
+/* don't inline if the length is too large for one operation
+** f6:
+**  li\s+a2,\d+
+**  tail\s+memcmp
+*/
+int f6 (void * a, void * b)
+{
+  return memcmp (a, b, MIN_VECTOR_BYTES*8+1);
+}
+
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/cmpmem-2.c b/gcc/testsuite/gcc.target/riscv/rvv/base/cmpmem-2.c
new file mode 100644
index 00000000000..eedd23d4db0
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/cmpmem-2.c
@@ -0,0 +1,69 @@ 
+/* { dg-do run { target { riscv_v } } } */
+/* { dg-options "-O2" } */
+
+#include <string.h>
+#include <stdlib.h>
+
+#define MIN_VECTOR_BYTES (__riscv_v_min_vlen/8)
+
+static inline __attribute__((always_inline)) 
+void do_one_test( int const size, int const diff_offset, 
+    int const diff_dir ) 
+{
+  unsigned char A[size];
+  unsigned char B[size];
+  unsigned char const fill_value = 0x55;
+  memset( A, fill_value, size );
+  memset( B, fill_value, size );
+
+  if( diff_dir != 0 ) {
+    if( diff_dir < 0 ) {
+      A[diff_offset] = fill_value-1;
+    } else {
+      A[diff_offset] = fill_value+1;
+    }
+  }
+
+  if( memcmp( A, B, size ) != diff_dir ) {
+    abort ();
+  }
+}
+
+int main()
+{
+  do_one_test( 0, 0, 0  );
+
+  do_one_test( 1, 0, -1 );
+  do_one_test( 1, 0,  0 );
+  do_one_test( 1, 0,  1 );
+
+  do_one_test( MIN_VECTOR_BYTES-1, 0, -1 );
+  do_one_test( MIN_VECTOR_BYTES-1, 0,  0 );
+  do_one_test( MIN_VECTOR_BYTES-1, 0,  1 );
+  do_one_test( MIN_VECTOR_BYTES-1, 1, -1 );
+  do_one_test( MIN_VECTOR_BYTES-1, 1,  0 );
+  do_one_test( MIN_VECTOR_BYTES-1, 1,  1 );
+
+  do_one_test( MIN_VECTOR_BYTES, 0, -1 );
+  do_one_test( MIN_VECTOR_BYTES, 0,  0 );
+  do_one_test( MIN_VECTOR_BYTES, 0,  1 );
+  do_one_test( MIN_VECTOR_BYTES, MIN_VECTOR_BYTES-1, -1 );
+  do_one_test( MIN_VECTOR_BYTES, MIN_VECTOR_BYTES-1,  0 );
+  do_one_test( MIN_VECTOR_BYTES, MIN_VECTOR_BYTES-1,  1 );
+
+  do_one_test( MIN_VECTOR_BYTES+1, 0, -1 );
+  do_one_test( MIN_VECTOR_BYTES+1, 0,  0 );
+  do_one_test( MIN_VECTOR_BYTES+1, 0,  1 );
+  do_one_test( MIN_VECTOR_BYTES+1, MIN_VECTOR_BYTES, -1 );
+  do_one_test( MIN_VECTOR_BYTES+1, MIN_VECTOR_BYTES,  0 );
+  do_one_test( MIN_VECTOR_BYTES+1, MIN_VECTOR_BYTES,  1 );
+
+  do_one_test( MIN_VECTOR_BYTES*8, 0, -1 );
+  do_one_test( MIN_VECTOR_BYTES*8, 0,  0 );
+  do_one_test( MIN_VECTOR_BYTES*8, 0,  1 );
+  do_one_test( MIN_VECTOR_BYTES*8, MIN_VECTOR_BYTES*8-1, -1 );
+  do_one_test( MIN_VECTOR_BYTES*8, MIN_VECTOR_BYTES*8-1,  0 );
+  do_one_test( MIN_VECTOR_BYTES*8, MIN_VECTOR_BYTES*8-1,  1 );
+
+  return 0;
+}