RISC-V: Add local user vsetvl instruction elimination

Message ID 20230407013413.127686-1-juzhe.zhong@rivai.ai
State Accepted
Headers
Series RISC-V: Add local user vsetvl instruction elimination |

Checks

Context Check Description
snail/gcc-patch-check success Github commit url

Commit Message

juzhe.zhong@rivai.ai April 7, 2023, 1:34 a.m. UTC
  From: Juzhe-Zhong <juzhe.zhong@rivai.ai>

This patch is to enhance optimization for auto-vectorization.

Before this patch:

Loop:
vsetvl a5,a2...
vsetvl zero,a5...
vle

After this patch:

Loop:
vsetvl a5,a2
vle

gcc/ChangeLog:

        * config/riscv/riscv-vsetvl.cc (local_eliminate_vsetvl_insn): New function.
        (vector_insn_info::skip_avl_compatible_p): Ditto.
        (vector_insn_info::merge): Remove default value.
        (pass_vsetvl::compute_local_backward_infos): Ditto.
        (pass_vsetvl::cleanup_insns): Add local vsetvl elimination.
        * config/riscv/riscv-vsetvl.h: Ditto.

---
 gcc/config/riscv/riscv-vsetvl.cc | 71 +++++++++++++++++++++++++++++++-
 gcc/config/riscv/riscv-vsetvl.h  |  1 +
 2 files changed, 70 insertions(+), 2 deletions(-)
  

Comments

Kito Cheng April 21, 2023, 6:50 a.m. UTC | #1
Committed with an extra testcase from PR109547

https://gcc.gnu.org/pipermail/gcc-patches/2023-April/616363.html

