[dart2wasm] Devirtualize closure calls based on TFA direct-call metadata

Use the TFA direct-call metadata to directly call a closure in function
invocations.

Closes #55231.

Tested: existing tests cover the new code paths, but I also added a new
test.
Change-Id: Ib5f26b10efd77570e256196b4bfb07e6bef800c0
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/397260
Reviewed-by: Martin Kustermann <kustermann@google.com>
Commit-Queue: Ömer Ağacan <omersa@google.com>
diff --git a/pkg/dart2wasm/lib/closures.dart b/pkg/dart2wasm/lib/closures.dart
index 9a4c4b4..dc30f23 100644
--- a/pkg/dart2wasm/lib/closures.dart
+++ b/pkg/dart2wasm/lib/closures.dart
@@ -5,12 +5,14 @@
 import 'dart:collection';
 import 'dart:math' show min;
 
+import 'package:collection/collection.dart';
 import 'package:kernel/ast.dart';
 import 'package:vm/metadata/procedure_attributes.dart';
 import 'package:vm/transformations/type_flow/utils.dart' show UnionFind;
 import 'package:wasm_builder/wasm_builder.dart' as w;
 
 import 'class_info.dart';
+import 'param_info.dart';
 import 'translator.dart';
 
 /// Describes the implementation of a concrete closure, including its vtable
@@ -34,8 +36,16 @@
   /// The module this closure is implemented in.
   final w.ModuleBuilder module;
 
-  ClosureImplementation(this.representation, this.functions,
-      this.dynamicCallEntry, this.vtable, this.module);
+  /// [ParameterInfo] to be used when directly calling the closure.
+  final ParameterInfo directCallParamInfo;
+
+  ClosureImplementation(
+      this.representation,
+      this.functions,
+      this.dynamicCallEntry,
+      this.vtable,
+      this.module,
+      this.directCallParamInfo);
 }
 
 /// Describes the representation of closures for a particular function
@@ -131,6 +141,8 @@
   /// The field index in the vtable struct for the function entry to use when
   /// calling the closure with the given number of positional arguments and the
   /// given set of named arguments.
