[dart2wasm] Devirtualize closure calls based on TFA direct-call metadata
Use the TFA direct-call metadata to directly call a closure in function
invocations.
Closes #55231.
Tested: existing tests cover the new code paths, but I also added a new
test.
Change-Id: Ib5f26b10efd77570e256196b4bfb07e6bef800c0
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/397260
Reviewed-by: Martin Kustermann <kustermann@google.com>
Commit-Queue: Ömer Ağacan <omersa@google.com>
diff --git a/pkg/dart2wasm/lib/closures.dart b/pkg/dart2wasm/lib/closures.dart
index 9a4c4b4..dc30f23 100644
--- a/pkg/dart2wasm/lib/closures.dart
+++ b/pkg/dart2wasm/lib/closures.dart
@@ -5,12 +5,14 @@
import 'dart:collection';
import 'dart:math' show min;
+import 'package:collection/collection.dart';
import 'package:kernel/ast.dart';
import 'package:vm/metadata/procedure_attributes.dart';
import 'package:vm/transformations/type_flow/utils.dart' show UnionFind;
import 'package:wasm_builder/wasm_builder.dart' as w;
import 'class_info.dart';
+import 'param_info.dart';
import 'translator.dart';
/// Describes the implementation of a concrete closure, including its vtable
@@ -34,8 +36,16 @@
/// The module this closure is implemented in.
final w.ModuleBuilder module;
- ClosureImplementation(this.representation, this.functions,
- this.dynamicCallEntry, this.vtable, this.module);
+ /// [ParameterInfo] to be used when directly calling the closure.
+ final ParameterInfo directCallParamInfo;
+
+ ClosureImplementation(
+ this.representation,
+ this.functions,
+ this.dynamicCallEntry,
+ this.vtable,
+ this.module,
+ this.directCallParamInfo);
}
/// Describes the representation of closures for a particular function
@@ -131,6 +141,8 @@
/// The field index in the vtable struct for the function entry to use when
/// calling the closure with the given number of positional arguments and the
/// given set of named arguments.
+ ///
+ /// `argNames` should be sorted.
int fieldIndexForSignature(int posArgCount, List<String> argNames) {
if (argNames.isEmpty) {
return vtableBaseIndex + posArgCount;
@@ -155,7 +167,9 @@
class NameCombination implements Comparable<NameCombination> {
final List<String> names;
- NameCombination(this.names);
+ NameCombination(this.names) {
+ assert(names.isSorted(Comparable.compare));
+ }
@override
int compareTo(NameCombination other) {
@@ -363,6 +377,8 @@
/// Get the representation for closures with a specific signature, described
/// by the number of type parameters, the maximum number of positional
/// parameters and the names of named parameters.
+ ///
+ /// `names` should be sorted.
ClosureRepresentation? getClosureRepresentation(
int typeCount, int positionalCount, List<String> names) {
final representations =
@@ -1002,10 +1018,20 @@
/// A local function or function expression.
class Lambda {
final FunctionNode functionNode;
+
+ // Note: creating a `Lambda` does not add this function to the compilation
+ // queue. Make sure to get it with `Functions.getLambdaFunction` to add it
+ // to the compilation queue.
final w.FunctionBuilder function;
+
final Source functionNodeSource;
- Lambda(this.functionNode, this.function, this.functionNodeSource);
+ /// Index of the function within the enclosing member, based on pre-order
+ /// traversal of the member body.
+ final int index;
+
+ Lambda._(
+ this.functionNode, this.function, this.functionNodeSource, this.index);
}
/// The context for one or more closures, containing their captured variables.
@@ -1141,7 +1167,10 @@
/// does not populate [lambdas], [contexts], [captures], and
/// [closurizedFunctions]. This mode is useful in the code generators that
/// always have direct access to variables (instead of via a context).
- Closures(this.translator, this._member, {bool findCaptures = true})
+ ///
+ /// When `findCaptures` is `true`, the created [Lambda]s are also added to the
+ /// compilation queue.
+ Closures(this.translator, this._member, {required bool findCaptures})
: _nullableThisType = _member is Constructor || _member.isInstanceMember
? translator.preciseThisFor(_member, nullable: true) as w.RefType
: null {
@@ -1385,7 +1414,10 @@
functionName = "$member closure $functionNodeName at ${node.location}";
}
final function = module.functions.define(type, functionName);
- closures.lambdas[node] = Lambda(node, function, _currentSource);
+ final lambda =
+ Lambda._(node, function, _currentSource, closures.lambdas.length);
+ closures.lambdas[node] = lambda;
+ translator.functions.getLambdaFunction(lambda, member, closures);
functionIsSyncStarOrAsync.add(node.asyncMarker == AsyncMarker.SyncStar ||
node.asyncMarker == AsyncMarker.Async);
diff --git a/pkg/dart2wasm/lib/code_generator.dart b/pkg/dart2wasm/lib/code_generator.dart
index 3e03a73..232d0cc 100644
--- a/pkg/dart2wasm/lib/code_generator.dart
+++ b/pkg/dart2wasm/lib/code_generator.dart
@@ -213,15 +213,6 @@
translator.membersBeingGenerated.remove(enclosingMember);
}
- void addNestedClosuresToCompilationQueue() {
- for (Lambda lambda in closures.lambdas.values) {
- translator.compilationQueue.add(CompilationTask(
- lambda.function,
- getLambdaCodeGenerator(
- translator, lambda, enclosingMember, closures)));
- }
- }
-
// Generate the body.
void generateInternal();
@@ -2388,14 +2379,11 @@
expectedType);
}
- final Expression receiver = node.receiver;
- final Arguments arguments = node.arguments;
-
- int typeCount = arguments.types.length;
- int posArgCount = arguments.positional.length;
- List<String> argNames = arguments.named.map((a) => a.name).toList()..sort();
+ List<String> argNames = node.arguments.named.map((a) => a.name).toList()
+ ..sort();
ClosureRepresentation? representation = translator.closureLayouter
- .getClosureRepresentation(typeCount, posArgCount, argNames);
+ .getClosureRepresentation(node.arguments.types.length,
+ node.arguments.positional.length, argNames);
if (representation == null) {
// This is a dynamic function call with a signature that matches no
// functions in the program.
@@ -2403,26 +2391,75 @@
return translator.topInfo.nullableType;
}
+ final SingleClosureTarget? directClosureCall =
+ translator.singleClosureTarget(node, representation, typeContext);
+
+ if (directClosureCall != null) {
+ return _generateDirectClosureCall(
+ node, representation, directClosureCall);
+ }
+
+ return _generateClosureInvocation(node, representation);
+ }
+
+ w.ValueType _generateDirectClosureCall(FunctionInvocation node,
+ ClosureRepresentation representation, SingleClosureTarget closureTarget) {
+ final closureStruct = representation.closureStruct;
+ final closureStructRef = w.RefType.def(closureStruct, nullable: false);
+ final signature = closureTarget.signature;
+ final paramInfo = closureTarget.paramInfo;
+ final member = closureTarget.member;
+ final lambdaFunction = closureTarget.lambdaFunction;
+
+ if (lambdaFunction == null) {
+ if (paramInfo.takesContextOrReceiver) {
+ translateExpression(node.receiver, closureStructRef);
+ b.struct_get(closureStruct, FieldIndex.closureContext);
+ translator.convertType(
+ b,
+ closureStruct.fields[FieldIndex.closureContext].type.unpacked,
+ signature.inputs[0]);
+ _visitArguments(node.arguments, signature, paramInfo, 1);
+ } else {
+ _visitArguments(node.arguments, signature, paramInfo, 0);
+ }
+ return translator.outputOrVoid(call(member.reference));
+ } else {
+ assert(paramInfo.takesContextOrReceiver);
+ translateExpression(node.receiver, closureStructRef);
+ b.struct_get(closureStruct, FieldIndex.closureContext);
+ _visitArguments(node.arguments, signature, paramInfo, 1);
+ return translator
+ .outputOrVoid(translator.callFunction(lambdaFunction, b));
+ }
+ }
+
+ w.ValueType _generateClosureInvocation(
+ FunctionInvocation node, ClosureRepresentation representation) {
+ final closureStruct = representation.closureStruct;
+
// Evaluate receiver
- w.StructType struct = representation.closureStruct;
- w.Local closureLocal = addLocal(w.RefType.def(struct, nullable: false));
- translateExpression(receiver, closureLocal.type);
+ w.Local closureLocal =
+ addLocal(w.RefType.def(closureStruct, nullable: false));
+ translateExpression(node.receiver, closureLocal.type);
b.local_tee(closureLocal);
- b.struct_get(struct, FieldIndex.closureContext);
+ b.struct_get(closureStruct, FieldIndex.closureContext);
// Type arguments
- for (DartType typeArg in arguments.types) {
+ for (DartType typeArg in node.arguments.types) {
types.makeType(this, typeArg);
}
// Positional arguments
- for (Expression arg in arguments.positional) {
+ for (Expression arg in node.arguments.positional) {
translateExpression(arg, translator.topInfo.nullableType);
}
// Named arguments
+ final List<String> argNames =
+ node.arguments.named.map((a) => a.name).toList()..sort();
final Map<String, w.Local> namedLocals = {};
- for (final namedArg in arguments.named) {
+ for (final namedArg in node.arguments.named) {
final w.Local namedLocal = addLocal(translator.topInfo.nullableType);
namedLocals[namedArg.name] = namedLocal;
translateExpression(namedArg.value, namedLocal.type);
@@ -2432,15 +2469,17 @@
b.local_get(namedLocals[name]!);
}
- // Call entry point in vtable
- int vtableFieldIndex =
- representation.fieldIndexForSignature(posArgCount, argNames);
- w.FunctionType functionType =
+ final int vtableFieldIndex = representation.fieldIndexForSignature(
+ node.arguments.positional.length, argNames);
+ final w.FunctionType functionType =
representation.getVtableFieldType(vtableFieldIndex);
+
+ // Call entry point in vtable
b.local_get(closureLocal);
- b.struct_get(struct, FieldIndex.closureVtable);
+ b.struct_get(closureStruct, FieldIndex.closureVtable);
b.struct_get(representation.vtableStruct, vtableFieldIndex);
b.call_ref(functionType);
+
return translator.topInfo.nullableType;
}
@@ -3181,7 +3220,7 @@
return;
}
- closures = Closures(translator, member);
+ closures = translator.getClosures(member);
setupParametersAndContexts(member, useUncheckedEntry: useUncheckedEntry);
@@ -3192,7 +3231,6 @@
_implicitReturn();
b.end();
- addNestedClosuresToCompilationQueue();
}
}
@@ -3209,7 +3247,7 @@
// used by `makeType` below, when generating runtime types of type
// parameters of the function type, but the type parameters are not
// captured, always loaded from the `this` struct.
- closures = Closures(translator, member, findCaptures: false);
+ closures = translator.getClosures(member, findCaptures: false);
_initializeThis(member.reference);
Procedure procedure = member as Procedure;
@@ -3242,7 +3280,7 @@
// Initialize [Closures] without [Closures.captures]: Similar to
// [TearOffCodeGenerator], type parameters will be loaded from the `this`
// struct.
- closures = Closures(translator, member, findCaptures: false);
+ closures = translator.getClosures(member, findCaptures: false);
if (member is Field ||
(member is Procedure && (member as Procedure).isSetter)) {
_generateFieldSetterTypeCheckerMethod();
@@ -3482,7 +3520,6 @@
generateInitializerList();
}
b.end();
- addNestedClosuresToCompilationQueue();
}
// Generates a constructor's initializer list method, and returns:
@@ -3848,7 +3885,7 @@
setSourceMapSourceAndFileOffset(source, field.fileOffset);
// Static field initializer function
- closures = Closures(translator, field);
+ closures = translator.getClosures(field);
w.Global global = translator.globals.getGlobalForStaticField(field);
w.Global? flag = translator.globals.getGlobalInitializedFlag(field);
@@ -3861,7 +3898,6 @@
b.global_get(global);
translator.convertType(b, global.type.type, outputs.single);
b.end();
- addNestedClosuresToCompilationQueue();
}
}
@@ -3944,7 +3980,7 @@
// that instantiates types uses closure information to see whether a type
// parameter was captured (and loads it from context chain) or not (and
// loads it directly from `this`).
- closures = Closures(translator, field, findCaptures: false);
+ closures = translator.getClosures(field, findCaptures: false);
final source = field.enclosingComponent!.uriToSource[field.fileUri]!;
setSourceMapSourceAndFileOffset(source, field.fileOffset);
diff --git a/pkg/dart2wasm/lib/functions.dart b/pkg/dart2wasm/lib/functions.dart
index 6d5b6f7..4b93afc 100644
--- a/pkg/dart2wasm/lib/functions.dart
+++ b/pkg/dart2wasm/lib/functions.dart
@@ -20,6 +20,8 @@
// Wasm function for each Dart function
final Map<Reference, w.BaseFunction> _functions = {};
+ // Wasm function for each function expression and local function.
+ final Map<Lambda, w.BaseFunction> _lambdas = {};
// Names of exported functions
final Map<Reference, String> _exports = {};
// Selector IDs that are invoked via GDT.
@@ -129,6 +131,17 @@
});
}
+ w.BaseFunction getLambdaFunction(
+ Lambda lambda, Member enclosingMember, Closures enclosingMemberClosures) {
+ return _lambdas.putIfAbsent(lambda, () {
+ translator.compilationQueue.add(CompilationTask(
+ lambda.function,
+ getLambdaCodeGenerator(
+ translator, lambda, enclosingMember, enclosingMemberClosures)));
+ return lambda.function;
+ });
+ }
+
w.FunctionType getFunctionType(Reference target) {
// We first try to get the function type by seeing if we already
// compiled the [target] function.
@@ -278,7 +291,7 @@
// context argument if context must be shared between them. Generate the
// contexts the first time we visit a constructor.
translator.constructorClosures[node.reference] ??=
- Closures(translator, node);
+ translator.getClosures(node);
if (target.isInitializerReference) {
return _getInitializerType(node, target, arguments);
diff --git a/pkg/dart2wasm/lib/state_machine.dart b/pkg/dart2wasm/lib/state_machine.dart
index 3ac5bb1..4ed78f4 100644
--- a/pkg/dart2wasm/lib/state_machine.dart
+++ b/pkg/dart2wasm/lib/state_machine.dart
@@ -595,7 +595,7 @@
setSourceMapSource(source);
setSourceMapFileOffset(member.fileOffset);
- closures = Closures(translator, member);
+ closures = translator.getClosures(member);
// We don't support inlining state machine functions atm. Only when we
// inline and have call-site guarantees we would use the unchecked entry.
@@ -605,7 +605,6 @@
if (context != null && context.isEmpty) context = context.parent;
generateOuter(member.function, context, source);
- addNestedClosuresToCompilationQueue();
}
}
diff --git a/pkg/dart2wasm/lib/translator.dart b/pkg/dart2wasm/lib/translator.dart
index 5e79222..8e38d8c 100644
--- a/pkg/dart2wasm/lib/translator.dart
+++ b/pkg/dart2wasm/lib/translator.dart
@@ -326,6 +326,15 @@
bool isMainModule(w.ModuleBuilder module) => _builderToOutput[module]!.isMain;
+ /// Maps compiled members to their [Closures], with capture information.
+ final Map<Member, Closures> _memberClosures = {};
+
+ Closures getClosures(Member member, {bool findCaptures = true}) =>
+ findCaptures
+ ? _memberClosures.putIfAbsent(
+ member, () => Closures(this, member, findCaptures: true))
+ : Closures(this, member, findCaptures: false);
+
Translator(this.component, this.coreTypes, this.index, this.recordClasses,
this._moduleOutputData, this.options)
: libraries = component.libraries,
@@ -869,8 +878,8 @@
ib.struct_new(representation.vtableStruct);
ib.end();
- final implementation = ClosureImplementation(
- representation, functions, dynamicCallEntry, vtable, targetModule);
+ final implementation = ClosureImplementation(representation, functions,
+ dynamicCallEntry, vtable, targetModule, paramInfo);
closureImplementations[functionNode] = implementation;
return implementation;
}
@@ -1035,6 +1044,74 @@
return directCallMetadata[node]?.targetMember;
}
+ /// Direct call information of a [FunctionInvocation] based on TFA's direct
+ /// call metadata.
+ SingleClosureTarget? singleClosureTarget(FunctionInvocation node,
+ ClosureRepresentation representation, StaticTypeContext typeContext) {
+ final (Member, int)? directClosureCall =
+ directCallMetadata[node]?.targetClosure;
+
+ if (directClosureCall == null) {
+ return null;
+ }
+
+ // To avoid using the `Null` class, avoid devirtualizing to `Null` members.
+ // `noSuchMethod` is also not allowed as `Null` inherits it.
+ if (directClosureCall.$1.enclosingClass == coreTypes.deprecatedNullClass ||
+ directClosureCall.$1 == objectNoSuchMethod) {
+ return null;
+ }
+
+ final member = directClosureCall.$1;
+ final closureId = directClosureCall.$2;
+
+ if (closureId == 0) {
+ // The member is called as a closure (tear-off). We'll generate a direct
+ // call to the member.
+ final lambdaDartType =
+ member.function!.computeFunctionType(Nullability.nonNullable);
+
+ // Check that type of the receiver is a subtype of
+ if (!typeEnvironment.isSubtypeOf(
+ lambdaDartType,
+ node.receiver.getStaticType(typeContext),
+ SubtypeCheckMode.withNullabilities)) {
+ return null;
+ }
+
+ return SingleClosureTarget._(
+ member,
+ paramInfoForDirectCall(member.reference),
+ signatureForDirectCall(member.reference),
+ null,
+ );
+ } else {
+ // A closure in the member is called.
+ final Closures enclosingMemberClosures =
+ getClosures(member, findCaptures: true);
+ final Lambda lambda = enclosingMemberClosures.lambdas.values
+ .firstWhere((lambda) => lambda.index == closureId - 1);
+ final FunctionType lambdaDartType =
+ lambda.functionNode.computeFunctionType(Nullability.nonNullable);
+ final w.BaseFunction lambdaFunction =
+ functions.getLambdaFunction(lambda, member, enclosingMemberClosures);
+
+ if (!typeEnvironment.isSubtypeOf(
+ lambdaDartType,
+ node.receiver.getStaticType(typeContext),
+ SubtypeCheckMode.withNullabilities)) {
+ return null;
+ }
+
+ return SingleClosureTarget._(
+ member,
+ ParameterInfo.fromLocalFunction(lambda.functionNode),
+ lambdaFunction.type,
+ lambdaFunction,
+ );
+ }
+ }
+
bool canSkipImplicitCheck(VariableDeclaration node) {
return inferredArgTypeMetadata[node]?.skipCheck ?? false;
}
@@ -2106,3 +2183,24 @@
return importingModule.tags.import(moduleName, importName, definition.type);
}
}
+
+class SingleClosureTarget {
+ /// When `lambdaFunction` is null, the member being directly called. Otherwise
+ /// the enclosing member of the closure being called.
+ final Member member;
+
+ /// [ParameterInfo] specifying how to compile arguments to the closure or
+ /// member.
+ final ParameterInfo paramInfo;
+
+ /// Wasm function type that goes along with the [paramInfo] for compiling
+ /// arguments.
+ final w.FunctionType signature;
+
+ /// If the callee is a local function or function expression (intead of a
+ /// member), this Wasm function for it.
+ final w.BaseFunction? lambdaFunction;
+
+ SingleClosureTarget._(
+ this.member, this.paramInfo, this.signature, this.lambdaFunction);
+}
diff --git a/pkg/vm/lib/metadata/direct_call.dart b/pkg/vm/lib/metadata/direct_call.dart
index e49910a..abba2a2 100644
--- a/pkg/vm/lib/metadata/direct_call.dart
+++ b/pkg/vm/lib/metadata/direct_call.dart
@@ -39,6 +39,15 @@
bool get checkReceiverForNull => (_flags & flagCheckReceiverForNull) != 0;
bool get isClosure => (_flags & flagClosure) != 0;
+ /// When calling a closure, the enclosing member of the closure, and the
+ /// closure index.
+ ///
+ /// Closures in a member are assigned ids based on pre-order traversal of the
+ /// member body, and the member itself also counts as a closure (for
+ /// tear-offs). So index 0 is the member itself, called as a closure
+ /// (tear-off).
+ (Member, int)? get targetClosure => isClosure ? (_member, _closureId) : null;
+
@override
String toString() => isClosure
? 'closure ${_closureId} in ${_member.toText(astTextStrategyForTesting)}'
diff --git a/tests/language/function/direct_invocation_test.dart b/tests/language/function/direct_invocation_test.dart
new file mode 100644
index 0000000..1d675f1
--- /dev/null
+++ b/tests/language/function/direct_invocation_test.dart
@@ -0,0 +1,35 @@
+// Copyright (c) 2024, the Dart project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+// Test using TFA's direct-call metadata in function invocations.
+
+import "package:expect/expect.dart";
+
+int f(int x, [int? y]) => x;
+
+// Calls `f` directly.
+int test1(int cb(int a)) => cb(1);
+
+class A {
+ int f(int x, [int? y]) => x;
+}
+
+// Calls `A.f` directly.
+int test2(int cb(int a)) => cb(2);
+
+class B {
+ int nested() {
+ int cb1(int x, [int? y]) => x;
+ return test3(cb1);
+ }
+}
+
+// Calls the nested closure `cb1` directly.
+int test3(int cb(int a)) => cb(3);
+
+void main() {
+ Expect.equals(test1(f), 1);
+ Expect.equals(test2(A().f), 2);
+ Expect.equals(B().nested(), 3);
+}