[committed,050/103] gccrs: Closure support at CallExpr

Message ID 20230221120230.596966-51-arthur.cohen@embecosm.com
State Unresolved
Headers
Series [committed,001/103] gccrs: Fix missing dead code analysis ICE on local enum definition |

Checks

Context Check Description
snail/gcc-patch-check warning Git am fail log

Commit Message

Arthur Cohen Feb. 21, 2023, 12:01 p.m. UTC
  From: Philip Herron <philip.herron@embecosm.com>

gcc/rust/ChangeLog:

	* backend/rust-compile-context.h: Add new functions: `insert_closure_decl` and
	`lookup_closure_decl`.
	* backend/rust-compile-expr.cc (CompileExpr::visit): Start compiling Closures properly.
	(CompileExpr::generate_closure_function): New function.
	(CompileExpr::generate_closure_fntype): Likewise.
	* backend/rust-compile-expr.h: Declare `generate_closure_function` and
	`generate_closure_fntype`.
	* backend/rust-compile-type.cc (TyTyResolveCompile::visit): Visit closure types properly.
	* backend/rust-mangle.cc (legacy_mangle_name): Add support for closures.
	* backend/rust-tree.h (RS_CLOSURE_FLAG): Add new tree macro.
	(RS_CLOSURE_TYPE_P): And checking for it on tree nodes.
	* typecheck/rust-tyty.cc (ClosureType::is_equal): Add implementation.

gcc/testsuite/ChangeLog:

	* rust/execute/torture/closure1.rs: New test.
---
 gcc/rust/backend/rust-compile-context.h       |  31 ++
 gcc/rust/backend/rust-compile-expr.cc         | 280 +++++++++++++++++-
 gcc/rust/backend/rust-compile-expr.h          |  10 +
 gcc/rust/backend/rust-compile-type.cc         |  10 +-
 gcc/rust/backend/rust-mangle.cc               |   6 +
 gcc/rust/backend/rust-tree.h                  |   5 +
 gcc/rust/typecheck/rust-tyty.cc               |  13 +-
 .../rust/execute/torture/closure1.rs          |  18 ++
 8 files changed, 361 insertions(+), 12 deletions(-)
 create mode 100644 gcc/testsuite/rust/execute/torture/closure1.rs
  

Patch

diff --git a/gcc/rust/backend/rust-compile-context.h b/gcc/rust/backend/rust-compile-context.h
index 49f78e19b20..d2d3a53f182 100644
--- a/gcc/rust/backend/rust-compile-context.h
+++ b/gcc/rust/backend/rust-compile-context.h
@@ -147,6 +147,35 @@  public:
     mono_fns[dId].push_back ({ref, fn});
   }
 
+  void insert_closure_decl (const TyTy::ClosureType *ref, tree fn)
+  {
+    auto dId = ref->get_def_id ();
+    auto it = mono_closure_fns.find (dId);
+    if (it == mono_closure_fns.end ())
+      mono_closure_fns[dId] = {};
+
+    mono_closure_fns[dId].push_back ({ref, fn});
+  }
+
+  tree lookup_closure_decl (const TyTy::ClosureType *ref)
+  {
+    auto dId = ref->get_def_id ();
+    auto it = mono_closure_fns.find (dId);
+    if (it == mono_closure_fns.end ())
+      return error_mark_node;
+
+    for (auto &i : it->second)
+      {
+	const TyTy::ClosureType *t = i.first;
+	tree fn = i.second;
+
+	if (ref->is_equal (*t))
+	  return fn;
+      }
+
+    return error_mark_node;
+  }
+
   bool lookup_function_decl (HirId id, tree *fn, DefId dId = UNKNOWN_DEFID,
 			     const TyTy::BaseType *ref = nullptr,
 			     const std::string &asm_name = std::string ())
@@ -343,6 +372,8 @@  private:
   std::vector<tree> loop_begin_labels;
   std::map<DefId, std::vector<std::pair<const TyTy::BaseType *, tree>>>
     mono_fns;
