[v3] RISC-V: Implement RESOLVE_OVERLOADED_BUILTIN for RVV intrinsic
Checks
Commit Message
From: Pan Li <pan2.li@intel.com>
Update in v3:
* Rewrite comment for overloaded function add.
* Move get_non_overloaded_instance to function_base.
Update in v2:
* Add get_non_overloaded_instance for function instance.
* Fix overload check for policy function.
* Enrich the test cases check.
Original log:
This patch would like add the framework to support the RVV overloaded
intrinsic API in riscv-xxx-xxx-gcc, like riscv-xxx-xxx-g++ did.
However, it almost leverage the hook TARGET_RESOLVE_OVERLOADED_BUILTIN
with below steps.
* Register overloaded functions.
* Add function_resolver for overloaded function resolving.
* Add resolve API for function shape with default implementation.
* Implement HOOK for navigating the overloaded API to non-overloaded API.
We validated this framework by the vmv_v intrinsic API(s), and we will
add more intrins API support in the underlying patches.
gcc/ChangeLog:
* config/riscv/riscv-c.cc
(riscv_resolve_overloaded_builtin): New function for the hook.
(riscv_register_pragmas): Register the hook
* config/riscv/riscv-protos.h (resolve_overloaded_builtin): New decl.
* config/riscv/riscv-vector-builtins-shapes.cc (build_one):
Register overloaded function.
(struct overloaded_base): New struct for overloaded shape.
(struct non_overloaded_base): New struct for non overloaded shape.
(struct move_def): Inherit overloaded shape.
* config/riscv/riscv-vector-builtins.cc
(function_base::get_non_overloaded_instance): New API impl.
(function_builder::add_function): Add overloaded arg.
(function_resolver::function_resolver): New constructor.
(function_builder::add_overloaded_function): New API impl.
(function_resolver::resolve): Ditto.
(function_resolver::lookup): Ditto.
(function_resolver::get_sub_code): Ditto.
(resolve_overloaded_builtin): New function impl.
* config/riscv/riscv-vector-builtins.h:
(class function_resolver): New class.
gcc/testsuite/ChangeLog:
* gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c: New test.
* gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c: New test.
* gcc.target/riscv/rvv/base/overloaded_vmv_v.h: New test.
Signed-off-by: Pan Li <pan2.li@intel.com>
---
gcc/config/riscv/riscv-c.cc | 36 ++++
gcc/config/riscv/riscv-protos.h | 1 +
.../riscv/riscv-vector-builtins-shapes.cc | 20 ++-
gcc/config/riscv/riscv-vector-builtins.cc | 155 +++++++++++++++++-
gcc/config/riscv/riscv-vector-builtins.h | 36 +++-
.../riscv/rvv/base/overloaded_rv32_vmv_v.c | 8 +
.../riscv/rvv/base/overloaded_rv64_vmv_v.c | 8 +
.../riscv/rvv/base/overloaded_vmv_v.h | 27 +++
8 files changed, 288 insertions(+), 3 deletions(-)
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
Comments
It looks reasonable to me now.
But let's wait for kito's more comments.
juzhe.zhong@rivai.ai
From: pan2.li
Date: 2023-09-12 16:46
To: gcc-patches
CC: juzhe.zhong; pan2.li; yanzhang.wang; kito.cheng
Subject: [PATCH v3] RISC-V: Implement RESOLVE_OVERLOADED_BUILTIN for RVV intrinsic
From: Pan Li <pan2.li@intel.com>
Update in v3:
* Rewrite comment for overloaded function add.
* Move get_non_overloaded_instance to function_base.
Update in v2:
* Add get_non_overloaded_instance for function instance.
* Fix overload check for policy function.
* Enrich the test cases check.
Original log:
This patch would like add the framework to support the RVV overloaded
intrinsic API in riscv-xxx-xxx-gcc, like riscv-xxx-xxx-g++ did.
However, it almost leverage the hook TARGET_RESOLVE_OVERLOADED_BUILTIN
with below steps.
* Register overloaded functions.
* Add function_resolver for overloaded function resolving.
* Add resolve API for function shape with default implementation.
* Implement HOOK for navigating the overloaded API to non-overloaded API.
We validated this framework by the vmv_v intrinsic API(s), and we will
add more intrins API support in the underlying patches.
gcc/ChangeLog:
* config/riscv/riscv-c.cc
(riscv_resolve_overloaded_builtin): New function for the hook.
(riscv_register_pragmas): Register the hook
* config/riscv/riscv-protos.h (resolve_overloaded_builtin): New decl.
* config/riscv/riscv-vector-builtins-shapes.cc (build_one):
Register overloaded function.
(struct overloaded_base): New struct for overloaded shape.
(struct non_overloaded_base): New struct for non overloaded shape.
(struct move_def): Inherit overloaded shape.
* config/riscv/riscv-vector-builtins.cc
(function_base::get_non_overloaded_instance): New API impl.
(function_builder::add_function): Add overloaded arg.
(function_resolver::function_resolver): New constructor.
(function_builder::add_overloaded_function): New API impl.
(function_resolver::resolve): Ditto.
(function_resolver::lookup): Ditto.
(function_resolver::get_sub_code): Ditto.
(resolve_overloaded_builtin): New function impl.
* config/riscv/riscv-vector-builtins.h:
(class function_resolver): New class.
gcc/testsuite/ChangeLog:
* gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c: New test.
* gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c: New test.
* gcc.target/riscv/rvv/base/overloaded_vmv_v.h: New test.
Signed-off-by: Pan Li <pan2.li@intel.com>
---
gcc/config/riscv/riscv-c.cc | 36 ++++
gcc/config/riscv/riscv-protos.h | 1 +
.../riscv/riscv-vector-builtins-shapes.cc | 20 ++-
gcc/config/riscv/riscv-vector-builtins.cc | 155 +++++++++++++++++-
gcc/config/riscv/riscv-vector-builtins.h | 36 +++-
.../riscv/rvv/base/overloaded_rv32_vmv_v.c | 8 +
.../riscv/rvv/base/overloaded_rv64_vmv_v.c | 8 +
.../riscv/rvv/base/overloaded_vmv_v.h | 27 +++
8 files changed, 288 insertions(+), 3 deletions(-)
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
diff --git a/gcc/config/riscv/riscv-c.cc b/gcc/config/riscv/riscv-c.cc
index 283052ae313..060edd3129d 100644
--- a/gcc/config/riscv/riscv-c.cc
+++ b/gcc/config/riscv/riscv-c.cc
@@ -220,11 +220,47 @@ riscv_check_builtin_call (location_t loc, vec<location_t> arg_loc, tree fndecl,
gcc_unreachable ();
}
+/* Implement TARGET_RESOLVE_OVERLOADED_BUILTIN. */
+static tree
+riscv_resolve_overloaded_builtin (unsigned int uncast_location, tree fndecl,
+ void *uncast_arglist)
+{
+ vec<tree, va_gc> empty = {};
+ location_t loc = (location_t) uncast_location;
+ vec<tree, va_gc> *arglist = (vec<tree, va_gc> *) uncast_arglist;
+ unsigned int code = DECL_MD_FUNCTION_CODE (fndecl);
+ unsigned int subcode = code >> RISCV_BUILTIN_SHIFT;
+ tree new_fndecl = NULL_TREE;
+
+ if (!arglist)
+ arglist = ∅
+
+ switch (code & RISCV_BUILTIN_CLASS)
+ {
+ case RISCV_BUILTIN_GENERAL:
+ break;
+ case RISCV_BUILTIN_VECTOR:
+ new_fndecl = riscv_vector::resolve_overloaded_builtin (loc, subcode,
+ arglist);
+ break;
+ default:
+ gcc_unreachable ();
+ }
+
+ if (new_fndecl == NULL_TREE)
+ return new_fndecl;
+
+ return build_function_call_vec (loc, vNULL, new_fndecl, arglist, NULL,
+ fndecl);
+}
+
/* Implement REGISTER_TARGET_PRAGMAS. */
void
riscv_register_pragmas (void)
{
+ targetm.resolve_overloaded_builtin = riscv_resolve_overloaded_builtin;
targetm.check_builtin_call = riscv_check_builtin_call;
+
c_register_pragma ("riscv", "intrinsic", riscv_pragma_intrinsic);
}
diff --git a/gcc/config/riscv/riscv-protos.h b/gcc/config/riscv/riscv-protos.h
index 6dbf6b9f943..5d2492dd031 100644
--- a/gcc/config/riscv/riscv-protos.h
+++ b/gcc/config/riscv/riscv-protos.h
@@ -381,6 +381,7 @@ gimple *gimple_fold_builtin (unsigned int, gimple_stmt_iterator *, gcall *);
rtx expand_builtin (unsigned int, tree, rtx);
bool check_builtin_call (location_t, vec<location_t>, unsigned int,
tree, unsigned int, tree *);
+tree resolve_overloaded_builtin (location_t, unsigned int, vec<tree, va_gc> *);
bool const_vec_all_same_in_range_p (rtx, HOST_WIDE_INT, HOST_WIDE_INT);
bool legitimize_move (rtx, rtx);
void emit_vlmax_vsetvl (machine_mode, rtx);
diff --git a/gcc/config/riscv/riscv-vector-builtins-shapes.cc b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
index f8fdec863e6..1c1a2cc9488 100644
--- a/gcc/config/riscv/riscv-vector-builtins-shapes.cc
+++ b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
@@ -49,6 +49,8 @@ build_one (function_builder &b, const function_group_info &group,
group.ops_infos.types[vec_type_idx].index);
b.allocate_argument_types (function_instance, argument_types);
b.apply_predication (function_instance, return_type, argument_types);
+
+ b.add_overloaded_function (function_instance, *group.shape);
b.add_unique_function (function_instance, (*group.shape), return_type,
argument_types);
}
@@ -87,6 +89,22 @@ struct build_base : public function_shape
}
};
+struct overloaded_base : public build_base
+{
+ tree resolve (function_resolver &r) const override
+ {
+ return r.lookup ();
+ }
+};
+
+struct non_overloaded_base : public build_base
+{
+ tree resolve (function_resolver &) const override
+ {
+ gcc_unreachable ();
+ }
+};
+
/* vsetvl_def class. */
struct vsetvl_def : public build_base
{
@@ -525,7 +543,7 @@ struct narrow_alu_def : public build_base
};
/* move_def class. Handle vmv.v.v/vmv.v.x. */
-struct move_def : public build_base
+struct move_def : public overloaded_base
{
char *get_name (function_builder &b, const function_instance &instance,
bool overloaded_p) const override
diff --git a/gcc/config/riscv/riscv-vector-builtins.cc b/gcc/config/riscv/riscv-vector-builtins.cc
index 6d99f970ead..4f6fbdc3e28 100644
--- a/gcc/config/riscv/riscv-vector-builtins.cc
+++ b/gcc/config/riscv/riscv-vector-builtins.cc
@@ -80,6 +80,10 @@ public:
/* The decl itself. */
tree GTY ((skip)) decl;
+
+ /* True if the decl represents an overloaded function that needs to be
+ resolved by function_resolver. */
+ bool overloaded_p;
};
/* Hash traits for registered_function. */
@@ -3196,6 +3200,77 @@ function_instance::could_trap_p () const
return false;
}
+/* Try to get the non-overloaded function instance.
+ After we register the overloaded the functions, the registered functions
+ table may look like:
+
+ +--------+---------------------------+-------------------+
+ | index | name | kind |
+ +--------+---------------------------+-------------------+
+ | 124733 | __riscv_vmv_v | Overloaded | <- Hook fun code
+ +--------+---------------------------+-------------------+
+ | 124735 | __riscv_vmv_v_v_i8mf8 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124737 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+ | 124739 | __riscv_vmv_v | Overloaded |
+ +--------+---------------------------+-------------------+
+ | 124741 | __riscv_vmv_v_v_i8mf4 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124743 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+ | 124745 | __riscv_vmv_v | Overloaded |
+ +--------+---------------------------+-------------------+
+ | 124747 | __riscv_vmv_v_v_i8mf2 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124749 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+ | 124751 | __riscv_vmv_v | Overloaded |
+ +--------+---------------------------+-------------------+
+ | 124753 | __riscv_vmv_v_v_i8m1 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124755 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+
+ When we resolve the overloaded API from the hook, we always get the first
+ function code of one API group (aka vmv_v as above table). We will search
+ start from that index to find the only one non-overloaded API with exactly
+ the same arglist. Or NULL instance will be returned.
+ */
+function_instance *
+function_base::get_non_overloaded_instance (unsigned int code,
+ vec<tree, va_gc> &arglist) const
+{
+ unsigned int code_limit = vec_safe_length (registered_functions);
+
+ for (unsigned fun_code = code; fun_code < code_limit; fun_code++)
+ {
+ registered_function *rfun = (*registered_functions)[fun_code];
+ function_instance instance = rfun->instance;
+
+ if (rfun->overloaded_p)
+ continue;
+
+ unsigned k;
+ const rvv_arg_type_info *args = instance.op_info->args;
+
+ for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
+ {
+ if (k >= arglist.length ())
+ break;
+
+ if (TYPE_MODE (instance.get_arg_type (k))
+ != TYPE_MODE (TREE_TYPE (arglist[k])))
+ break;
+ }
+
+ if (args[k].base_type == NUM_BASE_TYPES)
+ return &rfun->instance;
+ }
+
+ return NULL;
+}
+
function_builder::function_builder ()
{
m_direct_overloads = lang_GNU_CXX ();
@@ -3357,7 +3432,8 @@ function_builder::get_attributes (const function_instance &instance)
registered_function &
function_builder::add_function (const function_instance &instance,
const char *name, tree fntype, tree attrs,
- bool placeholder_p)
+ bool placeholder_p,
+ bool overloaded_p = false)
{
unsigned int code = vec_safe_length (registered_functions);
code = (code << RISCV_BUILTIN_SHIFT) + RISCV_BUILTIN_VECTOR;
@@ -3383,6 +3459,7 @@ function_builder::add_function (const function_instance &instance,
registered_function &rfn = *ggc_alloc<registered_function> ();
rfn.instance = instance;
rfn.decl = decl;
+ rfn.overloaded_p = overloaded_p;
vec_safe_push (registered_functions, &rfn);
return rfn;
@@ -3432,6 +3509,26 @@ function_builder::add_unique_function (const function_instance &instance,
obstack_free (&m_string_obstack, name);
}
+void
+function_builder::add_overloaded_function (const function_instance &instance,
+ const function_shape *shape)
+{
+ if (!check_required_extensions (instance))
+ return;
+
+ char *name = shape->get_name (*this, instance, true);
+
+ if (name)
+ {
+ /* To avoid API conflicting, take void return type and void argument
+ for the overloaded function. */
+ tree fntype = build_function_type (void_type_node, void_list_node);
+ add_function (instance, name, fntype, NULL_TREE, m_direct_overloads,
+ true);
+ obstack_free (&m_string_obstack, name);
+ }
+}
+
function_call_info::function_call_info (location_t location_in,
const function_instance &instance_in,
tree fndecl_in)
@@ -3852,6 +3949,13 @@ function_checker::function_checker (location_t location,
m_nargs (nargs), m_args (args)
{}
+function_resolver::function_resolver (location_t location,
+ const function_instance &instance,
+ tree fndecl,
+ vec<tree, va_gc> &arglist)
+ : function_call_info (location, instance, fndecl), m_arglist (arglist)
+{}
+
/* Report that LOCATION has a call to FNDECL in which argument ARGNO
was not an integer constant expression. ARGNO counts from zero. */
void
@@ -3967,6 +4071,39 @@ function_checker::check ()
return shape->check (*this);
}
+unsigned int
+function_resolver::get_sub_code ()
+{
+ unsigned int fun_code = DECL_MD_FUNCTION_CODE (fndecl);
+
+ return fun_code >> RISCV_BUILTIN_SHIFT;
+}
+
+tree
+function_resolver::resolve ()
+{
+ return shape->resolve (*this);
+}
+
+tree
+function_resolver::lookup ()
+{
+ unsigned int fun_code = get_sub_code ();
+ function_instance *instance
+ = base->get_non_overloaded_instance (fun_code, m_arglist);
+
+ if (!instance)
+ return NULL_TREE;
+
+ hashval_t hash = instance->hash ();
+ registered_function *rfun = function_table->find_with_hash (*instance, hash);
+
+ if (!rfun)
+ return NULL_TREE;
+
+ return rfun->decl;
+}
+
inline hashval_t
registered_function_hasher::hash (value_type value)
{
@@ -4196,6 +4333,22 @@ check_builtin_call (location_t location, vec<location_t>, unsigned int code,
TREE_TYPE (rfn.decl), nargs, args).check ();
}
+tree
+resolve_overloaded_builtin (location_t loc, unsigned int code,
+ vec<tree, va_gc> *arglist)
+{
+ if (code >= vec_safe_length (registered_functions))
+ return NULL_TREE;
+
+ const registered_function *rfun = (*registered_functions)[code];
+
+ if (!rfun || !rfun->overloaded_p)
+ return NULL_TREE;
+
+ return function_resolver (loc, rfun->instance, rfun->decl, *arglist)
+ .resolve ();
+}
+
function_instance
get_read_vl_instance (void)
{
diff --git a/gcc/config/riscv/riscv-vector-builtins.h b/gcc/config/riscv/riscv-vector-builtins.h
index e358a8e4d91..e20f0f14ce4 100644
--- a/gcc/config/riscv/riscv-vector-builtins.h
+++ b/gcc/config/riscv/riscv-vector-builtins.h
@@ -277,6 +277,8 @@ public:
void apply_predication (const function_instance &, tree, vec<tree> &) const;
void add_unique_function (const function_instance &, const function_shape *,
tree, vec<tree> &);
+ void add_overloaded_function (const function_instance &,
+ const function_shape *);
void register_function_group (const function_group_info &);
void append_name (const char *);
void append_base_name (const char *);
@@ -288,7 +290,7 @@ private:
tree get_attributes (const function_instance &);
registered_function &add_function (const function_instance &, const char *,
- tree, tree, bool);
+ tree, tree, bool, bool);
/* True if we should create a separate decl for each instance of an
overloaded function, instead of using function_builder. */
@@ -424,6 +426,11 @@ public:
/* Expand the given call into rtl. Return the result of the function,
or an arbitrary value if the function doesn't return a result. */
virtual rtx expand (function_expander &) const = 0;
+
+ /* Return the non-overloaded function instance from the registered
+ function table if success, or NULL will be returned. */
+ virtual function_instance * get_non_overloaded_instance (
+ unsigned int, vec<tree, va_gc> &arglist) const;
};
/* A class for checking that the semantic constraints on a function call are
@@ -462,6 +469,29 @@ private:
tree *m_args;
};
+/* A class for resolving an overloaded function call. */
+class function_resolver : public function_call_info
+{
+public:
+ function_resolver (location_t, const function_instance &, tree,
+ vec<tree, va_gc> &);
+
+ /* Resolve the correlated non-overloaded function from the
+ the registered_functions table. */
+ tree resolve ();
+
+ /* Lookup the non-overloaded function from the registered
+ function table. */
+ tree lookup ();
+
+ /* Return the sub code of the fndecl. */
+ unsigned int get_sub_code ();
+
+private:
+ /* The arguments to the overloaded function. */
+ vec<tree, va_gc> &m_arglist;
+};
+
/* Classifies functions into "shapes" base on:
- Base name of the intrinsic function.
@@ -486,6 +516,10 @@ public:
/* Check whether the given call is semantically valid. Return true
if it is, otherwise report an error and return false. */
virtual bool check (function_checker &) const { return true; }
+
+ /* Try to resolve the overloaded call. Return the non-overloaded
+ function decl on success and NULL_TREE on failure. */
+ virtual tree resolve (function_resolver &) const { return NULL_TREE; };
};
extern const char *const operand_suffixes[NUM_OP_TYPES];
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
new file mode 100644
index 00000000000..56154da155b
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
@@ -0,0 +1,8 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv32gcv_zvfh -mabi=ilp32 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
+
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
new file mode 100644
index 00000000000..f4a63c9585d
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
@@ -0,0 +1,8 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv64gcv_zvfh -mabi=lp64 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
+
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
new file mode 100644
index 00000000000..8756c5e17b7
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
@@ -0,0 +1,27 @@
+#include "riscv_vector.h"
+
+vint32m1_t test_vmv_overloaded_0 (vint32m1_t src, size_t vl) {
+ return __riscv_vmv_v (src, vl);
+}
+
+vfloat16m1_t test_vmv_overloaded_1 (vfloat16m1_t src, size_t vl) {
+ return __riscv_vmv_v (src, vl);
+}
+
+vint8m4_t test_vmv_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src,
+ size_t vl) {
+ return __riscv_vmv_v_tu (maskedoff, src, vl);
+}
+
+vint32m1_t test_vmv_non_overloaded_0 (vint32m1_t src, size_t vl) {
+ return __riscv_vmv_v_v_i32m1 (src, vl);
+}
+
+vfloat16m1_t test_vmv_non_overloaded_1 (vfloat16m1_t src, size_t vl) {
+ return __riscv_vmv_v_v_f16m1 (src, vl);
+}
+
+vint8m4_t test_vmv_non_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src,
+ size_t vl) {
+ return __riscv_vmv_v_v_i8m4_tu (maskedoff, src, vl);
+}
--
2.34.1
Sorry for comment again.
I am not happy with current get_non_overloaeded_instance function.
I think the searching approach is very in-effective:
+function_instance *
+function_base::get_non_overloaded_instance (unsigned int code,
+ vec<tree, va_gc> &arglist) const
+{
+ unsigned int code_limit = vec_safe_length (registered_functions);
+
+ for (unsigned fun_code = code; fun_code < code_limit; fun_code++)
+ {
+ registered_function *rfun = (*registered_functions)[fun_code];
+ function_instance instance = rfun->instance;
+
+ if (rfun->overloaded_p)
+ continue;
+
+ unsigned k;
+ const rvv_arg_type_info *args = instance.op_info->args;
+
+ for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
+ {
+ if (k >= arglist.length ())
+ break;
+
+ if (TYPE_MODE (instance.get_arg_type (k))
+ != TYPE_MODE (TREE_TYPE (arglist[k])))
+ break;
+ }
+
+ if (args[k].base_type == NUM_BASE_TYPES)
+ return &rfun->instance;
+ }
+
+ return NULL;
+}
Instead, I think we should build up a table which map non-overloaded function according to the arguments so that we could get the "instance" effectively.
E.g. For vint8mf8_t tumu vadd intrinsic the instance is like this:
function_instance ("vadd", bases::vadd, shapes::alu,
iu_ops[VECTOR_TYPE_vuint8mf8_t], PRED_TYPE_tumu, &iu_vvv_ops);
Since the get_nonoverloaed_instance is already the function of the class BASE.
So, The first 3 arguments "vadd", bases::vadd, shapes::alu
should already known since it is a known function_base.
The last 3 arguments may need some elegant analysis or map table to quickly grep.
So, I think we should consider this framework seriously.
juzhe.zhong@rivai.ai
From: pan2.li
Date: 2023-09-12 16:46
To: gcc-patches
CC: juzhe.zhong; pan2.li; yanzhang.wang; kito.cheng
Subject: [PATCH v3] RISC-V: Implement RESOLVE_OVERLOADED_BUILTIN for RVV intrinsic
From: Pan Li <pan2.li@intel.com>
Update in v3:
* Rewrite comment for overloaded function add.
* Move get_non_overloaded_instance to function_base.
Update in v2:
* Add get_non_overloaded_instance for function instance.
* Fix overload check for policy function.
* Enrich the test cases check.
Original log:
This patch would like add the framework to support the RVV overloaded
intrinsic API in riscv-xxx-xxx-gcc, like riscv-xxx-xxx-g++ did.
However, it almost leverage the hook TARGET_RESOLVE_OVERLOADED_BUILTIN
with below steps.
* Register overloaded functions.
* Add function_resolver for overloaded function resolving.
* Add resolve API for function shape with default implementation.
* Implement HOOK for navigating the overloaded API to non-overloaded API.
We validated this framework by the vmv_v intrinsic API(s), and we will
add more intrins API support in the underlying patches.
gcc/ChangeLog:
* config/riscv/riscv-c.cc
(riscv_resolve_overloaded_builtin): New function for the hook.
(riscv_register_pragmas): Register the hook
* config/riscv/riscv-protos.h (resolve_overloaded_builtin): New decl.
* config/riscv/riscv-vector-builtins-shapes.cc (build_one):
Register overloaded function.
(struct overloaded_base): New struct for overloaded shape.
(struct non_overloaded_base): New struct for non overloaded shape.
(struct move_def): Inherit overloaded shape.
* config/riscv/riscv-vector-builtins.cc
(function_base::get_non_overloaded_instance): New API impl.
(function_builder::add_function): Add overloaded arg.
(function_resolver::function_resolver): New constructor.
(function_builder::add_overloaded_function): New API impl.
(function_resolver::resolve): Ditto.
(function_resolver::lookup): Ditto.
(function_resolver::get_sub_code): Ditto.
(resolve_overloaded_builtin): New function impl.
* config/riscv/riscv-vector-builtins.h:
(class function_resolver): New class.
gcc/testsuite/ChangeLog:
* gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c: New test.
* gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c: New test.
* gcc.target/riscv/rvv/base/overloaded_vmv_v.h: New test.
Signed-off-by: Pan Li <pan2.li@intel.com>
---
gcc/config/riscv/riscv-c.cc | 36 ++++
gcc/config/riscv/riscv-protos.h | 1 +
.../riscv/riscv-vector-builtins-shapes.cc | 20 ++-
gcc/config/riscv/riscv-vector-builtins.cc | 155 +++++++++++++++++-
gcc/config/riscv/riscv-vector-builtins.h | 36 +++-
.../riscv/rvv/base/overloaded_rv32_vmv_v.c | 8 +
.../riscv/rvv/base/overloaded_rv64_vmv_v.c | 8 +
.../riscv/rvv/base/overloaded_vmv_v.h | 27 +++
8 files changed, 288 insertions(+), 3 deletions(-)
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
diff --git a/gcc/config/riscv/riscv-c.cc b/gcc/config/riscv/riscv-c.cc
index 283052ae313..060edd3129d 100644
--- a/gcc/config/riscv/riscv-c.cc
+++ b/gcc/config/riscv/riscv-c.cc
@@ -220,11 +220,47 @@ riscv_check_builtin_call (location_t loc, vec<location_t> arg_loc, tree fndecl,
gcc_unreachable ();
}
+/* Implement TARGET_RESOLVE_OVERLOADED_BUILTIN. */
+static tree
+riscv_resolve_overloaded_builtin (unsigned int uncast_location, tree fndecl,
+ void *uncast_arglist)
+{
+ vec<tree, va_gc> empty = {};
+ location_t loc = (location_t) uncast_location;
+ vec<tree, va_gc> *arglist = (vec<tree, va_gc> *) uncast_arglist;
+ unsigned int code = DECL_MD_FUNCTION_CODE (fndecl);
+ unsigned int subcode = code >> RISCV_BUILTIN_SHIFT;
+ tree new_fndecl = NULL_TREE;
+
+ if (!arglist)
+ arglist = ∅
+
+ switch (code & RISCV_BUILTIN_CLASS)
+ {
+ case RISCV_BUILTIN_GENERAL:
+ break;
+ case RISCV_BUILTIN_VECTOR:
+ new_fndecl = riscv_vector::resolve_overloaded_builtin (loc, subcode,
+ arglist);
+ break;
+ default:
+ gcc_unreachable ();
+ }
+
+ if (new_fndecl == NULL_TREE)
+ return new_fndecl;
+
+ return build_function_call_vec (loc, vNULL, new_fndecl, arglist, NULL,
+ fndecl);
+}
+
/* Implement REGISTER_TARGET_PRAGMAS. */
void
riscv_register_pragmas (void)
{
+ targetm.resolve_overloaded_builtin = riscv_resolve_overloaded_builtin;
targetm.check_builtin_call = riscv_check_builtin_call;
+
c_register_pragma ("riscv", "intrinsic", riscv_pragma_intrinsic);
}
diff --git a/gcc/config/riscv/riscv-protos.h b/gcc/config/riscv/riscv-protos.h
index 6dbf6b9f943..5d2492dd031 100644
--- a/gcc/config/riscv/riscv-protos.h
+++ b/gcc/config/riscv/riscv-protos.h
@@ -381,6 +381,7 @@ gimple *gimple_fold_builtin (unsigned int, gimple_stmt_iterator *, gcall *);
rtx expand_builtin (unsigned int, tree, rtx);
bool check_builtin_call (location_t, vec<location_t>, unsigned int,
tree, unsigned int, tree *);
+tree resolve_overloaded_builtin (location_t, unsigned int, vec<tree, va_gc> *);
bool const_vec_all_same_in_range_p (rtx, HOST_WIDE_INT, HOST_WIDE_INT);
bool legitimize_move (rtx, rtx);
void emit_vlmax_vsetvl (machine_mode, rtx);
diff --git a/gcc/config/riscv/riscv-vector-builtins-shapes.cc b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
index f8fdec863e6..1c1a2cc9488 100644
--- a/gcc/config/riscv/riscv-vector-builtins-shapes.cc
+++ b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
@@ -49,6 +49,8 @@ build_one (function_builder &b, const function_group_info &group,
group.ops_infos.types[vec_type_idx].index);
b.allocate_argument_types (function_instance, argument_types);
b.apply_predication (function_instance, return_type, argument_types);
+
+ b.add_overloaded_function (function_instance, *group.shape);
b.add_unique_function (function_instance, (*group.shape), return_type,
argument_types);
}
@@ -87,6 +89,22 @@ struct build_base : public function_shape
}
};
+struct overloaded_base : public build_base
+{
+ tree resolve (function_resolver &r) const override
+ {
+ return r.lookup ();
+ }
+};
+
+struct non_overloaded_base : public build_base
+{
+ tree resolve (function_resolver &) const override
+ {
+ gcc_unreachable ();
+ }
+};
+
/* vsetvl_def class. */
struct vsetvl_def : public build_base
{
@@ -525,7 +543,7 @@ struct narrow_alu_def : public build_base
};
/* move_def class. Handle vmv.v.v/vmv.v.x. */
-struct move_def : public build_base
+struct move_def : public overloaded_base
{
char *get_name (function_builder &b, const function_instance &instance,
bool overloaded_p) const override
diff --git a/gcc/config/riscv/riscv-vector-builtins.cc b/gcc/config/riscv/riscv-vector-builtins.cc
index 6d99f970ead..4f6fbdc3e28 100644
--- a/gcc/config/riscv/riscv-vector-builtins.cc
+++ b/gcc/config/riscv/riscv-vector-builtins.cc
@@ -80,6 +80,10 @@ public:
/* The decl itself. */
tree GTY ((skip)) decl;
+
+ /* True if the decl represents an overloaded function that needs to be
+ resolved by function_resolver. */
+ bool overloaded_p;
};
/* Hash traits for registered_function. */
@@ -3196,6 +3200,77 @@ function_instance::could_trap_p () const
return false;
}
+/* Try to get the non-overloaded function instance.
+ After we register the overloaded the functions, the registered functions
+ table may look like:
+
+ +--------+---------------------------+-------------------+
+ | index | name | kind |
+ +--------+---------------------------+-------------------+
+ | 124733 | __riscv_vmv_v | Overloaded | <- Hook fun code
+ +--------+---------------------------+-------------------+
+ | 124735 | __riscv_vmv_v_v_i8mf8 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124737 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+ | 124739 | __riscv_vmv_v | Overloaded |
+ +--------+---------------------------+-------------------+
+ | 124741 | __riscv_vmv_v_v_i8mf4 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124743 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+ | 124745 | __riscv_vmv_v | Overloaded |
+ +--------+---------------------------+-------------------+
+ | 124747 | __riscv_vmv_v_v_i8mf2 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124749 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+ | 124751 | __riscv_vmv_v | Overloaded |
+ +--------+---------------------------+-------------------+
+ | 124753 | __riscv_vmv_v_v_i8m1 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124755 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+
+ When we resolve the overloaded API from the hook, we always get the first
+ function code of one API group (aka vmv_v as above table). We will search
+ start from that index to find the only one non-overloaded API with exactly
+ the same arglist. Or NULL instance will be returned.
+ */
+function_instance *
+function_base::get_non_overloaded_instance (unsigned int code,
+ vec<tree, va_gc> &arglist) const
+{
+ unsigned int code_limit = vec_safe_length (registered_functions);
+
+ for (unsigned fun_code = code; fun_code < code_limit; fun_code++)
+ {
+ registered_function *rfun = (*registered_functions)[fun_code];
+ function_instance instance = rfun->instance;
+
+ if (rfun->overloaded_p)
+ continue;
+
+ unsigned k;
+ const rvv_arg_type_info *args = instance.op_info->args;
+
+ for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
+ {
+ if (k >= arglist.length ())
+ break;
+
+ if (TYPE_MODE (instance.get_arg_type (k))
+ != TYPE_MODE (TREE_TYPE (arglist[k])))
+ break;
+ }
+
+ if (args[k].base_type == NUM_BASE_TYPES)
+ return &rfun->instance;
+ }
+
+ return NULL;
+}
+
function_builder::function_builder ()
{
m_direct_overloads = lang_GNU_CXX ();
@@ -3357,7 +3432,8 @@ function_builder::get_attributes (const function_instance &instance)
registered_function &
function_builder::add_function (const function_instance &instance,
const char *name, tree fntype, tree attrs,
- bool placeholder_p)
+ bool placeholder_p,
+ bool overloaded_p = false)
{
unsigned int code = vec_safe_length (registered_functions);
code = (code << RISCV_BUILTIN_SHIFT) + RISCV_BUILTIN_VECTOR;
@@ -3383,6 +3459,7 @@ function_builder::add_function (const function_instance &instance,
registered_function &rfn = *ggc_alloc<registered_function> ();
rfn.instance = instance;
rfn.decl = decl;
+ rfn.overloaded_p = overloaded_p;
vec_safe_push (registered_functions, &rfn);
return rfn;
@@ -3432,6 +3509,26 @@ function_builder::add_unique_function (const function_instance &instance,
obstack_free (&m_string_obstack, name);
}
+void
+function_builder::add_overloaded_function (const function_instance &instance,
+ const function_shape *shape)
+{
+ if (!check_required_extensions (instance))
+ return;
+
+ char *name = shape->get_name (*this, instance, true);
+
+ if (name)
+ {
+ /* To avoid API conflicting, take void return type and void argument
+ for the overloaded function. */
+ tree fntype = build_function_type (void_type_node, void_list_node);
+ add_function (instance, name, fntype, NULL_TREE, m_direct_overloads,
+ true);
+ obstack_free (&m_string_obstack, name);
+ }
+}
+
function_call_info::function_call_info (location_t location_in,
const function_instance &instance_in,
tree fndecl_in)
@@ -3852,6 +3949,13 @@ function_checker::function_checker (location_t location,
m_nargs (nargs), m_args (args)
{}
+function_resolver::function_resolver (location_t location,
+ const function_instance &instance,
+ tree fndecl,
+ vec<tree, va_gc> &arglist)
+ : function_call_info (location, instance, fndecl), m_arglist (arglist)
+{}
+
/* Report that LOCATION has a call to FNDECL in which argument ARGNO
was not an integer constant expression. ARGNO counts from zero. */
void
@@ -3967,6 +4071,39 @@ function_checker::check ()
return shape->check (*this);
}
+unsigned int
+function_resolver::get_sub_code ()
+{
+ unsigned int fun_code = DECL_MD_FUNCTION_CODE (fndecl);
+
+ return fun_code >> RISCV_BUILTIN_SHIFT;
+}
+
+tree
+function_resolver::resolve ()
+{
+ return shape->resolve (*this);
+}
+
+tree
+function_resolver::lookup ()
+{
+ unsigned int fun_code = get_sub_code ();
+ function_instance *instance
+ = base->get_non_overloaded_instance (fun_code, m_arglist);
+
+ if (!instance)
+ return NULL_TREE;
+
+ hashval_t hash = instance->hash ();
+ registered_function *rfun = function_table->find_with_hash (*instance, hash);
+
+ if (!rfun)
+ return NULL_TREE;
+
+ return rfun->decl;
+}
+
inline hashval_t
registered_function_hasher::hash (value_type value)
{
@@ -4196,6 +4333,22 @@ check_builtin_call (location_t location, vec<location_t>, unsigned int code,
TREE_TYPE (rfn.decl), nargs, args).check ();
}
+tree
+resolve_overloaded_builtin (location_t loc, unsigned int code,
+ vec<tree, va_gc> *arglist)
+{
+ if (code >= vec_safe_length (registered_functions))
+ return NULL_TREE;
+
+ const registered_function *rfun = (*registered_functions)[code];
+
+ if (!rfun || !rfun->overloaded_p)
+ return NULL_TREE;
+
+ return function_resolver (loc, rfun->instance, rfun->decl, *arglist)
+ .resolve ();
+}
+
function_instance
get_read_vl_instance (void)
{
diff --git a/gcc/config/riscv/riscv-vector-builtins.h b/gcc/config/riscv/riscv-vector-builtins.h
index e358a8e4d91..e20f0f14ce4 100644
--- a/gcc/config/riscv/riscv-vector-builtins.h
+++ b/gcc/config/riscv/riscv-vector-builtins.h
@@ -277,6 +277,8 @@ public:
void apply_predication (const function_instance &, tree, vec<tree> &) const;
void add_unique_function (const function_instance &, const function_shape *,
tree, vec<tree> &);
+ void add_overloaded_function (const function_instance &,
+ const function_shape *);
void register_function_group (const function_group_info &);
void append_name (const char *);
void append_base_name (const char *);
@@ -288,7 +290,7 @@ private:
tree get_attributes (const function_instance &);
registered_function &add_function (const function_instance &, const char *,
- tree, tree, bool);
+ tree, tree, bool, bool);
/* True if we should create a separate decl for each instance of an
overloaded function, instead of using function_builder. */
@@ -424,6 +426,11 @@ public:
/* Expand the given call into rtl. Return the result of the function,
or an arbitrary value if the function doesn't return a result. */
virtual rtx expand (function_expander &) const = 0;
+
+ /* Return the non-overloaded function instance from the registered
+ function table if success, or NULL will be returned. */
+ virtual function_instance * get_non_overloaded_instance (
+ unsigned int, vec<tree, va_gc> &arglist) const;
};
/* A class for checking that the semantic constraints on a function call are
@@ -462,6 +469,29 @@ private:
tree *m_args;
};
+/* A class for resolving an overloaded function call. */
+class function_resolver : public function_call_info
+{
+public:
+ function_resolver (location_t, const function_instance &, tree,
+ vec<tree, va_gc> &);
+
+ /* Resolve the correlated non-overloaded function from the
+ the registered_functions table. */
+ tree resolve ();
+
+ /* Lookup the non-overloaded function from the registered
+ function table. */
+ tree lookup ();
+
+ /* Return the sub code of the fndecl. */
+ unsigned int get_sub_code ();
+
+private:
+ /* The arguments to the overloaded function. */
+ vec<tree, va_gc> &m_arglist;
+};
+
/* Classifies functions into "shapes" base on:
- Base name of the intrinsic function.
@@ -486,6 +516,10 @@ public:
/* Check whether the given call is semantically valid. Return true
if it is, otherwise report an error and return false. */
virtual bool check (function_checker &) const { return true; }
+
+ /* Try to resolve the overloaded call. Return the non-overloaded
+ function decl on success and NULL_TREE on failure. */
+ virtual tree resolve (function_resolver &) const { return NULL_TREE; };
};
extern const char *const operand_suffixes[NUM_OP_TYPES];
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
new file mode 100644
index 00000000000..56154da155b
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
@@ -0,0 +1,8 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv32gcv_zvfh -mabi=ilp32 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
+
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
new file mode 100644
index 00000000000..f4a63c9585d
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
@@ -0,0 +1,8 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv64gcv_zvfh -mabi=lp64 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
+
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
new file mode 100644
index 00000000000..8756c5e17b7
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
@@ -0,0 +1,27 @@
+#include "riscv_vector.h"
+
+vint32m1_t test_vmv_overloaded_0 (vint32m1_t src, size_t vl) {
+ return __riscv_vmv_v (src, vl);
+}
+
+vfloat16m1_t test_vmv_overloaded_1 (vfloat16m1_t src, size_t vl) {
+ return __riscv_vmv_v (src, vl);
+}
+
+vint8m4_t test_vmv_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src,
+ size_t vl) {
+ return __riscv_vmv_v_tu (maskedoff, src, vl);
+}
+
+vint32m1_t test_vmv_non_overloaded_0 (vint32m1_t src, size_t vl) {
+ return __riscv_vmv_v_v_i32m1 (src, vl);
+}
+
+vfloat16m1_t test_vmv_non_overloaded_1 (vfloat16m1_t src, size_t vl) {
+ return __riscv_vmv_v_v_f16m1 (src, vl);
+}
+
+vint8m4_t test_vmv_non_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src,
+ size_t vl) {
+ return __riscv_vmv_v_v_i8m4_tu (maskedoff, src, vl);
+}
--
2.34.1
More information:
For PRED_TYPE_tumu, it's easy to analyze, just need to count how many arguments in the arglist.
If arglist has 5 arguments (mask, merge, op1, op2, len) Then it must be TUMU.
What I mean is that we should be able to quickly to compute the arguments of the construction of the function_instance.
Then we can get the non-overloaeded function.
juzhe.zhong@rivai.ai
From: juzhe.zhong@rivai.ai
Date: 2023-09-15 10:02
To: pan2.li; gcc-patches
CC: pan2.li; yanzhang.wang; kito.cheng
Subject: Re: [PATCH v3] RISC-V: Implement RESOLVE_OVERLOADED_BUILTIN for RVV intrinsic
Sorry for comment again.
I am not happy with current get_non_overloaeded_instance function.
I think the searching approach is very in-effective:
+function_instance *
+function_base::get_non_overloaded_instance (unsigned int code,
+ vec<tree, va_gc> &arglist) const
+{
+ unsigned int code_limit = vec_safe_length (registered_functions);
+
+ for (unsigned fun_code = code; fun_code < code_limit; fun_code++)
+ {
+ registered_function *rfun = (*registered_functions)[fun_code];
+ function_instance instance = rfun->instance;
+
+ if (rfun->overloaded_p)
+ continue;
+
+ unsigned k;
+ const rvv_arg_type_info *args = instance.op_info->args;
+
+ for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
+ {
+ if (k >= arglist.length ())
+ break;
+
+ if (TYPE_MODE (instance.get_arg_type (k))
+ != TYPE_MODE (TREE_TYPE (arglist[k])))
+ break;
+ }
+
+ if (args[k].base_type == NUM_BASE_TYPES)
+ return &rfun->instance;
+ }
+
+ return NULL;
+}
Instead, I think we should build up a table which map non-overloaded function according to the arguments so that we could get the "instance" effectively.
E.g. For vint8mf8_t tumu vadd intrinsic the instance is like this:
function_instance ("vadd", bases::vadd, shapes::alu,
iu_ops[VECTOR_TYPE_vuint8mf8_t], PRED_TYPE_tumu, &iu_vvv_ops);
Since the get_nonoverloaed_instance is already the function of the class BASE.
So, The first 3 arguments "vadd", bases::vadd, shapes::alu
should already known since it is a known function_base.
The last 3 arguments may need some elegant analysis or map table to quickly grep.
So, I think we should consider this framework seriously.
juzhe.zhong@rivai.ai
From: pan2.li
Date: 2023-09-12 16:46
To: gcc-patches
CC: juzhe.zhong; pan2.li; yanzhang.wang; kito.cheng
Subject: [PATCH v3] RISC-V: Implement RESOLVE_OVERLOADED_BUILTIN for RVV intrinsic
From: Pan Li <pan2.li@intel.com>
Update in v3:
* Rewrite comment for overloaded function add.
* Move get_non_overloaded_instance to function_base.
Update in v2:
* Add get_non_overloaded_instance for function instance.
* Fix overload check for policy function.
* Enrich the test cases check.
Original log:
This patch would like add the framework to support the RVV overloaded
intrinsic API in riscv-xxx-xxx-gcc, like riscv-xxx-xxx-g++ did.
However, it almost leverage the hook TARGET_RESOLVE_OVERLOADED_BUILTIN
with below steps.
* Register overloaded functions.
* Add function_resolver for overloaded function resolving.
* Add resolve API for function shape with default implementation.
* Implement HOOK for navigating the overloaded API to non-overloaded API.
We validated this framework by the vmv_v intrinsic API(s), and we will
add more intrins API support in the underlying patches.
gcc/ChangeLog:
* config/riscv/riscv-c.cc
(riscv_resolve_overloaded_builtin): New function for the hook.
(riscv_register_pragmas): Register the hook
* config/riscv/riscv-protos.h (resolve_overloaded_builtin): New decl.
* config/riscv/riscv-vector-builtins-shapes.cc (build_one):
Register overloaded function.
(struct overloaded_base): New struct for overloaded shape.
(struct non_overloaded_base): New struct for non overloaded shape.
(struct move_def): Inherit overloaded shape.
* config/riscv/riscv-vector-builtins.cc
(function_base::get_non_overloaded_instance): New API impl.
(function_builder::add_function): Add overloaded arg.
(function_resolver::function_resolver): New constructor.
(function_builder::add_overloaded_function): New API impl.
(function_resolver::resolve): Ditto.
(function_resolver::lookup): Ditto.
(function_resolver::get_sub_code): Ditto.
(resolve_overloaded_builtin): New function impl.
* config/riscv/riscv-vector-builtins.h:
(class function_resolver): New class.
gcc/testsuite/ChangeLog:
* gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c: New test.
* gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c: New test.
* gcc.target/riscv/rvv/base/overloaded_vmv_v.h: New test.
Signed-off-by: Pan Li <pan2.li@intel.com>
---
gcc/config/riscv/riscv-c.cc | 36 ++++
gcc/config/riscv/riscv-protos.h | 1 +
.../riscv/riscv-vector-builtins-shapes.cc | 20 ++-
gcc/config/riscv/riscv-vector-builtins.cc | 155 +++++++++++++++++-
gcc/config/riscv/riscv-vector-builtins.h | 36 +++-
.../riscv/rvv/base/overloaded_rv32_vmv_v.c | 8 +
.../riscv/rvv/base/overloaded_rv64_vmv_v.c | 8 +
.../riscv/rvv/base/overloaded_vmv_v.h | 27 +++
8 files changed, 288 insertions(+), 3 deletions(-)
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
diff --git a/gcc/config/riscv/riscv-c.cc b/gcc/config/riscv/riscv-c.cc
index 283052ae313..060edd3129d 100644
--- a/gcc/config/riscv/riscv-c.cc
+++ b/gcc/config/riscv/riscv-c.cc
@@ -220,11 +220,47 @@ riscv_check_builtin_call (location_t loc, vec<location_t> arg_loc, tree fndecl,
gcc_unreachable ();
}
+/* Implement TARGET_RESOLVE_OVERLOADED_BUILTIN. */
+static tree
+riscv_resolve_overloaded_builtin (unsigned int uncast_location, tree fndecl,
+ void *uncast_arglist)
+{
+ vec<tree, va_gc> empty = {};
+ location_t loc = (location_t) uncast_location;
+ vec<tree, va_gc> *arglist = (vec<tree, va_gc> *) uncast_arglist;
+ unsigned int code = DECL_MD_FUNCTION_CODE (fndecl);
+ unsigned int subcode = code >> RISCV_BUILTIN_SHIFT;
+ tree new_fndecl = NULL_TREE;
+
+ if (!arglist)
+ arglist = ∅
+
+ switch (code & RISCV_BUILTIN_CLASS)
+ {
+ case RISCV_BUILTIN_GENERAL:
+ break;
+ case RISCV_BUILTIN_VECTOR:
+ new_fndecl = riscv_vector::resolve_overloaded_builtin (loc, subcode,
+ arglist);
+ break;
+ default:
+ gcc_unreachable ();
+ }
+
+ if (new_fndecl == NULL_TREE)
+ return new_fndecl;
+
+ return build_function_call_vec (loc, vNULL, new_fndecl, arglist, NULL,
+ fndecl);
+}
+
/* Implement REGISTER_TARGET_PRAGMAS. */
void
riscv_register_pragmas (void)
{
+ targetm.resolve_overloaded_builtin = riscv_resolve_overloaded_builtin;
targetm.check_builtin_call = riscv_check_builtin_call;
+
c_register_pragma ("riscv", "intrinsic", riscv_pragma_intrinsic);
}
diff --git a/gcc/config/riscv/riscv-protos.h b/gcc/config/riscv/riscv-protos.h
index 6dbf6b9f943..5d2492dd031 100644
--- a/gcc/config/riscv/riscv-protos.h
+++ b/gcc/config/riscv/riscv-protos.h
@@ -381,6 +381,7 @@ gimple *gimple_fold_builtin (unsigned int, gimple_stmt_iterator *, gcall *);
rtx expand_builtin (unsigned int, tree, rtx);
bool check_builtin_call (location_t, vec<location_t>, unsigned int,
tree, unsigned int, tree *);
+tree resolve_overloaded_builtin (location_t, unsigned int, vec<tree, va_gc> *);
bool const_vec_all_same_in_range_p (rtx, HOST_WIDE_INT, HOST_WIDE_INT);
bool legitimize_move (rtx, rtx);
void emit_vlmax_vsetvl (machine_mode, rtx);
diff --git a/gcc/config/riscv/riscv-vector-builtins-shapes.cc b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
index f8fdec863e6..1c1a2cc9488 100644
--- a/gcc/config/riscv/riscv-vector-builtins-shapes.cc
+++ b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
@@ -49,6 +49,8 @@ build_one (function_builder &b, const function_group_info &group,
group.ops_infos.types[vec_type_idx].index);
b.allocate_argument_types (function_instance, argument_types);
b.apply_predication (function_instance, return_type, argument_types);
+
+ b.add_overloaded_function (function_instance, *group.shape);
b.add_unique_function (function_instance, (*group.shape), return_type,
argument_types);
}
@@ -87,6 +89,22 @@ struct build_base : public function_shape
}
};
+struct overloaded_base : public build_base
+{
+ tree resolve (function_resolver &r) const override
+ {
+ return r.lookup ();
+ }
+};
+
+struct non_overloaded_base : public build_base
+{
+ tree resolve (function_resolver &) const override
+ {
+ gcc_unreachable ();
+ }
+};
+
/* vsetvl_def class. */
struct vsetvl_def : public build_base
{
@@ -525,7 +543,7 @@ struct narrow_alu_def : public build_base
};
/* move_def class. Handle vmv.v.v/vmv.v.x. */
-struct move_def : public build_base
+struct move_def : public overloaded_base
{
char *get_name (function_builder &b, const function_instance &instance,
bool overloaded_p) const override
diff --git a/gcc/config/riscv/riscv-vector-builtins.cc b/gcc/config/riscv/riscv-vector-builtins.cc
index 6d99f970ead..4f6fbdc3e28 100644
--- a/gcc/config/riscv/riscv-vector-builtins.cc
+++ b/gcc/config/riscv/riscv-vector-builtins.cc
@@ -80,6 +80,10 @@ public:
/* The decl itself. */
tree GTY ((skip)) decl;
+
+ /* True if the decl represents an overloaded function that needs to be
+ resolved by function_resolver. */
+ bool overloaded_p;
};
/* Hash traits for registered_function. */
@@ -3196,6 +3200,77 @@ function_instance::could_trap_p () const
return false;
}
+/* Try to get the non-overloaded function instance.
+ After we register the overloaded the functions, the registered functions
+ table may look like:
+
+ +--------+---------------------------+-------------------+
+ | index | name | kind |
+ +--------+---------------------------+-------------------+
+ | 124733 | __riscv_vmv_v | Overloaded | <- Hook fun code
+ +--------+---------------------------+-------------------+
+ | 124735 | __riscv_vmv_v_v_i8mf8 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124737 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+ | 124739 | __riscv_vmv_v | Overloaded |
+ +--------+---------------------------+-------------------+
+ | 124741 | __riscv_vmv_v_v_i8mf4 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124743 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+ | 124745 | __riscv_vmv_v | Overloaded |
+ +--------+---------------------------+-------------------+
+ | 124747 | __riscv_vmv_v_v_i8mf2 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124749 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+ | 124751 | __riscv_vmv_v | Overloaded |
+ +--------+---------------------------+-------------------+
+ | 124753 | __riscv_vmv_v_v_i8m1 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124755 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+
+ When we resolve the overloaded API from the hook, we always get the first
+ function code of one API group (aka vmv_v as above table). We will search
+ start from that index to find the only one non-overloaded API with exactly
+ the same arglist. Or NULL instance will be returned.
+ */
+function_instance *
+function_base::get_non_overloaded_instance (unsigned int code,
+ vec<tree, va_gc> &arglist) const
+{
+ unsigned int code_limit = vec_safe_length (registered_functions);
+
+ for (unsigned fun_code = code; fun_code < code_limit; fun_code++)
+ {
+ registered_function *rfun = (*registered_functions)[fun_code];
+ function_instance instance = rfun->instance;
+
+ if (rfun->overloaded_p)
+ continue;
+
+ unsigned k;
+ const rvv_arg_type_info *args = instance.op_info->args;
+
+ for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
+ {
+ if (k >= arglist.length ())
+ break;
+
+ if (TYPE_MODE (instance.get_arg_type (k))
+ != TYPE_MODE (TREE_TYPE (arglist[k])))
+ break;
+ }
+
+ if (args[k].base_type == NUM_BASE_TYPES)
+ return &rfun->instance;
+ }
+
+ return NULL;
+}
+
function_builder::function_builder ()
{
m_direct_overloads = lang_GNU_CXX ();
@@ -3357,7 +3432,8 @@ function_builder::get_attributes (const function_instance &instance)
registered_function &
function_builder::add_function (const function_instance &instance,
const char *name, tree fntype, tree attrs,
- bool placeholder_p)
+ bool placeholder_p,
+ bool overloaded_p = false)
{
unsigned int code = vec_safe_length (registered_functions);
code = (code << RISCV_BUILTIN_SHIFT) + RISCV_BUILTIN_VECTOR;
@@ -3383,6 +3459,7 @@ function_builder::add_function (const function_instance &instance,
registered_function &rfn = *ggc_alloc<registered_function> ();
rfn.instance = instance;
rfn.decl = decl;
+ rfn.overloaded_p = overloaded_p;
vec_safe_push (registered_functions, &rfn);
return rfn;
@@ -3432,6 +3509,26 @@ function_builder::add_unique_function (const function_instance &instance,
obstack_free (&m_string_obstack, name);
}
+void
+function_builder::add_overloaded_function (const function_instance &instance,
+ const function_shape *shape)
+{
+ if (!check_required_extensions (instance))
+ return;
+
+ char *name = shape->get_name (*this, instance, true);
+
+ if (name)
+ {
+ /* To avoid API conflicting, take void return type and void argument
+ for the overloaded function. */
+ tree fntype = build_function_type (void_type_node, void_list_node);
+ add_function (instance, name, fntype, NULL_TREE, m_direct_overloads,
+ true);
+ obstack_free (&m_string_obstack, name);
+ }
+}
+
function_call_info::function_call_info (location_t location_in,
const function_instance &instance_in,
tree fndecl_in)
@@ -3852,6 +3949,13 @@ function_checker::function_checker (location_t location,
m_nargs (nargs), m_args (args)
{}
+function_resolver::function_resolver (location_t location,
+ const function_instance &instance,
+ tree fndecl,
+ vec<tree, va_gc> &arglist)
+ : function_call_info (location, instance, fndecl), m_arglist (arglist)
+{}
+
/* Report that LOCATION has a call to FNDECL in which argument ARGNO
was not an integer constant expression. ARGNO counts from zero. */
void
@@ -3967,6 +4071,39 @@ function_checker::check ()
return shape->check (*this);
}
+unsigned int
+function_resolver::get_sub_code ()
+{
+ unsigned int fun_code = DECL_MD_FUNCTION_CODE (fndecl);
+
+ return fun_code >> RISCV_BUILTIN_SHIFT;
+}
+
+tree
+function_resolver::resolve ()
+{
+ return shape->resolve (*this);
+}
+
+tree
+function_resolver::lookup ()
+{
+ unsigned int fun_code = get_sub_code ();
+ function_instance *instance
+ = base->get_non_overloaded_instance (fun_code, m_arglist);
+
+ if (!instance)
+ return NULL_TREE;
+
+ hashval_t hash = instance->hash ();
+ registered_function *rfun = function_table->find_with_hash (*instance, hash);
+
+ if (!rfun)
+ return NULL_TREE;
+
+ return rfun->decl;
+}
+
inline hashval_t
registered_function_hasher::hash (value_type value)
{
@@ -4196,6 +4333,22 @@ check_builtin_call (location_t location, vec<location_t>, unsigned int code,
TREE_TYPE (rfn.decl), nargs, args).check ();
}
+tree
+resolve_overloaded_builtin (location_t loc, unsigned int code,
+ vec<tree, va_gc> *arglist)
+{
+ if (code >= vec_safe_length (registered_functions))
+ return NULL_TREE;
+
+ const registered_function *rfun = (*registered_functions)[code];
+
+ if (!rfun || !rfun->overloaded_p)
+ return NULL_TREE;
+
+ return function_resolver (loc, rfun->instance, rfun->decl, *arglist)
+ .resolve ();
+}
+
function_instance
get_read_vl_instance (void)
{
diff --git a/gcc/config/riscv/riscv-vector-builtins.h b/gcc/config/riscv/riscv-vector-builtins.h
index e358a8e4d91..e20f0f14ce4 100644
--- a/gcc/config/riscv/riscv-vector-builtins.h
+++ b/gcc/config/riscv/riscv-vector-builtins.h
@@ -277,6 +277,8 @@ public:
void apply_predication (const function_instance &, tree, vec<tree> &) const;
void add_unique_function (const function_instance &, const function_shape *,
tree, vec<tree> &);
+ void add_overloaded_function (const function_instance &,
+ const function_shape *);
void register_function_group (const function_group_info &);
void append_name (const char *);
void append_base_name (const char *);
@@ -288,7 +290,7 @@ private:
tree get_attributes (const function_instance &);
registered_function &add_function (const function_instance &, const char *,
- tree, tree, bool);
+ tree, tree, bool, bool);
/* True if we should create a separate decl for each instance of an
overloaded function, instead of using function_builder. */
@@ -424,6 +426,11 @@ public:
/* Expand the given call into rtl. Return the result of the function,
or an arbitrary value if the function doesn't return a result. */
virtual rtx expand (function_expander &) const = 0;
+
+ /* Return the non-overloaded function instance from the registered
+ function table if success, or NULL will be returned. */
+ virtual function_instance * get_non_overloaded_instance (
+ unsigned int, vec<tree, va_gc> &arglist) const;
};
/* A class for checking that the semantic constraints on a function call are
@@ -462,6 +469,29 @@ private:
tree *m_args;
};
+/* A class for resolving an overloaded function call. */
+class function_resolver : public function_call_info
+{
+public:
+ function_resolver (location_t, const function_instance &, tree,
+ vec<tree, va_gc> &);
+
+ /* Resolve the correlated non-overloaded function from the
+ the registered_functions table. */
+ tree resolve ();
+
+ /* Lookup the non-overloaded function from the registered
+ function table. */
+ tree lookup ();
+
+ /* Return the sub code of the fndecl. */
+ unsigned int get_sub_code ();
+
+private:
+ /* The arguments to the overloaded function. */
+ vec<tree, va_gc> &m_arglist;
+};
+
/* Classifies functions into "shapes" base on:
- Base name of the intrinsic function.
@@ -486,6 +516,10 @@ public:
/* Check whether the given call is semantically valid. Return true
if it is, otherwise report an error and return false. */
virtual bool check (function_checker &) const { return true; }
+
+ /* Try to resolve the overloaded call. Return the non-overloaded
+ function decl on success and NULL_TREE on failure. */
+ virtual tree resolve (function_resolver &) const { return NULL_TREE; };
};
extern const char *const operand_suffixes[NUM_OP_TYPES];
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
new file mode 100644
index 00000000000..56154da155b
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
@@ -0,0 +1,8 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv32gcv_zvfh -mabi=ilp32 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
+
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
new file mode 100644
index 00000000000..f4a63c9585d
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
@@ -0,0 +1,8 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv64gcv_zvfh -mabi=lp64 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
+
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
new file mode 100644
index 00000000000..8756c5e17b7
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
@@ -0,0 +1,27 @@
+#include "riscv_vector.h"
+
+vint32m1_t test_vmv_overloaded_0 (vint32m1_t src, size_t vl) {
+ return __riscv_vmv_v (src, vl);
+}
+
+vfloat16m1_t test_vmv_overloaded_1 (vfloat16m1_t src, size_t vl) {
+ return __riscv_vmv_v (src, vl);
+}
+
+vint8m4_t test_vmv_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src,
+ size_t vl) {
+ return __riscv_vmv_v_tu (maskedoff, src, vl);
+}
+
+vint32m1_t test_vmv_non_overloaded_0 (vint32m1_t src, size_t vl) {
+ return __riscv_vmv_v_v_i32m1 (src, vl);
+}
+
+vfloat16m1_t test_vmv_non_overloaded_1 (vfloat16m1_t src, size_t vl) {
+ return __riscv_vmv_v_v_f16m1 (src, vl);
+}
+
+vint8m4_t test_vmv_non_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src,
+ size_t vl) {
+ return __riscv_vmv_v_v_i8m4_tu (maskedoff, src, vl);
+}
--
2.34.1
Thanks Juzhe for comments, got the point and will have a try for hashmap liked approach to get the non-overloaded later in PATCH v4. Sorry for that in the middle of something.
Pan
From: juzhe.zhong@rivai.ai <juzhe.zhong@rivai.ai>
Sent: Friday, September 15, 2023 10:21 AM
To: Li, Pan2 <pan2.li@intel.com>; gcc-patches <gcc-patches@gcc.gnu.org>
Cc: Li, Pan2 <pan2.li@intel.com>; Wang, Yanzhang <yanzhang.wang@intel.com>; kito.cheng <kito.cheng@gmail.com>
Subject: Re: Re: [PATCH v3] RISC-V: Implement RESOLVE_OVERLOADED_BUILTIN for RVV intrinsic
More information:
For PRED_TYPE_tumu, it's easy to analyze, just need to count how many arguments in the arglist.
If arglist has 5 arguments (mask, merge, op1, op2, len) Then it must be TUMU.
What I mean is that we should be able to quickly to compute the arguments of the construction of the function_instance.
Then we can get the non-overloaeded function.
Hi Pan,
> +function_instance *
> +function_base::get_non_overloaded_instance (unsigned int code,
> + vec<tree, va_gc> &arglist) const
> +{
> + unsigned int code_limit = vec_safe_length (registered_functions);
> +
> + for (unsigned fun_code = code; fun_code < code_limit; fun_code++)
> + {
> + registered_function *rfun = (*registered_functions)[fun_code];
> + function_instance instance = rfun->instance;
> +
> + if (rfun->overloaded_p)
> + continue;
> +
> + unsigned k;
> + const rvv_arg_type_info *args = instance.op_info->args;
> +
> + for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
> + {
> + if (k >= arglist.length ())
> + break;
Can we fast continue if args length not equal arglist length before this
loop:
if (args lengh != arglist.length ())
continue;
for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
{
...
Thanks Lehua, actually Yes.
Consider we will have a try for hashmap way and will keep you posted.
Pan
-----Original Message-----
From: Lehua Ding <lehua.ding@rivai.ai>
Sent: Friday, September 15, 2023 10:29 AM
To: Li, Pan2 <pan2.li@intel.com>; gcc-patches@gcc.gnu.org
Cc: Wang, Yanzhang <yanzhang.wang@intel.com>; kito.cheng@gmail.com; juzhe.zhong@rivai.ai
Subject: Re: [PATCH v3] RISC-V: Implement RESOLVE_OVERLOADED_BUILTIN for RVV intrinsic
Hi Pan,
> +function_instance *
> +function_base::get_non_overloaded_instance (unsigned int code,
> + vec<tree, va_gc> &arglist) const
> +{
> + unsigned int code_limit = vec_safe_length (registered_functions);
> +
> + for (unsigned fun_code = code; fun_code < code_limit; fun_code++)
> + {
> + registered_function *rfun = (*registered_functions)[fun_code];
> + function_instance instance = rfun->instance;
> +
> + if (rfun->overloaded_p)
> + continue;
> +
> + unsigned k;
> + const rvv_arg_type_info *args = instance.op_info->args;
> +
> + for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
> + {
> + if (k >= arglist.length ())
> + break;
Can we fast continue if args length not equal arglist length before this
loop:
if (args lengh != arglist.length ())
continue;
for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
{
...
--
Best,
Lehua
@@ -220,11 +220,47 @@ riscv_check_builtin_call (location_t loc, vec<location_t> arg_loc, tree fndecl,
gcc_unreachable ();
}
+/* Implement TARGET_RESOLVE_OVERLOADED_BUILTIN. */
+static tree
+riscv_resolve_overloaded_builtin (unsigned int uncast_location, tree fndecl,
+ void *uncast_arglist)
+{
+ vec<tree, va_gc> empty = {};
+ location_t loc = (location_t) uncast_location;
+ vec<tree, va_gc> *arglist = (vec<tree, va_gc> *) uncast_arglist;
+ unsigned int code = DECL_MD_FUNCTION_CODE (fndecl);
+ unsigned int subcode = code >> RISCV_BUILTIN_SHIFT;
+ tree new_fndecl = NULL_TREE;
+
+ if (!arglist)
+ arglist = ∅
+
+ switch (code & RISCV_BUILTIN_CLASS)
+ {
+ case RISCV_BUILTIN_GENERAL:
+ break;
+ case RISCV_BUILTIN_VECTOR:
+ new_fndecl = riscv_vector::resolve_overloaded_builtin (loc, subcode,
+ arglist);
+ break;
+ default:
+ gcc_unreachable ();
+ }
+
+ if (new_fndecl == NULL_TREE)
+ return new_fndecl;
+
+ return build_function_call_vec (loc, vNULL, new_fndecl, arglist, NULL,
+ fndecl);
+}
+
/* Implement REGISTER_TARGET_PRAGMAS. */
void
riscv_register_pragmas (void)
{
+ targetm.resolve_overloaded_builtin = riscv_resolve_overloaded_builtin;
targetm.check_builtin_call = riscv_check_builtin_call;
+
c_register_pragma ("riscv", "intrinsic", riscv_pragma_intrinsic);
}
@@ -381,6 +381,7 @@ gimple *gimple_fold_builtin (unsigned int, gimple_stmt_iterator *, gcall *);
rtx expand_builtin (unsigned int, tree, rtx);
bool check_builtin_call (location_t, vec<location_t>, unsigned int,
tree, unsigned int, tree *);
+tree resolve_overloaded_builtin (location_t, unsigned int, vec<tree, va_gc> *);
bool const_vec_all_same_in_range_p (rtx, HOST_WIDE_INT, HOST_WIDE_INT);
bool legitimize_move (rtx, rtx);
void emit_vlmax_vsetvl (machine_mode, rtx);
@@ -49,6 +49,8 @@ build_one (function_builder &b, const function_group_info &group,
group.ops_infos.types[vec_type_idx].index);
b.allocate_argument_types (function_instance, argument_types);
b.apply_predication (function_instance, return_type, argument_types);
+
+ b.add_overloaded_function (function_instance, *group.shape);
b.add_unique_function (function_instance, (*group.shape), return_type,
argument_types);
}
@@ -87,6 +89,22 @@ struct build_base : public function_shape
}
};
+struct overloaded_base : public build_base
+{
+ tree resolve (function_resolver &r) const override
+ {
+ return r.lookup ();
+ }
+};
+
+struct non_overloaded_base : public build_base
+{
+ tree resolve (function_resolver &) const override
+ {
+ gcc_unreachable ();
+ }
+};
+
/* vsetvl_def class. */
struct vsetvl_def : public build_base
{
@@ -525,7 +543,7 @@ struct narrow_alu_def : public build_base
};
/* move_def class. Handle vmv.v.v/vmv.v.x. */
-struct move_def : public build_base
+struct move_def : public overloaded_base
{
char *get_name (function_builder &b, const function_instance &instance,
bool overloaded_p) const override
@@ -80,6 +80,10 @@ public:
/* The decl itself. */
tree GTY ((skip)) decl;
+
+ /* True if the decl represents an overloaded function that needs to be
+ resolved by function_resolver. */
+ bool overloaded_p;
};
/* Hash traits for registered_function. */
@@ -3196,6 +3200,77 @@ function_instance::could_trap_p () const
return false;
}
+/* Try to get the non-overloaded function instance.
+ After we register the overloaded the functions, the registered functions
+ table may look like:
+
+ +--------+---------------------------+-------------------+
+ | index | name | kind |
+ +--------+---------------------------+-------------------+
+ | 124733 | __riscv_vmv_v | Overloaded | <- Hook fun code
+ +--------+---------------------------+-------------------+
+ | 124735 | __riscv_vmv_v_v_i8mf8 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124737 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+ | 124739 | __riscv_vmv_v | Overloaded |
+ +--------+---------------------------+-------------------+
+ | 124741 | __riscv_vmv_v_v_i8mf4 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124743 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+ | 124745 | __riscv_vmv_v | Overloaded |
+ +--------+---------------------------+-------------------+
+ | 124747 | __riscv_vmv_v_v_i8mf2 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124749 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+ | 124751 | __riscv_vmv_v | Overloaded |
+ +--------+---------------------------+-------------------+
+ | 124753 | __riscv_vmv_v_v_i8m1 | Non-overloaded |
+ +--------+---------------------------+-------------------+
+ | 124755 | __riscv_vmv_v | Placeholder |
+ +--------+---------------------------+-------------------+
+
+ When we resolve the overloaded API from the hook, we always get the first
+ function code of one API group (aka vmv_v as above table). We will search
+ start from that index to find the only one non-overloaded API with exactly
+ the same arglist. Or NULL instance will be returned.
+ */
+function_instance *
+function_base::get_non_overloaded_instance (unsigned int code,
+ vec<tree, va_gc> &arglist) const
+{
+ unsigned int code_limit = vec_safe_length (registered_functions);
+
+ for (unsigned fun_code = code; fun_code < code_limit; fun_code++)
+ {
+ registered_function *rfun = (*registered_functions)[fun_code];
+ function_instance instance = rfun->instance;
+
+ if (rfun->overloaded_p)
+ continue;
+
+ unsigned k;
+ const rvv_arg_type_info *args = instance.op_info->args;
+
+ for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
+ {
+ if (k >= arglist.length ())
+ break;
+
+ if (TYPE_MODE (instance.get_arg_type (k))
+ != TYPE_MODE (TREE_TYPE (arglist[k])))
+ break;
+ }
+
+ if (args[k].base_type == NUM_BASE_TYPES)
+ return &rfun->instance;
+ }
+
+ return NULL;
+}
+
function_builder::function_builder ()
{
m_direct_overloads = lang_GNU_CXX ();
@@ -3357,7 +3432,8 @@ function_builder::get_attributes (const function_instance &instance)
registered_function &
function_builder::add_function (const function_instance &instance,
const char *name, tree fntype, tree attrs,
- bool placeholder_p)
+ bool placeholder_p,
+ bool overloaded_p = false)
{
unsigned int code = vec_safe_length (registered_functions);
code = (code << RISCV_BUILTIN_SHIFT) + RISCV_BUILTIN_VECTOR;
@@ -3383,6 +3459,7 @@ function_builder::add_function (const function_instance &instance,
registered_function &rfn = *ggc_alloc<registered_function> ();
rfn.instance = instance;
rfn.decl = decl;
+ rfn.overloaded_p = overloaded_p;
vec_safe_push (registered_functions, &rfn);
return rfn;
@@ -3432,6 +3509,26 @@ function_builder::add_unique_function (const function_instance &instance,
obstack_free (&m_string_obstack, name);
}
+void
+function_builder::add_overloaded_function (const function_instance &instance,
+ const function_shape *shape)
+{
+ if (!check_required_extensions (instance))
+ return;
+
+ char *name = shape->get_name (*this, instance, true);
+
+ if (name)
+ {
+ /* To avoid API conflicting, take void return type and void argument
+ for the overloaded function. */
+ tree fntype = build_function_type (void_type_node, void_list_node);
+ add_function (instance, name, fntype, NULL_TREE, m_direct_overloads,
+ true);
+ obstack_free (&m_string_obstack, name);
+ }
+}
+
function_call_info::function_call_info (location_t location_in,
const function_instance &instance_in,
tree fndecl_in)
@@ -3852,6 +3949,13 @@ function_checker::function_checker (location_t location,
m_nargs (nargs), m_args (args)
{}
+function_resolver::function_resolver (location_t location,
+ const function_instance &instance,
+ tree fndecl,
+ vec<tree, va_gc> &arglist)
+ : function_call_info (location, instance, fndecl), m_arglist (arglist)
+{}
+
/* Report that LOCATION has a call to FNDECL in which argument ARGNO
was not an integer constant expression. ARGNO counts from zero. */
void
@@ -3967,6 +4071,39 @@ function_checker::check ()
return shape->check (*this);
}
+unsigned int
+function_resolver::get_sub_code ()
+{
+ unsigned int fun_code = DECL_MD_FUNCTION_CODE (fndecl);
+
+ return fun_code >> RISCV_BUILTIN_SHIFT;
+}
+
+tree
+function_resolver::resolve ()
+{
+ return shape->resolve (*this);
+}
+
+tree
+function_resolver::lookup ()
+{
+ unsigned int fun_code = get_sub_code ();
+ function_instance *instance
+ = base->get_non_overloaded_instance (fun_code, m_arglist);
+
+ if (!instance)
+ return NULL_TREE;
+
+ hashval_t hash = instance->hash ();
+ registered_function *rfun = function_table->find_with_hash (*instance, hash);
+
+ if (!rfun)
+ return NULL_TREE;
+
+ return rfun->decl;
+}
+
inline hashval_t
registered_function_hasher::hash (value_type value)
{
@@ -4196,6 +4333,22 @@ check_builtin_call (location_t location, vec<location_t>, unsigned int code,
TREE_TYPE (rfn.decl), nargs, args).check ();
}
+tree
+resolve_overloaded_builtin (location_t loc, unsigned int code,
+ vec<tree, va_gc> *arglist)
+{
+ if (code >= vec_safe_length (registered_functions))
+ return NULL_TREE;
+
+ const registered_function *rfun = (*registered_functions)[code];
+
+ if (!rfun || !rfun->overloaded_p)
+ return NULL_TREE;
+
+ return function_resolver (loc, rfun->instance, rfun->decl, *arglist)
+ .resolve ();
+}
+
function_instance
get_read_vl_instance (void)
{
@@ -277,6 +277,8 @@ public:
void apply_predication (const function_instance &, tree, vec<tree> &) const;
void add_unique_function (const function_instance &, const function_shape *,
tree, vec<tree> &);
+ void add_overloaded_function (const function_instance &,
+ const function_shape *);
void register_function_group (const function_group_info &);
void append_name (const char *);
void append_base_name (const char *);
@@ -288,7 +290,7 @@ private:
tree get_attributes (const function_instance &);
registered_function &add_function (const function_instance &, const char *,
- tree, tree, bool);
+ tree, tree, bool, bool);
/* True if we should create a separate decl for each instance of an
overloaded function, instead of using function_builder. */
@@ -424,6 +426,11 @@ public:
/* Expand the given call into rtl. Return the result of the function,
or an arbitrary value if the function doesn't return a result. */
virtual rtx expand (function_expander &) const = 0;
+
+ /* Return the non-overloaded function instance from the registered
+ function table if success, or NULL will be returned. */
+ virtual function_instance * get_non_overloaded_instance (
+ unsigned int, vec<tree, va_gc> &arglist) const;
};
/* A class for checking that the semantic constraints on a function call are
@@ -462,6 +469,29 @@ private:
tree *m_args;
};
+/* A class for resolving an overloaded function call. */
+class function_resolver : public function_call_info
+{
+public:
+ function_resolver (location_t, const function_instance &, tree,
+ vec<tree, va_gc> &);
+
+ /* Resolve the correlated non-overloaded function from the
+ the registered_functions table. */
+ tree resolve ();
+
+ /* Lookup the non-overloaded function from the registered
+ function table. */
+ tree lookup ();
+
+ /* Return the sub code of the fndecl. */
+ unsigned int get_sub_code ();
+
+private:
+ /* The arguments to the overloaded function. */
+ vec<tree, va_gc> &m_arglist;
+};
+
/* Classifies functions into "shapes" base on:
- Base name of the intrinsic function.
@@ -486,6 +516,10 @@ public:
/* Check whether the given call is semantically valid. Return true
if it is, otherwise report an error and return false. */
virtual bool check (function_checker &) const { return true; }
+
+ /* Try to resolve the overloaded call. Return the non-overloaded
+ function decl on success and NULL_TREE on failure. */
+ virtual tree resolve (function_resolver &) const { return NULL_TREE; };
};
extern const char *const operand_suffixes[NUM_OP_TYPES];
new file mode 100644
@@ -0,0 +1,8 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv32gcv_zvfh -mabi=ilp32 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
+
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */
new file mode 100644
@@ -0,0 +1,8 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv64gcv_zvfh -mabi=lp64 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
+
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e32,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e16,\s*m1,\s*ta,\s*ma} 2 } } */
+/* { dg-final { scan-assembler-times {vsetvli\s+zero,\s*[ax][0-9]+,\s*e8,\s*m4,\s*tu,\s*ma} 2 } } */
new file mode 100644
@@ -0,0 +1,27 @@
+#include "riscv_vector.h"
+
+vint32m1_t test_vmv_overloaded_0 (vint32m1_t src, size_t vl) {
+ return __riscv_vmv_v (src, vl);
+}
+
+vfloat16m1_t test_vmv_overloaded_1 (vfloat16m1_t src, size_t vl) {
+ return __riscv_vmv_v (src, vl);
+}
+
+vint8m4_t test_vmv_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src,
+ size_t vl) {
+ return __riscv_vmv_v_tu (maskedoff, src, vl);
+}
+
+vint32m1_t test_vmv_non_overloaded_0 (vint32m1_t src, size_t vl) {
+ return __riscv_vmv_v_v_i32m1 (src, vl);
+}
+
+vfloat16m1_t test_vmv_non_overloaded_1 (vfloat16m1_t src, size_t vl) {
+ return __riscv_vmv_v_v_f16m1 (src, vl);
+}
+
+vint8m4_t test_vmv_non_overloaded_2 (vint8m4_t maskedoff, vint8m4_t src,
+ size_t vl) {
+ return __riscv_vmv_v_v_i8m4_tu (maskedoff, src, vl);
+}