[Committed] RISC-V: Make known NITERS loop be aware of dynamic lmul cost model liveness information

Message ID 20231227081641.1031426-1-juzhe.zhong@rivai.ai
State Unresolved
Headers
Series [Committed] RISC-V: Make known NITERS loop be aware of dynamic lmul cost model liveness information |

Checks

Context Check Description
snail/gcc-patch-check warning Git am fail log

Commit Message

juzhe.zhong@rivai.ai Dec. 27, 2023, 8:16 a.m. UTC
  Consider this following case:

int f[12][100];

void bad1(int v1, int v2)
{
  for (int r = 0; r < 100; r += 4)
    {
      int i = r + 1;
      f[0][r] = f[1][r] * (f[2][r]) - f[1][i] * (f[2][i]);
      f[0][i] = f[1][r] * (f[2][i]) + f[1][i] * (f[2][r]);
      f[0][r+2] = f[1][r+2] * (f[2][r+2]) - f[1][i+2] * (f[2][i+2]);
      f[0][i+2] = f[1][r+2] * (f[2][i+2]) + f[1][i+2] * (f[2][r+2]);
    }
}

Pick up LMUL = 8 VLS blindly:

        lui     a4,%hi(f)
        addi    a4,a4,%lo(f)
        addi    sp,sp,-592
        addi    a3,a4,800
        lui     a5,%hi(.LANCHOR0)
        vl8re32.v       v24,0(a3)
        addi    a5,a5,%lo(.LANCHOR0)
        addi    a1,a4,400
        addi    a3,sp,140
        vl8re32.v       v16,0(a1)
        vl4re16.v       v4,0(a5)
        addi    a7,a5,192
        vs4r.v  v4,0(a3)
        addi    t0,a5,64
        addi    a3,sp,336
        li      t2,32
        addi    a2,a5,128
        vsetvli a5,zero,e32,m8,ta,ma
        vrgatherei16.vv v8,v16,v4
        vmul.vv v8,v8,v24
        vl8re32.v       v0,0(a7)
        vs8r.v  v8,0(a3)
        vmsltu.vx       v8,v0,t2
        addi    a3,sp,12
        addi    t2,sp,204
        vsm.v   v8,0(t2)
        vl4re16.v       v4,0(t0)
        vl4re16.v       v0,0(a2)
        vs4r.v  v4,0(a3)
        addi    t0,sp,336
        vrgatherei16.vv v8,v24,v4
        addi    a3,sp,208
        vrgatherei16.vv v24,v16,v0
        vs4r.v  v0,0(a3)
        vmul.vv v8,v8,v24
        vlm.v   v0,0(t2)
        vl8re32.v       v24,0(t0)
        addi    a3,sp,208
        vsub.vv v16,v24,v8
        addi    t6,a4,528
        vadd.vv v8,v24,v8
        addi    t5,a4,928
        vmerge.vvm      v8,v8,v16,v0
        addi    t3,a4,128
        vs8r.v  v8,0(a4)
        addi    t4,a4,1056
        addi    t1,a4,656
        addi    a0,a4,256
        addi    a6,a4,1184
        addi    a1,a4,784
        addi    a7,a4,384
        addi    a4,sp,140
        vl4re16.v       v0,0(a3)
        vl8re32.v       v24,0(t6)
        vl4re16.v       v4,0(a4)
        vrgatherei16.vv v16,v24,v0
        addi    a3,sp,12
        vs8r.v  v16,0(t0)
        vl8re32.v       v8,0(t5)
        vrgatherei16.vv v16,v24,v4
        vl4re16.v       v4,0(a3)
        vrgatherei16.vv v24,v8,v4
        vmul.vv v16,v16,v8
        vl8re32.v       v8,0(t0)
        vmul.vv v8,v8,v24
        vsub.vv v24,v16,v8
        vlm.v   v0,0(t2)
        addi    a3,sp,208
        vadd.vv v8,v8,v16
        vl8re32.v       v16,0(t4)
        vmerge.vvm      v8,v8,v24,v0
        vrgatherei16.vv v24,v16,v4
        vs8r.v  v24,0(t0)
        vl4re16.v       v28,0(a3)
        addi    a3,sp,464
        vs8r.v  v8,0(t3)
        vl8re32.v       v8,0(t1)
        vrgatherei16.vv v0,v8,v28
        vs8r.v  v0,0(a3)
        addi    a3,sp,140
        vl4re16.v       v24,0(a3)
        addi    a3,sp,464
        vrgatherei16.vv v0,v8,v24
        vl8re32.v       v24,0(t0)
        vmv8r.v v8,v0
        vl8re32.v       v0,0(a3)
        vmul.vv v8,v8,v16
        vmul.vv v24,v24,v0
        vsub.vv v16,v8,v24
        vadd.vv v8,v8,v24
        vsetivli        zero,4,e32,m8,ta,ma
        vle32.v v24,0(a6)
        vsetvli a4,zero,e32,m8,ta,ma
        addi    a4,sp,12
        vlm.v   v0,0(t2)
        vmerge.vvm      v8,v8,v16,v0
        vl4re16.v       v16,0(a4)
        vrgatherei16.vv v0,v24,v16
        vsetivli        zero,4,e32,m8,ta,ma
        vs8r.v  v0,0(a4)
        addi    a4,sp,208
        vl4re16.v       v0,0(a4)
        vs8r.v  v8,0(a0)
        vle32.v v16,0(a1)
        vsetvli a5,zero,e32,m8,ta,ma
        vrgatherei16.vv v8,v16,v0
        vs8r.v  v8,0(a4)
        addi    a4,sp,140
        vl4re16.v       v4,0(a4)
        addi    a5,sp,12
        vrgatherei16.vv v8,v16,v4
        vl8re32.v       v0,0(a5)
        vsetivli        zero,4,e32,m8,ta,ma
        addi    a5,sp,208
        vmv8r.v v16,v8
        vl8re32.v       v8,0(a5)
        vmul.vv v24,v24,v16
        vmul.vv v8,v0,v8
        vsub.vv v16,v24,v8
        vadd.vv v8,v8,v24
        vsetvli a5,zero,e8,m2,ta,ma
        vlm.v   v0,0(t2)
        vsetivli        zero,4,e32,m8,ta,ma
        vmerge.vvm      v8,v8,v16,v0
        vse32.v v8,0(a7)
        addi    sp,sp,592
        jr      ra