+  std::map<DefId, std::vector<std::pair<const TyTy::ClosureType *, tree>>>
+    mono_closure_fns;
   std::map<HirId, tree> implicit_pattern_bindings;
   std::map<hashval_t, tree> main_variants;
 
diff --git a/gcc/rust/backend/rust-compile-expr.cc b/gcc/rust/backend/rust-compile-expr.cc
index 724a93a68bd..d2d9ae0a233 100644
--- a/gcc/rust/backend/rust-compile-expr.cc
+++ b/gcc/rust/backend/rust-compile-expr.cc
@@ -1589,9 +1589,7 @@  CompileExpr::visit (HIR::CallExpr &expr)
     }
 
   // must be a tuple constructor
-  bool is_fn = tyty->get_kind () == TyTy::TypeKind::FNDEF
-	       || tyty->get_kind () == TyTy::TypeKind::FNPTR;
-  bool is_adt_ctor = !is_fn;
+  bool is_adt_ctor = tyty->get_kind () == TyTy::TypeKind::ADT;
   if (is_adt_ctor)
     {
       rust_assert (tyty->get_kind () == TyTy::TypeKind::ADT);
@@ -1692,6 +1690,57 @@  CompileExpr::visit (HIR::CallExpr &expr)
     return true;
   };
 
+  auto fn_address = CompileExpr::Compile (expr.get_fnexpr (), ctx);
+
+  // is this a closure call?
+  if (RS_CLOSURE_TYPE_P (TREE_TYPE (fn_address)))
+    {
+      rust_assert (tyty->get_kind () == TyTy::TypeKind::CLOSURE);
+      TyTy::ClosureType *closure = static_cast<TyTy::ClosureType *> (tyty);
+
+      std::vector<tree> tuple_arg_vals;
+      for (auto &argument : expr.get_arguments ())
+	{
+	  auto rvalue = CompileExpr::Compile (argument.get (), ctx);
+	  tuple_arg_vals.push_back (rvalue);
+	}
+
+      tree tuple_args_tyty
+	= TyTyResolveCompile::compile (ctx, &closure->get_parameters ());
+      tree tuple_args
+	= ctx->get_backend ()->constructor_expression (tuple_args_tyty, false,
+						       tuple_arg_vals, -1,
+						       expr.get_locus ());
+
+      // need to apply any autoderef's to the self argument
+      HirId autoderef_mappings_id = expr.get_mappings ().get_hirid ();
+      std::vector<Resolver::Adjustment> *adjustments = nullptr;
+      bool ok
+	= ctx->get_tyctx ()->lookup_autoderef_mappings (autoderef_mappings_id,
+							&adjustments);
+      rust_assert (ok);
+
+      // apply adjustments for the fn call
+      tree self
+	= resolve_adjustements (*adjustments, fn_address, expr.get_locus ());
+
+      // args are always self, and the tuple of the args we are passing where
+      // self is the path of the call-expr in this case the fn_address
+      std::vector<tree> args;
+      args.push_back (self);
+      args.push_back (tuple_args);
+
+      // get the fn call address
+      tree closure_call_site = ctx->lookup_closure_decl (closure);
+      tree closure_call_address
+	= address_expression (closure_call_site, expr.get_locus ());
+      translated
+	= ctx->get_backend ()->call_expression (closure_call_address, args,
+						nullptr /* static chain ?*/,
+						expr.get_locus ());
+      return;
+    }
+
   bool is_varadic = false;
   if (tyty->get_kind () == TyTy::TypeKind::FNDEF)
     {
@@ -1699,13 +1748,13 @@  CompileExpr::visit (HIR::CallExpr &expr)
       is_varadic = fn->is_varadic ();
     }
 
-  size_t required_num_args;
+  size_t required_num_args = expr.get_arguments ().size ();
   if (tyty->get_kind () == TyTy::TypeKind::FNDEF)
     {
       const TyTy::FnType *fn = static_cast<const TyTy::FnType *> (tyty);
       required_num_args = fn->num_params ();
     }