+  ///
+  /// `argNames` should be sorted.
   int fieldIndexForSignature(int posArgCount, List<String> argNames) {
     if (argNames.isEmpty) {
       return vtableBaseIndex + posArgCount;
@@ -155,7 +167,9 @@
 class NameCombination implements Comparable<NameCombination> {
   final List<String> names;
 
-  NameCombination(this.names);
+  NameCombination(this.names) {
+    assert(names.isSorted(Comparable.compare));
+  }
 
   @override
   int compareTo(NameCombination other) {
@@ -363,6 +377,8 @@
   /// Get the representation for closures with a specific signature, described
   /// by the number of type parameters, the maximum number of positional
   /// parameters and the names of named parameters.
+  ///
+  /// `names` should be sorted.
   ClosureRepresentation? getClosureRepresentation(
       int typeCount, int positionalCount, List<String> names) {
     final representations =
@@ -1002,10 +1018,20 @@
 /// A local function or function expression.
 class Lambda {
   final FunctionNode functionNode;
+
+  // Note: creating a `Lambda` does not add this function to the compilation
+  // queue. Make sure to get it with `Functions.getLambdaFunction` to add it
+  // to the compilation queue.
   final w.FunctionBuilder function;
+
   final Source functionNodeSource;
 
-  Lambda(this.functionNode, this.function, this.functionNodeSource);
+  /// Index of the function within the enclosing member, based on pre-order
+  /// traversal of the member body.
+  final int index;
+
+  Lambda._(
+      this.functionNode, this.function, this.functionNodeSource, this.index);
 }
 
 /// The context for one or more closures, containing their captured variables.
@@ -1141,7 +1167,10 @@
   /// does not populate [lambdas], [contexts], [captures], and
   /// [closurizedFunctions]. This mode is useful in the code generators that
   /// always have direct access to variables (instead of via a context).
-  Closures(this.translator, this._member, {bool findCaptures = true})
+  ///
+  /// When `findCaptures` is `true`, the created [Lambda]s are also added to the
+  /// compilation queue.
+  Closures(this.translator, this._member, {required bool findCaptures})
       : _nullableThisType = _member is Constructor || _member.isInstanceMember
             ? translator.preciseThisFor(_member, nullable: true) as w.RefType
             : null {
@@ -1385,7 +1414,10 @@
       functionName = "$member closure $functionNodeName at ${node.location}";
     }
     final function = module.functions.define(type, functionName);
-    closures.lambdas[node] = Lambda(node, function, _currentSource);
+    final lambda =
+        Lambda._(node, function, _currentSource, closures.lambdas.length);
+    closures.lambdas[node] = lambda;
+    translator.functions.getLambdaFunction(lambda, member, closures);
 
     functionIsSyncStarOrAsync.add(node.asyncMarker == AsyncMarker.SyncStar ||
         node.asyncMarker == AsyncMarker.Async);
diff --git a/pkg/dart2wasm/lib/code_generator.dart b/pkg/dart2wasm/lib/code_generator.dart
index 3e03a73..232d0cc 100644
--- a/pkg/dart2wasm/lib/code_generator.dart
+++ b/pkg/dart2wasm/lib/code_generator.dart
@@ -213,15 +213,6 @@
     translator.membersBeingGenerated.remove(enclosingMember);
   }
 
-  void addNestedClosuresToCompilationQueue() {
-    for (Lambda lambda in closures.lambdas.values) {
-      translator.compilationQueue.add(CompilationTask(
-          lambda.function,
-          getLambdaCodeGenerator(
-              translator, lambda, enclosingMember, closures)));
-    }
-  }
-
   // Generate the body.
   void generateInternal();
 
@@ -2388,14 +2379,11 @@
           expectedType);
     }
 
-    final Expression receiver = node.receiver;
-    final Arguments arguments = node.arguments;
-
-    int typeCount = arguments.types.length;
-    int posArgCount = arguments.positional.length;
-    List<String> argNames = arguments.named.map((a) => a.name).toList()..sort();
+    List<String> argNames = node.arguments.named.map((a) => a.name).toList()
+      ..sort();
     ClosureRepresentation? representation = translator.closureLayouter
-        .getClosureRepresentation(typeCount, posArgCount, argNames);
+        .getClosureRepresentation(node.arguments.types.length,
+            node.arguments.positional.length, argNames);
     if (representation == null) {
       // This is a dynamic function call with a signature that matches no
       // functions in the program.
@@ -2403,26 +2391,75 @@
       return translator.topInfo.nullableType;
     }
 
+    final SingleClosureTarget? directClosureCall =
+        translator.singleClosureTarget(node, representation, typeContext);
+
+    if (directClosureCall != null) {
+      return _generateDirectClosureCall(
+          node, representation, directClosureCall);
+    }
+
+    return _generateClosureInvocation(node, representation);
+  }
+
+  w.ValueType _generateDirectClosureCall(FunctionInvocation node,
+      ClosureRepresentation representation, SingleClosureTarget closureTarget) {
+    final closureStruct = representation.closureStruct;
+    final closureStructRef = w.RefType.def(closureStruct, nullable: false);
+    final signature = closureTarget.signature;
+    final paramInfo = closureTarget.paramInfo;
+    final member = closureTarget.member;
+    final lambdaFunction = closureTarget.lambdaFunction;
+
+    if (lambdaFunction == null) {
+      if (paramInfo.takesContextOrReceiver) {
+        translateExpression(node.receiver, closureStructRef);
+        b.struct_get(closureStruct, FieldIndex.closureContext);
+        translator.convertType(
+            b,
+            closureStruct.fields[FieldIndex.closureContext].type.unpacked,
+            signature.inputs[0]);
+        _visitArguments(node.arguments, signature, paramInfo, 1);
+      } else {
+        _visitArguments(node.arguments, signature, paramInfo, 0);
+      }
+      return translator.outputOrVoid(call(member.reference));
+    } else {
+      assert(paramInfo.takesContextOrReceiver);
+      translateExpression(node.receiver, closureStructRef);
+      b.struct_get(closureStruct, FieldIndex.closureContext);
+      _visitArguments(node.arguments, signature, paramInfo, 1);
+      return translator
+          .outputOrVoid(translator.callFunction(lambdaFunction, b));
+    }
+  }
+
+  w.ValueType _generateClosureInvocation(
+      FunctionInvocation node, ClosureRepresentation representation) {
+    final closureStruct = representation.closureStruct;
+
     // Evaluate receiver
-    w.StructType struct = representation.closureStruct;
-    w.Local closureLocal = addLocal(w.RefType.def(struct, nullable: false));
-    translateExpression(receiver, closureLocal.type);
+    w.Local closureLocal =
+        addLocal(w.RefType.def(closureStruct, nullable: false));
+    translateExpression(node.receiver, closureLocal.type);
     b.local_tee(closureLocal);
-    b.struct_get(struct, FieldIndex.closureContext);
+    b.struct_get(closureStruct, FieldIndex.closureContext);
 
     // Type arguments
-    for (DartType typeArg in arguments.types) {
+    for (DartType typeArg in node.arguments.types) {
       types.makeType(this, typeArg);
     }
 
     // Positional arguments
-    for (Expression arg in arguments.positional) {
+    for (Expression arg in node.arguments.positional) {
       translateExpression(arg, translator.topInfo.nullableType);
     }
 
     // Named arguments
+    final List<String> argNames =
+        node.arguments.named.map((a) => a.name).toList()..sort();
     final Map<String, w.Local> namedLocals = {};
-    for (final namedArg in arguments.named) {
+    for (final namedArg in node.arguments.named) {
       final w.Local namedLocal = addLocal(translator.topInfo.nullableType);
       namedLocals[namedArg.name] = namedLocal;
       translateExpression(namedArg.value, namedLocal.type);
@@ -2432,15 +2469,17 @@
       b.local_get(namedLocals[name]!);
     }
 
-    // Call entry point in vtable
-    int vtableFieldIndex =
-        representation.fieldIndexForSignature(posArgCount, argNames);
-    w.FunctionType functionType =
+    final int vtableFieldIndex = representation.fieldIndexForSignature(
+        node.arguments.positional.length, argNames);
+    final w.FunctionType functionType =
         representation.getVtableFieldType(vtableFieldIndex);
+
+    // Call entry point in vtable
     b.local_get(closureLocal);
-    b.struct_get(struct, FieldIndex.closureVtable);
+    b.struct_get(closureStruct, FieldIndex.closureVtable);
     b.struct_get(representation.vtableStruct, vtableFieldIndex);
     b.call_ref(functionType);
+
     return translator.topInfo.nullableType;
   }
 
@@ -3181,7 +3220,7 @@
       return;
     }
 
-    closures = Closures(translator, member);
+    closures = translator.getClosures(member);
 
     setupParametersAndContexts(member, useUncheckedEntry: useUncheckedEntry);
 
@@ -3192,7 +3231,6 @@
 
     _implicitReturn();
     b.end();
-    addNestedClosuresToCompilationQueue();
   }
 }
 
@@ -3209,7 +3247,7 @@
     // used by `makeType` below, when generating runtime types of type
     // parameters of the function type, but the type parameters are not
     // captured, always loaded from the `this` struct.
-    closures = Closures(translator, member, findCaptures: false);
+    closures = translator.getClosures(member, findCaptures: false);
 
     _initializeThis(member.reference);
     Procedure procedure = member as Procedure;
@@ -3242,7 +3280,7 @@
     // Initialize [Closures] without [Closures.captures]: Similar to
     // [TearOffCodeGenerator], type parameters will be loaded from the `this`
     // struct.
-    closures = Closures(translator, member, findCaptures: false);
+    closures = translator.getClosures(member, findCaptures: false);
     if (member is Field ||
         (member is Procedure && (member as Procedure).isSetter)) {
       _generateFieldSetterTypeCheckerMethod();
@@ -3482,7 +3520,6 @@
       generateInitializerList();
     }
     b.end();
-    addNestedClosuresToCompilationQueue();
   }
 
   // Generates a constructor's initializer list method, and returns:
@@ -3848,7 +3885,7 @@
     setSourceMapSourceAndFileOffset(source, field.fileOffset);
 
     // Static field initializer function
-    closures = Closures(translator, field);
+    closures = translator.getClosures(field);
 
     w.Global global = translator.globals.getGlobalForStaticField(field);
     w.Global? flag = translator.globals.getGlobalInitializedFlag(field);
@@ -3861,7 +3898,6 @@
     b.global_get(global);
     translator.convertType(b, global.type.type, outputs.single);
     b.end();
-    addNestedClosuresToCompilationQueue();
   }
 }
 
@@ -3944,7 +3980,7 @@
     // that instantiates types uses closure information to see whether a type
     // parameter was captured (and loads it from context chain) or not (and
     // loads it directly from `this`).
-    closures = Closures(translator, field, findCaptures: false);
+    closures = translator.getClosures(field, findCaptures: false);
 
     final source = field.enclosingComponent!.uriToSource[field.fileUri]!;
     setSourceMapSourceAndFileOffset(source, field.fileOffset);
diff --git a/pkg/dart2wasm/lib/functions.dart b/pkg/dart2wasm/lib/functions.dart
index 6d5b6f7..4b93afc 100644
--- a/pkg/dart2wasm/lib/functions.dart
+++ b/pkg/dart2wasm/lib/functions.dart
@@ -20,6 +20,8 @@
 
   // Wasm function for each Dart function
   final Map<Reference, w.BaseFunction> _functions = {};