This patch makes loop with known NITERS be aware of liveness estimation, after this patch, choosing LMUL = 4:

	lui	a5,%hi(f)
	addi	a5,a5,%lo(f)
	addi	a3,a5,400
	addi	a4,a5,800
	vsetivli	zero,8,e32,m2,ta,ma
	vlseg4e32.v	v16,(a3)
	vlseg4e32.v	v8,(a4)
	vmul.vv	v2,v8,v16
	addi	a3,a5,528
	vmv.v.v	v24,v10
	vnmsub.vv	v24,v18,v2
	addi	a4,a5,928
	vmul.vv	v2,v12,v22
	vmul.vv	v6,v8,v18
	vmv.v.v	v30,v2
	vmacc.vv	v30,v14,v20
	vmv.v.v	v26,v6
	vmacc.vv	v26,v10,v16
	vmul.vv	v4,v12,v20
	vmv.v.v	v28,v14
	vnmsub.vv	v28,v22,v4
	vsseg4e32.v	v24,(a5)
	vlseg4e32.v	v16,(a3)
	vlseg4e32.v	v8,(a4)
	vmul.vv	v2,v8,v16
	addi	a6,a5,128
	vmv.v.v	v24,v10
	vnmsub.vv	v24,v18,v2
	addi	a0,a5,656
	vmul.vv	v2,v12,v22
	addi	a1,a5,1056
	vmv.v.v	v30,v2
	vmacc.vv	v30,v14,v20
	vmul.vv	v6,v8,v18
	vmul.vv	v4,v12,v20
	vmv.v.v	v26,v6
	vmacc.vv	v26,v10,v16
	vmv.v.v	v28,v14
	vnmsub.vv	v28,v22,v4
	vsseg4e32.v	v24,(a6)
	vlseg4e32.v	v16,(a0)
	vlseg4e32.v	v8,(a1)
	vmul.vv	v2,v8,v16
	addi	a2,a5,256
	vmv.v.v	v24,v10
	vnmsub.vv	v24,v18,v2
	addi	a3,a5,784
	vmul.vv	v2,v12,v22
	addi	a4,a5,1184
	vmv.v.v	v30,v2
	vmacc.vv	v30,v14,v20
	vmul.vv	v6,v8,v18
	vmul.vv	v4,v12,v20
	vmv.v.v	v26,v6
	vmacc.vv	v26,v10,v16
	vmv.v.v	v28,v14
	vnmsub.vv	v28,v22,v4
	addi	a5,a5,384
	vsseg4e32.v	v24,(a2)
	vsetivli	zero,1,e32,m2,ta,ma
	vlseg4e32.v	v16,(a3)
	vlseg4e32.v	v8,(a4)
	vmul.vv	v2,v16,v8
	vmul.vv	v6,v18,v8
	vmv.v.v	v24,v18
	vnmsub.vv	v24,v10,v2
	vmul.vv	v4,v20,v12
	vmul.vv	v2,v22,v12
	vmv.v.v	v26,v6
	vmacc.vv	v26,v16,v10
	vmv.v.v	v28,v22
	vnmsub.vv	v28,v14,v4
	vmv.v.v	v30,v2
	vmacc.vv	v30,v20,v14
	vsseg4e32.v	v24,(a5)
	ret