On Fri, Apr 7, 2023 at 9:34 AM <juzhe.zhong@rivai.ai> wrote:
>
> From: Juzhe-Zhong <juzhe.zhong@rivai.ai>
>
> This patch is to enhance optimization for auto-vectorization.
>
> Before this patch:
>
> Loop:
> vsetvl a5,a2...
> vsetvl zero,a5...
> vle
>
> After this patch:
>
> Loop:
> vsetvl a5,a2
> vle
>
> gcc/ChangeLog:
>
>         * config/riscv/riscv-vsetvl.cc (local_eliminate_vsetvl_insn): New function.
>         (vector_insn_info::skip_avl_compatible_p): Ditto.
>         (vector_insn_info::merge): Remove default value.
>         (pass_vsetvl::compute_local_backward_infos): Ditto.
>         (pass_vsetvl::cleanup_insns): Add local vsetvl elimination.
>         * config/riscv/riscv-vsetvl.h: Ditto.
>
> ---
>  gcc/config/riscv/riscv-vsetvl.cc | 71 +++++++++++++++++++++++++++++++-
>  gcc/config/riscv/riscv-vsetvl.h  |  1 +
>  2 files changed, 70 insertions(+), 2 deletions(-)
>
> diff --git a/gcc/config/riscv/riscv-vsetvl.cc b/gcc/config/riscv/riscv-vsetvl.cc
> index 7e8a5376705..b402035f7a5 100644
> --- a/gcc/config/riscv/riscv-vsetvl.cc
> +++ b/gcc/config/riscv/riscv-vsetvl.cc
> @@ -1054,6 +1054,51 @@ change_vsetvl_insn (const insn_info *insn, const vector_insn_info &info)
>    change_insn (rinsn, new_pat);
>  }
>
> +static void
> +local_eliminate_vsetvl_insn (const vector_insn_info &dem)
> +{
> +  const insn_info *insn = dem.get_insn ();
> +  if (!insn || insn->is_artificial ())
> +    return;
> +  rtx_insn *rinsn = insn->rtl ();
> +  const bb_info *bb = insn->bb ();
> +  if (vsetvl_insn_p (rinsn))
> +    {
> +      rtx vl = get_vl (rinsn);
> +      for (insn_info *i = insn->next_nondebug_insn ();
> +          real_insn_and_same_bb_p (i, bb); i = i->next_nondebug_insn ())
> +       {
> +         if (i->is_call () || i->is_asm ()
> +             || find_access (i->defs (), VL_REGNUM)
> +             || find_access (i->defs (), VTYPE_REGNUM))
> +           return;
> +
> +         if (has_vtype_op (i->rtl ()))
> +           {
> +             if (!vsetvl_discard_result_insn_p (PREV_INSN (i->rtl ())))
> +               return;
> +             rtx avl = get_avl (i->rtl ());
> +             if (avl != vl)
> +               return;
> +             set_info *def = find_access (i->uses (), REGNO (avl))->def ();
> +             if (def->insn () != insn)
> +               return;
> +
> +             vector_insn_info new_info;
> +             new_info.parse_insn (i);
> +             if (!new_info.skip_avl_compatible_p (dem))
> +               return;
> +
> +             new_info.set_avl_info (dem.get_avl_info ());
> +             new_info = dem.merge (new_info, LOCAL_MERGE);
> +             change_vsetvl_insn (insn, new_info);
> +             eliminate_insn (PREV_INSN (i->rtl ()));
> +             return;
> +           }
> +       }
> +    }
> +}
> +
>  static bool
>  source_equal_p (insn_info *insn1, insn_info *insn2)
>  {
> @@ -1984,6 +2029,19 @@ vector_insn_info::compatible_p (const vector_insn_info &other) const
>    return true;
>  }
>
> +bool
> +vector_insn_info::skip_avl_compatible_p (const vector_insn_info &other) const
> +{
> +  gcc_assert (valid_or_dirty_p () && other.valid_or_dirty_p ()
> +             && "Can't compare invalid demanded infos");
> +  unsigned array_size = sizeof (incompatible_conds) / sizeof (demands_cond);
> +  /* Bypass AVL incompatible cases.  */
> +  for (unsigned i = 1; i < array_size; i++)
> +    if (incompatible_conds[i].dual_incompatible_p (*this, other))
> +      return false;
> +  return true;
> +}
> +
>  bool
>  vector_insn_info::compatible_avl_p (const vl_vtype_info &other) const
>  {
> @@ -2178,7 +2236,7 @@ vector_insn_info::fuse_mask_policy (const vector_insn_info &info1,
>
>  vector_insn_info
>  vector_insn_info::merge (const vector_insn_info &merge_info,
> -                        enum merge_type type = LOCAL_MERGE) const
> +                        enum merge_type type) const
>  {
>    if (!vsetvl_insn_p (get_insn ()->rtl ()))
>      gcc_assert (this->compatible_p (merge_info)
> @@ -2716,7 +2774,7 @@ pass_vsetvl::compute_local_backward_infos (const bb_info *bb)
>                     && !reg_available_p (insn, change))
>                   && change.compatible_p (info))
>                 {
> -                 info = change.merge (info);
> +                 info = change.merge (info, LOCAL_MERGE);
>                   /* Fix PR109399, we should update user vsetvl instruction
>                      if there is a change in demand fusion.  */
>                   if (vsetvl_insn_p (insn->rtl ()))
> @@ -3998,6 +4056,15 @@ pass_vsetvl::cleanup_insns (void) const
>        for (insn_info *insn : bb->real_nondebug_insns ())
>         {
>           rtx_insn *rinsn = insn->rtl ();
> +         const auto &dem = m_vector_manager->vector_insn_infos[insn->uid ()];
> +         /* Eliminate local vsetvl:
> +              bb 0:
> +              vsetvl a5,a6,...
> +              vsetvl zero,a5.
> +
> +            Eliminate vsetvl in bb2 when a5 is only coming from
> +            bb 0.  */
> +         local_eliminate_vsetvl_insn (dem);
>
>           if (vlmax_avl_insn_p (rinsn))
>             {
> diff --git a/gcc/config/riscv/riscv-vsetvl.h b/gcc/config/riscv/riscv-vsetvl.h
> index d05472c86a0..d7a6c14e931 100644
> --- a/gcc/config/riscv/riscv-vsetvl.h
> +++ b/gcc/config/riscv/riscv-vsetvl.h
> @@ -380,6 +380,7 @@ public:
>    void fuse_mask_policy (const vector_insn_info &, const vector_insn_info &);
>
>    bool compatible_p (const vector_insn_info &) const;
> +  bool skip_avl_compatible_p (const vector_insn_info &) const;
>    bool compatible_avl_p (const vl_vtype_info &) const;
>    bool compatible_avl_p (const avl_info &) const;
>    bool compatible_vtype_p (const vl_vtype_info &) const;
> --
> 2.36.3
>
  

Patch

diff --git a/gcc/config/riscv/riscv-vsetvl.cc b/gcc/config/riscv/riscv-vsetvl.cc
index 7e8a5376705..b402035f7a5 100644
--- a/gcc/config/riscv/riscv-vsetvl.cc
+++ b/gcc/config/riscv/riscv-vsetvl.cc
@@ -1054,6 +1054,51 @@  change_vsetvl_insn (const insn_info *insn, const vector_insn_info &info)
   change_insn (rinsn, new_pat);
 }
 
+static void
+local_eliminate_vsetvl_insn (const vector_insn_info &dem)
+{
+  const insn_info *insn = dem.get_insn ();
+  if (!insn || insn->is_artificial ())
+    return;
+  rtx_insn *rinsn = insn->rtl ();
+  const bb_info *bb = insn->bb ();
+  if (vsetvl_insn_p (rinsn))
+    {
+      rtx vl = get_vl (rinsn);
+      for (insn_info *i = insn->next_nondebug_insn ();
+	   real_insn_and_same_bb_p (i, bb); i = i->next_nondebug_insn ())
+	{
+	  if (i->is_call () || i->is_asm ()
+	      || find_access (i->defs (), VL_REGNUM)
+	      || find_access (i->defs (), VTYPE_REGNUM))
+	    return;
+
+	  if (has_vtype_op (i->rtl ()))
+	    {
+	      if (!vsetvl_discard_result_insn_p (PREV_INSN (i->rtl ())))
+		return;
+	      rtx avl = get_avl (i->rtl ());
+	      if (avl != vl)
+		return;
+	      set_info *def = find_access (i->uses (), REGNO (avl))->def ();
+	      if (def->insn () != insn)
+		return;
+
+	      vector_insn_info new_info;
+	      new_info.parse_insn (i);
+	      if (!new_info.skip_avl_compatible_p (dem))
+		return;
+
+	      new_info.set_avl_info (dem.get_avl_info ());
+	      new_info = dem.merge (new_info, LOCAL_MERGE);
+	      change_vsetvl_insn (insn, new_info);
+	      eliminate_insn (PREV_INSN (i->rtl ()));
+	      return;
+	    }
+	}
+    }
+}
+
 static bool
 source_equal_p (insn_info *insn1, insn_info *insn2)
 {
@@ -1984,6 +2029,19 @@  vector_insn_info::compatible_p (const vector_insn_info &other) const
   return true;
 }
 
+bool
+vector_insn_info::skip_avl_compatible_p (const vector_insn_info &other) const
+{
+  gcc_assert (valid_or_dirty_p () && other.valid_or_dirty_p ()
+	      && "Can't compare invalid demanded infos");
+  unsigned array_size = sizeof (incompatible_conds) / sizeof (demands_cond);
+  /* Bypass AVL incompatible cases.  */
+  for (unsigned i = 1; i < array_size; i++)
+    if (incompatible_conds[i].dual_incompatible_p (*this, other))
+      return false;
+  return true;
+}
+
 bool
 vector_insn_info::compatible_avl_p (const vl_vtype_info &other) const
 {
@@ -2178,7 +2236,7 @@  vector_insn_info::fuse_mask_policy (const vector_insn_info &info1,
 
 vector_insn_info
 vector_insn_info::merge (const vector_insn_info &merge_info,
-			 enum merge_type type = LOCAL_MERGE) const
+			 enum merge_type type) const
 {
   if (!vsetvl_insn_p (get_insn ()->rtl ()))
     gcc_assert (this->compatible_p (merge_info)
@@ -2716,7 +2774,7 @@  pass_vsetvl::compute_local_backward_infos (const bb_info *bb)
 		    && !reg_available_p (insn, change))
 		  && change.compatible_p (info))
 		{
-		  info = change.merge (info);
+		  info = change.merge (info, LOCAL_MERGE);
 		  /* Fix PR109399, we should update user vsetvl instruction
 		     if there is a change in demand fusion.  */
 		  if (vsetvl_insn_p (insn->rtl ()))
@@ -3998,6 +4056,15 @@  pass_vsetvl::cleanup_insns (void) const
       for (insn_info *insn : bb->real_nondebug_insns ())
 	{
 	  rtx_insn *rinsn = insn->rtl ();
+	  const auto &dem = m_vector_manager->vector_insn_infos[insn->uid ()];
+	  /* Eliminate local vsetvl:
+	       bb 0:
+	       vsetvl a5,a6,...
+	       vsetvl zero,a5.
+
+	     Eliminate vsetvl in bb2 when a5 is only coming from
+	     bb 0.  */
+	  local_eliminate_vsetvl_insn (dem);
 
 	  if (vlmax_avl_insn_p (rinsn))
 	    {
diff --git a/gcc/config/riscv/riscv-vsetvl.h b/gcc/config/riscv/riscv-vsetvl.h
index d05472c86a0..d7a6c14e931 100644
--- a/gcc/config/riscv/riscv-vsetvl.h
+++ b/gcc/config/riscv/riscv-vsetvl.h
@@ -380,6 +380,7 @@  public:
   void fuse_mask_policy (const vector_insn_info &, const vector_insn_info &);
 
   bool compatible_p (const vector_insn_info &) const;
+  bool skip_avl_compatible_p (const vector_insn_info &) const;
   bool compatible_avl_p (const vl_vtype_info &) const;
   bool compatible_avl_p (const avl_info &) const;
   bool compatible_vtype_p (const vl_vtype_info &) const;