+  // Wasm function for each function expression and local function.
+  final Map<Lambda, w.BaseFunction> _lambdas = {};
   // Names of exported functions
   final Map<Reference, String> _exports = {};
   // Selector IDs that are invoked via GDT.
@@ -129,6 +131,17 @@
     });
   }
 
+  w.BaseFunction getLambdaFunction(
+      Lambda lambda, Member enclosingMember, Closures enclosingMemberClosures) {
+    return _lambdas.putIfAbsent(lambda, () {
+      translator.compilationQueue.add(CompilationTask(
+          lambda.function,
+          getLambdaCodeGenerator(
+              translator, lambda, enclosingMember, enclosingMemberClosures)));
+      return lambda.function;
+    });
+  }
+
   w.FunctionType getFunctionType(Reference target) {
     // We first try to get the function type by seeing if we already
     // compiled the [target] function.
@@ -278,7 +291,7 @@
     // context argument if context must be shared between them. Generate the
     // contexts the first time we visit a constructor.
     translator.constructorClosures[node.reference] ??=
-        Closures(translator, node);
+        translator.getClosures(node);
 
     if (target.isInitializerReference) {
       return _getInitializerType(node, target, arguments);
diff --git a/pkg/dart2wasm/lib/state_machine.dart b/pkg/dart2wasm/lib/state_machine.dart
index 3ac5bb1..4ed78f4 100644
--- a/pkg/dart2wasm/lib/state_machine.dart
+++ b/pkg/dart2wasm/lib/state_machine.dart
@@ -595,7 +595,7 @@
     setSourceMapSource(source);
     setSourceMapFileOffset(member.fileOffset);
 
-    closures = Closures(translator, member);
+    closures = translator.getClosures(member);
 
     // We don't support inlining state machine functions atm. Only when we
     // inline and have call-site guarantees we would use the unchecked entry.
@@ -605,7 +605,6 @@
     if (context != null && context.isEmpty) context = context.parent;
 
     generateOuter(member.function, context, source);
-    addNestedClosuresToCompilationQueue();
   }
 }
 
diff --git a/pkg/dart2wasm/lib/translator.dart b/pkg/dart2wasm/lib/translator.dart
index 5e79222..8e38d8c 100644
--- a/pkg/dart2wasm/lib/translator.dart
+++ b/pkg/dart2wasm/lib/translator.dart
@@ -326,6 +326,15 @@
 
   bool isMainModule(w.ModuleBuilder module) => _builderToOutput[module]!.isMain;
 
+  /// Maps compiled members to their [Closures], with capture information.
+  final Map<Member, Closures> _memberClosures = {};
+
+  Closures getClosures(Member member, {bool findCaptures = true}) =>
+      findCaptures
+          ? _memberClosures.putIfAbsent(
+              member, () => Closures(this, member, findCaptures: true))
+          : Closures(this, member, findCaptures: false);
+
   Translator(this.component, this.coreTypes, this.index, this.recordClasses,
       this._moduleOutputData, this.options)
       : libraries = component.libraries,
@@ -869,8 +878,8 @@
     ib.struct_new(representation.vtableStruct);
     ib.end();
 
-    final implementation = ClosureImplementation(
-        representation, functions, dynamicCallEntry, vtable, targetModule);
+    final implementation = ClosureImplementation(representation, functions,
+        dynamicCallEntry, vtable, targetModule, paramInfo);
     closureImplementations[functionNode] = implementation;
     return implementation;
   }
@@ -1035,6 +1044,74 @@
     return directCallMetadata[node]?.targetMember;
   }
 