-  else
+  else if (tyty->get_kind () == TyTy::TypeKind::FNPTR)
     {
       const TyTy::FnPtr *fn = static_cast<const TyTy::FnPtr *> (tyty);
       required_num_args = fn->num_params ();
@@ -1746,8 +1795,7 @@  CompileExpr::visit (HIR::CallExpr &expr)
       args.push_back (rvalue);
     }
 
-  // must be a call to a function
-  auto fn_address = CompileExpr::Compile (expr.get_fnexpr (), ctx);
+  // must be a regular call to a function
   translated = ctx->get_backend ()->call_expression (fn_address, args, nullptr,
 						     expr.get_locus ());
 }
@@ -2806,7 +2854,223 @@  CompileExpr::visit (HIR::ArrayIndexExpr &expr)
 void
 CompileExpr::visit (HIR::ClosureExpr &expr)
 {
-  gcc_unreachable ();
+  TyTy::BaseType *closure_expr_ty = nullptr;
+  if (!ctx->get_tyctx ()->lookup_type (expr.get_mappings ().get_hirid (),
+				       &closure_expr_ty))
+    {
+      rust_fatal_error (expr.get_locus (),
+			"did not resolve type for this ClosureExpr");
+      return;
+    }
+  rust_assert (closure_expr_ty->get_kind () == TyTy::TypeKind::CLOSURE);
+  TyTy::ClosureType *closure_tyty
+    = static_cast<TyTy::ClosureType *> (closure_expr_ty);
+  tree compiled_closure_tyty = TyTyResolveCompile::compile (ctx, closure_tyty);
+
+  // generate closure function
+  generate_closure_function (expr, *closure_tyty, compiled_closure_tyty);
+
+  // lets ignore state capture for now we need to instantiate the struct anyway
+  // then generate the function
+
+  std::vector<tree> vals;
+  // TODO
+  // setup argument captures based on the mode?
+
+  translated
+    = ctx->get_backend ()->constructor_expression (compiled_closure_tyty, false,
+						   vals, -1, expr.get_locus ());
+}
+
+tree
+CompileExpr::generate_closure_function (HIR::ClosureExpr &expr,
+					TyTy::ClosureType &closure_tyty,
+					tree compiled_closure_tyty)
+{
+  TyTy::FnType *fn_tyty = nullptr;
+  tree compiled_fn_type
+    = generate_closure_fntype (expr, closure_tyty, compiled_closure_tyty,
+			       &fn_tyty);
+  if (compiled_fn_type == error_mark_node)
+    return error_mark_node;
+
+  const Resolver::CanonicalPath &parent_canonical_path
+    = closure_tyty.get_ident ().path;
+  Resolver::CanonicalPath path = parent_canonical_path.append (
+    Resolver::CanonicalPath::new_seg (UNKNOWN_NODEID, "{{closure}}"));
+
+  std::string ir_symbol_name = path.get ();
+  std::string asm_name = ctx->mangle_item (&closure_tyty, path);
+
+  unsigned int flags = 0;
+  tree fndecl
+    = ctx->get_backend ()->function (compiled_fn_type, ir_symbol_name, asm_name,
+				     flags, expr.get_locus ());
+
+  // insert into the context
+  ctx->insert_function_decl (fn_tyty, fndecl);
+  ctx->insert_closure_decl (&closure_tyty, fndecl);
+
+  // setup the parameters
+  std::vector<Bvariable *> param_vars;
+
+  // closure self
+  Bvariable *self_param
+    = ctx->get_backend ()->parameter_variable (fndecl, "$closure",
+					       compiled_closure_tyty,
+					       expr.get_locus ());
+  DECL_ARTIFICIAL (self_param->get_decl ()) = 1;
+  param_vars.push_back (self_param);
+
+  // setup the implicit argument captures
+  // TODO
+
+  // args tuple
+  tree args_type
+    = TyTyResolveCompile::compile (ctx, &closure_tyty.get_parameters ());
+  Bvariable *args_param
+    = ctx->get_backend ()->parameter_variable (fndecl, "args", args_type,
+					       expr.get_locus ());
+  param_vars.push_back (args_param);
+
+  // setup the implicit mappings for the arguments. Since argument passing to
+  // closure functions is done via passing a tuple but the closure body expects
+  // just normal arguments this means we need to destructure them similar to
+  // what we do in MatchExpr's. This means when we have a closure-param of a we
+  // actually setup the destructure to take from the args tuple
+
+  tree args_param_expr = args_param->get_tree (expr.get_locus ());
+  size_t i = 0;
+  for (auto &closure_param : expr.get_params ())
+    {
+      tree compiled_param_var = ctx->get_backend ()->struct_field_expression (
+	args_param_expr, i, closure_param.get_locus ());
+
+      const HIR::Pattern &param_pattern = *closure_param.get_pattern ();
+      ctx->insert_pattern_binding (
+	param_pattern.get_pattern_mappings ().get_hirid (), compiled_param_var);
+      i++;
+    }
+
+  if (!ctx->get_backend ()->function_set_parameters (fndecl, param_vars))
+    return error_mark_node;
+
+  // lookup locals
+  HIR::Expr *function_body = expr.get_expr ().get ();
+  auto body_mappings = function_body->get_mappings ();
+  Resolver::Rib *rib = nullptr;
+  bool ok
+    = ctx->get_resolver ()->find_name_rib (body_mappings.get_nodeid (), &rib);
+  rust_assert (ok);
+
+  std::vector<Bvariable *> locals
+    = compile_locals_for_block (ctx, *rib, fndecl);
+
+  tree enclosing_scope = NULL_TREE;
+  Location start_location = function_body->get_locus ();
+  Location end_location = function_body->get_locus ();
+  bool is_block_expr
+    = function_body->get_expression_type () == HIR::Expr::ExprType::Block;
+  if (is_block_expr)
+    {
+      HIR::BlockExpr *body = static_cast<HIR::BlockExpr *> (function_body);
+      start_location = body->get_locus ();
+      end_location = body->get_end_locus ();
+    }
+
+  tree code_block = ctx->get_backend ()->block (fndecl, enclosing_scope, locals,
+						start_location, end_location);
+  ctx->push_block (code_block);
+
+  TyTy::BaseType *tyret = &closure_tyty.get_result_type ();
+  bool function_has_return = !closure_tyty.get_result_type ().is_unit ();
+  Bvariable *return_address = nullptr;
+  if (function_has_return)
+    {
+      tree return_type = TyTyResolveCompile::compile (ctx, tyret);
+
+      bool address_is_taken = false;
+      tree ret_var_stmt = NULL_TREE;
+
+      return_address = ctx->get_backend ()->temporary_variable (
+	fndecl, code_block, return_type, NULL, address_is_taken,
+	expr.get_locus (), &ret_var_stmt);
+
+      ctx->add_statement (ret_var_stmt);
+    }
+
+  ctx->push_fn (fndecl, return_address);
+
+  if (is_block_expr)
+    {
+      HIR::BlockExpr *body = static_cast<HIR::BlockExpr *> (function_body);
+      compile_function_body (ctx, fndecl, *body, true);
+    }
+  else
+    {
+      tree value = CompileExpr::Compile (function_body, ctx);
+      tree return_expr
+	= ctx->get_backend ()->return_statement (fndecl, {value},
+						 function_body->get_locus ());
+      ctx->add_statement (return_expr);
+    }
+
+  tree bind_tree = ctx->pop_block ();
+
+  gcc_assert (TREE_CODE (bind_tree) == BIND_EXPR);
+  DECL_SAVED_TREE (fndecl) = bind_tree;
+
+  ctx->pop_fn ();
+  ctx->push_function (fndecl);
+
+  return fndecl;
+}
+
+tree
+CompileExpr::generate_closure_fntype (HIR::ClosureExpr &expr,
+				      const TyTy::ClosureType &closure_tyty,
+				      tree compiled_closure_tyty,
+				      TyTy::FnType **fn_tyty)
+{
+  // grab the specified_bound
+  rust_assert (closure_tyty.num_specified_bounds () == 1);
+  const TyTy::TypeBoundPredicate &predicate
+    = *closure_tyty.get_specified_bounds ().begin ();
+
+  // ensure the fn_once_output associated type is set
+  closure_tyty.setup_fn_once_output ();
+
+  // the function signature is based on the trait bound that the closure
+  // implements which is determined at the type resolution time
+  //
+  // https://github.com/rust-lang/rust/blob/7807a694c2f079fd3f395821bcc357eee8650071/library/core/src/ops/function.rs#L54-L71
+
+  TyTy::TypeBoundPredicateItem item = TyTy::TypeBoundPredicateItem::error ();
+  if (predicate.get_name ().compare ("FnOnce") == 0)
+    {
+      item = predicate.lookup_associated_item ("call_once");
+    }
+  else if (predicate.get_name ().compare ("FnMut") == 0)
+    {
+      item = predicate.lookup_associated_item ("call_mut");
+    }
+  else if (predicate.get_name ().compare ("Fn") == 0)
+    {
+      item = predicate.lookup_associated_item ("call");
+    }
+  else
+    {
+      // FIXME error message?
+      gcc_unreachable ();
+      return error_mark_node;
+    }
+
+  rust_assert (!item.is_error ());
+
+  TyTy::BaseType *item_tyty = item.get_tyty_for_receiver (&closure_tyty);
+  rust_assert (item_tyty->get_kind () == TyTy::TypeKind::FNDEF);
+  *fn_tyty = static_cast<TyTy::FnType *> (item_tyty);
+  return TyTyResolveCompile::compile (ctx, item_tyty);
 }
 
 } // namespace Compile