Tested on both RV32 and RV64 no regressions.

	PR target/113112

gcc/ChangeLog:

	* config/riscv/riscv-vector-costs.cc (is_gimple_assign_or_call): New function.
	(get_first_lane_point): Ditto.
	(get_last_lane_point): Ditto.
	(max_number_of_live_regs): Refine live point dump.
	(compute_estimated_lmul): Make unknown NITERS loop be aware of liveness.
	(costs::better_main_loop_than_p): Ditto.
	* config/riscv/riscv-vector-costs.h (struct stmt_point): Add new member.

gcc/testsuite/ChangeLog:

	* gcc.dg/vect/costmodel/riscv/rvv/pr113112-1.c:
	* gcc.dg/vect/costmodel/riscv/rvv/pr113112-3.c: New test.

---
 gcc/config/riscv/riscv-vector-costs.cc        | 91 ++++++++++++++++---
 gcc/config/riscv/riscv-vector-costs.h         |  1 +
 .../vect/costmodel/riscv/rvv/pr113112-1.c     |  6 +-
 .../vect/costmodel/riscv/rvv/pr113112-3.c     | 20 ++++
 4 files changed, 101 insertions(+), 17 deletions(-)
 create mode 100644 gcc/testsuite/gcc.dg/vect/costmodel/riscv/rvv/pr113112-3.c
  

Patch

diff --git a/gcc/config/riscv/riscv-vector-costs.cc b/gcc/config/riscv/riscv-vector-costs.cc
index 74b8e86a5e1..df3c0b0d93a 100644
--- a/gcc/config/riscv/riscv-vector-costs.cc
+++ b/gcc/config/riscv/riscv-vector-costs.cc
@@ -88,6 +88,39 @@  namespace riscv_vector {
 	 3. M1(M8) -> MF2(M4) -> MF4(M2) -> MF8(M1)
 */
 
+static bool
+is_gimple_assign_or_call (gimple_stmt_iterator si)
+{
+  return is_gimple_assign (gsi_stmt (si)) || is_gimple_call (gsi_stmt (si));
+}
+
+/* Return the program point of 1st vectorized lanes statement.  */
+static unsigned int
+get_first_lane_point (const vec<stmt_point> program_points,
+		      stmt_vec_info stmt_info)
+{
+  for (const auto program_point : program_points)
+    if (program_point.stmt_info == DR_GROUP_FIRST_ELEMENT (stmt_info))
+      return program_point.point;
+  return 0;
+}
+
+/* Return the program point of last vectorized lanes statement.  */
+static unsigned int
+get_last_lane_point (const vec<stmt_point> program_points,
+		     stmt_vec_info stmt_info)
+{
+  unsigned int max_point = 0;
+  for (auto s = DR_GROUP_FIRST_ELEMENT (stmt_info); s != NULL;
+       s = DR_GROUP_NEXT_ELEMENT (s))
+    {
+      for (const auto program_point : program_points)
+	if (program_point.stmt_info == s && program_point.point > max_point)
+	  max_point = program_point.point;
+    }
+  return max_point;
+}
+
 /* Collect all STMTs that are vectorized and compute their program points.
    Note that we don't care about the STMTs that are not vectorized and
    we only build the local graph (within a block) of program points.
@@ -132,15 +165,14 @@  compute_local_program_points (
 			     bb->index);
 	  for (si = gsi_start_bb (bbs[i]); !gsi_end_p (si); gsi_next (&si))
 	    {
-	      if (!(is_gimple_assign (gsi_stmt (si))
-		    || is_gimple_call (gsi_stmt (si))))
+	      if (!is_gimple_assign_or_call (si))
 		continue;
 	      stmt_vec_info stmt_info = vinfo->lookup_stmt (gsi_stmt (si));
 	      enum stmt_vec_info_type type
 		= STMT_VINFO_TYPE (vect_stmt_to_vectorize (stmt_info));
 	      if (type != undef_vec_info_type)
 		{
-		  stmt_point info = {point, gsi_stmt (si)};
+		  stmt_point info = {point, gsi_stmt (si), stmt_info};
 		  program_points.safe_push (info);
 		  point++;
 		  if (dump_enabled_p ())
@@ -219,6 +251,10 @@  compute_local_live_ranges (
 		  pair &live_range
 		    = live_ranges->get_or_insert (lhs, &existed_p);
 		  gcc_assert (!existed_p);
+		  if (STMT_VINFO_MEMORY_ACCESS_TYPE (program_point.stmt_info)
+		      == VMAT_LOAD_STORE_LANES)
+		    point = get_first_lane_point (program_points,
+						  program_point.stmt_info);
 		  live_range = pair (point, point);
 		}
 	      for (i = 0; i < gimple_num_args (stmt); i++)
@@ -241,6 +277,11 @@  compute_local_live_ranges (
 		      bool existed_p = false;
 		      pair &live_range
 			= live_ranges->get_or_insert (var, &existed_p);
+		      if (STMT_VINFO_MEMORY_ACCESS_TYPE (
+			    program_point.stmt_info)
+			  == VMAT_LOAD_STORE_LANES)
+			point = get_last_lane_point (program_points,
+						     program_point.stmt_info);
 		      if (existed_p)
 			/* We will grow the live range for each use.  */
 			live_range = pair (live_range.first, point);
@@ -313,7 +354,10 @@  max_number_of_live_regs (const basic_block bb,
 	    = compute_nregs_for_mode (mode, biggest_mode, lmul);
 	  live_vars_vec[i] += nregs;
 	  if (live_vars_vec[i] > max_nregs)
-	    max_nregs = live_vars_vec[i];
+	    {
+	      max_nregs = live_vars_vec[i];
+	      live_point = i;
+	    }
 	}
     }
 