+  /// Direct call information of a [FunctionInvocation] based on TFA's direct
+  /// call metadata.
+  SingleClosureTarget? singleClosureTarget(FunctionInvocation node,
+      ClosureRepresentation representation, StaticTypeContext typeContext) {
+    final (Member, int)? directClosureCall =
+        directCallMetadata[node]?.targetClosure;
+
+    if (directClosureCall == null) {
+      return null;
+    }
+
+    // To avoid using the `Null` class, avoid devirtualizing to `Null` members.
+    // `noSuchMethod` is also not allowed as `Null` inherits it.
+    if (directClosureCall.$1.enclosingClass == coreTypes.deprecatedNullClass ||
+        directClosureCall.$1 == objectNoSuchMethod) {
+      return null;
+    }
+
+    final member = directClosureCall.$1;
+    final closureId = directClosureCall.$2;
+
+    if (closureId == 0) {
+      // The member is called as a closure (tear-off). We'll generate a direct
+      // call to the member.
+      final lambdaDartType =
+          member.function!.computeFunctionType(Nullability.nonNullable);
+
+      // Check that type of the receiver is a subtype of
+      if (!typeEnvironment.isSubtypeOf(
+          lambdaDartType,
+          node.receiver.getStaticType(typeContext),
+          SubtypeCheckMode.withNullabilities)) {
+        return null;
+      }
+
+      return SingleClosureTarget._(
+        member,
+        paramInfoForDirectCall(member.reference),
+        signatureForDirectCall(member.reference),
+        null,
+      );
+    } else {
+      // A closure in the member is called.
+      final Closures enclosingMemberClosures =
+          getClosures(member, findCaptures: true);
+      final Lambda lambda = enclosingMemberClosures.lambdas.values
+          .firstWhere((lambda) => lambda.index == closureId - 1);
+      final FunctionType lambdaDartType =
+          lambda.functionNode.computeFunctionType(Nullability.nonNullable);
+      final w.BaseFunction lambdaFunction =
+          functions.getLambdaFunction(lambda, member, enclosingMemberClosures);
+
+      if (!typeEnvironment.isSubtypeOf(
+          lambdaDartType,
+          node.receiver.getStaticType(typeContext),
+          SubtypeCheckMode.withNullabilities)) {
+        return null;
+      }
+
+      return SingleClosureTarget._(
+        member,
+        ParameterInfo.fromLocalFunction(lambda.functionNode),
+        lambdaFunction.type,
+        lambdaFunction,
+      );
+    }
+  }
+
   bool canSkipImplicitCheck(VariableDeclaration node) {
     return inferredArgTypeMetadata[node]?.skipCheck ?? false;
   }
@@ -2106,3 +2183,24 @@
     return importingModule.tags.import(moduleName, importName, definition.type);
   }
 }
+
+class SingleClosureTarget {
+  /// When `lambdaFunction` is null, the member being directly called. Otherwise
+  /// the enclosing member of the closure being called.
+  final Member member;
+
+  /// [ParameterInfo] specifying how to compile arguments to the closure or
+  /// member.
+  final ParameterInfo paramInfo;
+
+  /// Wasm function type that goes along with the [paramInfo] for compiling
+  /// arguments.
+  final w.FunctionType signature;
+
+  /// If the callee is a local function or function expression (intead of a
+  /// member), this Wasm function for it.
+  final w.BaseFunction? lambdaFunction;
+
+  SingleClosureTarget._(
+      this.member, this.paramInfo, this.signature, this.lambdaFunction);
+}
diff --git a/pkg/vm/lib/metadata/direct_call.dart b/pkg/vm/lib/metadata/direct_call.dart
index e49910a..abba2a2 100644
--- a/pkg/vm/lib/metadata/direct_call.dart
+++ b/pkg/vm/lib/metadata/direct_call.dart
@@ -39,6 +39,15 @@
   bool get checkReceiverForNull => (_flags & flagCheckReceiverForNull) != 0;
   bool get isClosure => (_flags & flagClosure) != 0;
 
+  /// When calling a closure, the enclosing member of the closure, and the
+  /// closure index.
+  ///
+  /// Closures in a member are assigned ids based on pre-order traversal of the
+  /// member body, and the member itself also counts as a closure (for
+  /// tear-offs). So index 0 is the member itself, called as a closure
+  /// (tear-off).
+  (Member, int)? get targetClosure => isClosure ? (_member, _closureId) : null;
+
   @override
   String toString() => isClosure
       ? 'closure ${_closureId} in ${_member.toText(astTextStrategyForTesting)}'
diff --git a/tests/language/function/direct_invocation_test.dart b/tests/language/function/direct_invocation_test.dart
new file mode 100644
index 0000000..1d675f1
--- /dev/null
+++ b/tests/language/function/direct_invocation_test.dart
@@ -0,0 +1,35 @@
+// Copyright (c) 2024, the Dart project authors.  Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+// Test using TFA's direct-call metadata in function invocations.
+
+import "package:expect/expect.dart";
+
+int f(int x, [int? y]) => x;
+
+// Calls `f` directly.
+int test1(int cb(int a)) => cb(1);
+
+class A {
+  int f(int x, [int? y]) => x;
+}
+
+// Calls `A.f` directly.
+int test2(int cb(int a)) => cb(2);
+
+class B {
+  int nested() {
+    int cb1(int x, [int? y]) => x;
+    return test3(cb1);
+  }
+}
+
+// Calls the nested closure `cb1` directly.
+int test3(int cb(int a)) => cb(3);
+
+void main() {
+  Expect.equals(test1(f), 1);
+  Expect.equals(test2(A().f), 2);
+  Expect.equals(B().nested(), 3);
+}