diff --git a/gcc/rust/backend/rust-compile-expr.h b/gcc/rust/backend/rust-compile-expr.h
index 7fc3f5e7f4d..c734406e0da 100644
--- a/gcc/rust/backend/rust-compile-expr.h
+++ b/gcc/rust/backend/rust-compile-expr.h
@@ -142,6 +142,16 @@  protected:
 			  const TyTy::ArrayType &array_tyty, tree array_type,
 			  HIR::ArrayElemsCopied &elems);
 
+protected:
+  tree generate_closure_function (HIR::ClosureExpr &expr,
+				  TyTy::ClosureType &closure_tyty,
+				  tree compiled_closure_tyty);
+
+  tree generate_closure_fntype (HIR::ClosureExpr &expr,
+				const TyTy::ClosureType &closure_tyty,
+				tree compiled_closure_tyty,
+				TyTy::FnType **fn_tyty);
+
 private:
   CompileExpr (Context *ctx);
 
diff --git a/gcc/rust/backend/rust-compile-type.cc b/gcc/rust/backend/rust-compile-type.cc
index fe1b7ce95e3..824cb3a56ef 100644
--- a/gcc/rust/backend/rust-compile-type.cc
+++ b/gcc/rust/backend/rust-compile-type.cc
@@ -97,9 +97,15 @@  TyTyResolveCompile::visit (const TyTy::InferType &)
 }
 
 void