@@ -396,8 +440,7 @@  compute_estimated_lmul (loop_vec_info loop_vinfo, machine_mode mode)
   int regno_alignment = riscv_get_v_regno_alignment (loop_vinfo->vector_mode);
   if (riscv_v_ext_vls_mode_p (loop_vinfo->vector_mode))
     return regno_alignment;
-  else if (known_eq (LOOP_VINFO_SLP_UNROLLING_FACTOR (loop_vinfo), 1U)
-	   || LOOP_VINFO_SLP_UNROLLING_FACTOR (loop_vinfo).is_constant ())
+  else if (known_eq (LOOP_VINFO_SLP_UNROLLING_FACTOR (loop_vinfo), 1U))
     {
       int estimated_vf = vect_vf_for_cost (loop_vinfo);
       return estimated_vf * GET_MODE_BITSIZE (mode).to_constant ()
@@ -408,7 +451,8 @@  compute_estimated_lmul (loop_vec_info loop_vinfo, machine_mode mode)
       /* Estimate the VLA SLP LMUL.  */
       if (regno_alignment > RVV_M1)
 	return regno_alignment;
-      else if (mode != QImode)
+      else if (mode != QImode
+	       || LOOP_VINFO_SLP_UNROLLING_FACTOR (loop_vinfo).is_constant ())
 	{
 	  int ratio;
 	  if (can_div_trunc_p (BYTES_PER_RISCV_VECTOR,
@@ -507,7 +551,7 @@  update_local_live_ranges (
 		      auto &program_points = (*program_points_per_bb.get (bb));
 		      if (program_points.is_empty ())
 			{
-			  stmt_point info = {1, phi};
+			  stmt_point info = {1, phi, stmt_info};
 			  program_points.safe_push (info);
 			}
 		      if (dump_enabled_p ())
@@ -545,8 +589,7 @@  update_local_live_ranges (
 	}
       for (si = gsi_start_bb (bb); !gsi_end_p (si); gsi_next (&si))
 	{
-	  if (!(is_gimple_assign (gsi_stmt (si))
-		|| is_gimple_call (gsi_stmt (si))))
+	  if (!is_gimple_assign_or_call (si))
 	    continue;
 	  stmt_vec_info stmt_info = vinfo->lookup_stmt (gsi_stmt (si));
 	  enum stmt_vec_info_type type
@@ -802,8 +845,7 @@  costs::better_main_loop_than_p (const vector_costs *uncast_other) const
 	  return other_prefer_unrolled;
 	}
     }
-  else if (riscv_autovec_lmul == RVV_DYNAMIC
-	   && !LOOP_VINFO_NITERS_KNOWN_P (other_loop_vinfo))
+  else if (riscv_autovec_lmul == RVV_DYNAMIC)
     {
       if (other->m_has_unexpected_spills_p)
 	{
@@ -813,8 +855,29 @@  costs::better_main_loop_than_p (const vector_costs *uncast_other) const
 			     " it has unexpected spills\n");
 	  return true;
 	}
-      else
-	return false;
+      else if (riscv_v_ext_vector_mode_p (other_loop_vinfo->vector_mode))
+	{
+	  if (LOOP_VINFO_NITERS_KNOWN_P (other_loop_vinfo))
+	    {
+	      if (maybe_gt (LOOP_VINFO_INT_NITERS (this_loop_vinfo),
+			    LOOP_VINFO_VECT_FACTOR (this_loop_vinfo)))
+		{
+		  if (dump_enabled_p ())
+		    dump_printf_loc (MSG_NOTE, vect_location,
+				     "Keep current LMUL loop because"
+				     " known NITERS exceed the new VF\n");
+		  return false;
+		}
+	    }
+	  else
+	    {
+	      if (dump_enabled_p ())
+		dump_printf_loc (MSG_NOTE, vect_location,
+				 "Keep current LMUL loop because"
+				 " it is unknown NITERS\n");
+	      return false;
+	    }
+	}
     }
 
   return vector_costs::better_main_loop_than_p (other);
diff --git a/gcc/config/riscv/riscv-vector-costs.h b/gcc/config/riscv/riscv-vector-costs.h
index ed7fff94d07..36c70fefdd8 100644
--- a/gcc/config/riscv/riscv-vector-costs.h
+++ b/gcc/config/riscv/riscv-vector-costs.h
@@ -28,6 +28,7 @@  struct stmt_point
   /* Program point.  */
   unsigned int point;
   gimple *stmt;
+  stmt_vec_info stmt_info;
 };
 
 enum cost_type_enum
diff --git a/gcc/testsuite/gcc.dg/vect/costmodel/riscv/rvv/pr113112-1.c b/gcc/testsuite/gcc.dg/vect/costmodel/riscv/rvv/pr113112-1.c
index cd0fe19b98d..95df7809d49 100644
--- a/gcc/testsuite/gcc.dg/vect/costmodel/riscv/rvv/pr113112-1.c
+++ b/gcc/testsuite/gcc.dg/vect/costmodel/riscv/rvv/pr113112-1.c
@@ -24,6 +24,6 @@  foo (int n){
 /* { dg-final { scan-assembler-not {jr} } } */
 /* { dg-final { scan-assembler-times {ret} 1 } } */
 /* { dg-final { scan-tree-dump-times "Preferring smaller LMUL loop because it has unexpected spills" 1 "vect" } } */
-/* { dg-final { scan-tree-dump "At most 8 number of live V_REG at program point 0 for bb 4" "vect" } } */
-/* { dg-final { scan-tree-dump "At most 40 number of live V_REG at program point 0 for bb 3" "vect" } } */
-/* { dg-final { scan-tree-dump "At most 8 number of live V_REG at program point 0 for bb 5" "vect" } } */
+/* { dg-final { scan-tree-dump "At most 8 number of live V_REG at program point 1 for bb 4" "vect" } } */
+/* { dg-final { scan-tree-dump "At most 40 number of live V_REG at program point 1 for bb 3" "vect" } } */
+/* { dg-final { scan-tree-dump "At most 8 number of live V_REG at program point 1 for bb 5" "vect" } } */
diff --git a/gcc/testsuite/gcc.dg/vect/costmodel/riscv/rvv/pr113112-3.c b/gcc/testsuite/gcc.dg/vect/costmodel/riscv/rvv/pr113112-3.c
new file mode 100644
index 00000000000..c80936246d7
--- /dev/null
+++ b/gcc/testsuite/gcc.dg/vect/costmodel/riscv/rvv/pr113112-3.c
@@ -0,0 +1,20 @@ 
+/* { dg-do compile } */
+/* { dg-options "-march=rv64gcv -mabi=lp64d -O3 -ftree-vectorize --param riscv-autovec-lmul=dynamic --param riscv-autovec-preference=fixed-vlmax" } */
+
+int f[12][100];
+
+void bad1(int v1, int v2)
+{
+  for (int r = 0; r < 100; r += 4)
+    {
+      int i = r + 1;
+      f[0][r] = f[1][r] * (f[2][r]) - f[1][i] * (f[2][i]);
+      f[0][i] = f[1][r] * (f[2][i]) + f[1][i] * (f[2][r]);
+      f[0][r+2] = f[1][r+2] * (f[2][r+2]) - f[1][i+2] * (f[2][i+2]);
+      f[0][i+2] = f[1][r+2] * (f[2][i+2]) + f[1][i+2] * (f[2][r+2]);
+    }
+}
+
+/* { dg-final { scan-assembler {e32,m2} } } */
+/* { dg-final { scan-assembler-not {jr} } } */
+/* { dg-final { scan-assembler-times {ret} 1 } } */