-TyTyResolveCompile::visit (const TyTy::ClosureType &)
+TyTyResolveCompile::visit (const TyTy::ClosureType &type)
 {
-  gcc_unreachable ();
+  std::vector<Backend::typed_identifier> fields;
+  tree type_record = ctx->get_backend ()->struct_type (fields);
+  RS_CLOSURE_FLAG (type_record) = 1;
+
+  std::string named_struct_str = type.get_ident ().path.get () + "{{closure}}";
+  translated = ctx->get_backend ()->named_type (named_struct_str, type_record,
+						type.get_ident ().locus);
 }
 
 void
diff --git a/gcc/rust/backend/rust-mangle.cc b/gcc/rust/backend/rust-mangle.cc
index 4d202078a70..83aefa7997a 100644
--- a/gcc/rust/backend/rust-mangle.cc
+++ b/gcc/rust/backend/rust-mangle.cc
@@ -13,6 +13,8 @@  static const std::string kMangledRef = "$RF$";
 static const std::string kMangledPtr = "$BP$";
 static const std::string kMangledLeftSqParen = "$u5b$";	 // [
 static const std::string kMangledRightSqParen = "$u5d$"; // ]
+static const std::string kMangledLeftBrace = "$u7b$";	 // {
+static const std::string kMangledRightBrace = "$u7d$";	 // }
 static const std::string kQualPathBegin = "_" + kMangledSubstBegin;
 static const std::string kMangledComma = "$C$";
 
@@ -66,6 +68,10 @@  legacy_mangle_name (const std::string &name)
 	m = kMangledLeftSqParen;
       else if (c == ']')
 	m = kMangledRightSqParen;
+      else if (c == '{')
+	m = kMangledLeftBrace;
+      else if (c == '}')
+	m = kMangledRightBrace;
       else if (c == ',')
 	m = kMangledComma;
       else if (c == ':')
diff --git a/gcc/rust/backend/rust-tree.h b/gcc/rust/backend/rust-tree.h
index 41dd012bd6d..284fd873c1c 100644
--- a/gcc/rust/backend/rust-tree.h
+++ b/gcc/rust/backend/rust-tree.h
@@ -82,6 +82,11 @@ 
 #define SLICE_TYPE_P(TYPE)                                                     \
   (TREE_CODE (TYPE) == RECORD_TYPE && TREE_LANG_FLAG_0 (TYPE))
 
+// lambda?
+#define RS_CLOSURE_FLAG TREE_LANG_FLAG_1
+#define RS_CLOSURE_TYPE_P(TYPE)                                                \
+  (TREE_CODE (TYPE) == RECORD_TYPE && TREE_LANG_FLAG_1 (TYPE))
+
 /* Returns true if NODE is a pointer to member function type.  */
 #define TYPE_PTRMEMFUNC_P(NODE)                                                \
   (TREE_CODE (NODE) == RECORD_TYPE && TYPE_PTRMEMFUNC_FLAG (NODE))
diff --git a/gcc/rust/typecheck/rust-tyty.cc b/gcc/rust/typecheck/rust-tyty.cc
index 0d96c0f04fd..bdb2d909b86 100644
--- a/gcc/rust/typecheck/rust-tyty.cc
+++ b/gcc/rust/typecheck/rust-tyty.cc
@@ -1696,8 +1696,17 @@  ClosureType::can_eq (const BaseType *other, bool emit_errors) const
 bool
 ClosureType::is_equal (const BaseType &other) const
 {
-  gcc_unreachable ();
-  return false;
+  if (other.get_kind () != TypeKind::CLOSURE)
+    return false;
+
+  const ClosureType &other2 = static_cast<const ClosureType &> (other);
+  if (get_def_id () != other2.get_def_id ())
+    return false;
+
+  if (!get_parameters ().is_equal (other2.get_parameters ()))
+    return false;
+
+  return get_result_type ().is_equal (other2.get_result_type ());
 }
 
 BaseType *
diff --git a/gcc/testsuite/rust/execute/torture/closure1.rs b/gcc/testsuite/rust/execute/torture/closure1.rs
new file mode 100644
index 00000000000..62afa78a038
--- /dev/null
+++ b/gcc/testsuite/rust/execute/torture/closure1.rs
@@ -0,0 +1,18 @@ 
+extern "C" {
+    fn printf(s: *const i8, ...);
+}
+
+#[lang = "fn_once"]
+pub trait FnOnce<Args> {
+    #[lang = "fn_once_output"]
+    type Output;
+
+    extern "rust-call" fn call_once(self, args: Args) -> Self::Output;
+}
+
+fn main() -> i32 {
+    let closure_annotated = |i: i32| -> i32 { i + 1 };
+
+    let i = 1;
+    closure_annotated(i) - 2
+}