[dart2wasm] Adjustments to inliner

* Inline small `get iterator` & `get current` iterator methods
* Inline bodies that are small compared to arguments
* Make AST node counter more precise
* Manually mark ListIterator methods as prefer inline

CoreLibraryReviewExempt: Only adds annotation to existing functions
Change-Id: Ib6379e73713cd47a88e5cc67cecd4b5c8344adcb
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/382882
Reviewed-by: Slava Egorov <vegorov@google.com>
diff --git a/pkg/dart2wasm/lib/code_generator.dart b/pkg/dart2wasm/lib/code_generator.dart
index e1d17b9..107f3b0 100644
--- a/pkg/dart2wasm/lib/code_generator.dart
+++ b/pkg/dart2wasm/lib/code_generator.dart
@@ -4530,7 +4530,8 @@
   bool get supportsInlining => _translator.supportsInlining(_reference);
 
   @override
-  bool get shouldInline => _translator.shouldInline(_reference);
+  bool get shouldInline =>
+      _translator.shouldInline(_reference, signature, useUncheckedEntry);
 
   @override
   CodeGenerator get inliningCodeGen => getInlinableMemberCodeGenerator(
diff --git a/pkg/dart2wasm/lib/translator.dart b/pkg/dart2wasm/lib/translator.dart
index 6e38302..c006cba 100644
--- a/pkg/dart2wasm/lib/translator.dart
+++ b/pkg/dart2wasm/lib/translator.dart
@@ -993,20 +993,65 @@
     return InterfaceType(concreteClass, nullability, typeArguments);
   }
 
-  bool shouldInline(Reference target) {
+  bool shouldInline(
+      Reference target, w.FunctionType signature, bool useUncheckedEntry) {
     if (!options.inlining) return false;
-    Member member = target.asMember;
+
+    final member = target.asMember;
     if (getPragma<bool>(member, "wasm:never-inline", true) == true) {
       return false;
     }
-    if (target.isInitializerReference) return true;
-    if (member is Field) return true;
     if (getPragma<bool>(member, "wasm:prefer-inline", true) == true) {
       return true;
     }
-    Statement? body = member.function!.body;
-    return body != null &&
-        NodeCounter().countNodes(body) <= options.inliningLimit;
+    if (member is Field) return true;
+    if (target.isInitializerReference) return true;
+
+    final function = member.function!;
+    if (function.body == null) return false;
+
+    // We never want to inline throwing functions (as they are slow paths).
+    if (member is Procedure && member.function.returnType is NeverType) {
+      return false;
+    }
+
+    final nodeCount =
+        NodeCounter(options.omitImplicitTypeChecks || useUncheckedEntry)
+            .countNodes(member);
+
+    // Special cases for iterator inlining:
+    //   class ... implements Iterable<T> {
+    //     Iterator<T> get iterator => FooIterator(...)
+    //   }
+    //   class ... implements Iterator<T> {
+    //     T get current => _current as E;
+    //   }
+    final klass = member.enclosingClass;
+    if (klass != null) {
+      final name = member.name.text;
+      if (name == 'iterator' && nodeCount <= 20) {
+        if (typeEnvironment.isSubtypeOf(
+            klass.getThisType(coreTypes, Nullability.nonNullable),
+            coreTypes.iterableRawType(Nullability.nonNullable),
+            SubtypeCheckMode.ignoringNullabilities)) {
+          return true;
+        }
+      }
+      if (name == 'current' && nodeCount <= 5) {
+        if (typeEnvironment.isSubtypeOf(
+            klass.getThisType(coreTypes, Nullability.nonNullable),
+            coreTypes.iteratorRawType(Nullability.nonNullable),
+            SubtypeCheckMode.ignoringNullabilities)) {
+          return true;
+        }
+      }
+    }
+
+    // If we think the overhead of pushing arguments is around the same as the
+    // body itself, we always inline.
+    if (nodeCount <= signature.inputs.length) return true;
+
+    return nodeCount <= options.inliningLimit;
   }
 
   bool supportsInlining(Reference target) {
@@ -1471,19 +1516,94 @@
 }
 
 class NodeCounter extends VisitorDefault<void> with VisitorVoidMixin {
+  final bool omitCovarianceChecks;
   int count = 0;
 
-  int countNodes(Node node) {
+  NodeCounter(this.omitCovarianceChecks);
+
+  int countNodes(Member member) {
     count = 0;
-    node.accept(this);
+    if (member is Constructor) {
+      count += 2; // object creation overhead
+      for (final init in member.initializers) {
+        init.accept(this);
+      }
+      for (final field in member.enclosingClass.fields) {
+        field.initializer?.accept(this);
+      }
+    }
+
+    final function = member.function!;
+    if (!omitCovarianceChecks) {
+      for (final parameter in function.positionalParameters) {
+        if (parameter.isCovariantByDeclaration ||
+            parameter.isCovariantByClass) {
+          count++;
+        }
+      }
+    }
+    for (final parameter in function.positionalParameters) {
+      if (!omitCovarianceChecks) {
+        if (parameter.isCovariantByDeclaration ||
+            parameter.isCovariantByClass) {
+          count++;
+        }
+      }
+      if (!parameter.isRequired) count++;
+    }
+
+    function.body?.accept(this);
     return count;
   }
 
+  // We only count tree nodes and do not recurse into things that aren't part of
+  // the tree (e.g. constants, variable types, ...)
+
   @override
-  void defaultNode(Node node) {
+  void defaultTreeNode(TreeNode node) {
     count++;
     node.visitChildren(this);
   }
+
+  // The following AST nodes do not actually emit any code, so we don't count
+  // those nodes but we recurse into children that do emit code and therefore
+  // should count.
+
+  @override
+  void visitBlock(Block node) {
+    node.visitChildren(this);
+  }
+
+  @override
+  void visitEmptyStatement(EmptyStatement node) {
+    node.visitChildren(this);
+  }
+
+  @override
+  void visitLabeledStatement(LabeledStatement node) {
+    node.visitChildren(this);
+  }
+
+  @override
+  void visitBlockExpression(BlockExpression node) {
+    node.visitChildren(this);
+  }
+
+  @override
+  void visitExpressionStatement(ExpressionStatement node) {
+    node.visitChildren(this);
+  }
+
+  @override
+  void visitArguments(Arguments node) {
+    count += node.types.length;
+    node.visitChildren(this);
+  }
+
+  @override
+  void visitNamedExpression(NamedExpression node) {
+    node.visitChildren(this);
+  }
 }
 
 /// Creates forwarders for generic functions where the caller passes a constant
diff --git a/sdk/lib/internal/iterable.dart b/sdk/lib/internal/iterable.dart
index 9996ddd..b6f58166 100644
--- a/sdk/lib/internal/iterable.dart
+++ b/sdk/lib/internal/iterable.dart
@@ -334,6 +334,7 @@
   int _index;
   E? _current;
 
+  @pragma("wasm:prefer-inline")
   ListIterator(Iterable<E> iterable)
       : _iterable = iterable,
         _length = iterable.length,
@@ -342,6 +343,7 @@
   E get current => _current as E;
 
   @pragma("vm:prefer-inline")
+  @pragma("wasm:prefer-inline")
   bool moveNext() {
     int length = _iterable.length;
     if (_length != length) {
diff --git a/tests/web/wasm/source_map_simple_lib.dart b/tests/web/wasm/source_map_simple_lib.dart
index 150c627..74907ae 100644
--- a/tests/web/wasm/source_map_simple_lib.dart
+++ b/tests/web/wasm/source_map_simple_lib.dart
@@ -18,12 +18,12 @@
 
 runtimeFalse() => int.parse('1') == 0;
 
-// `frameDetails` is (line, column) of the frames we check.
+// `expectedFrames` is (String, line, column) of the frames we check.
 //
 // Information we don't check are "null": we don't want to check line/column
 // of standard library functions to avoid breaking the test with unrelated
 // changes to the standard library.
-void testMain(String testName, List<(int?, int?)?> frameDetails) {
+void testMain(String testName, List<(String?, int?, int?)?> expectedFrames) {
   // Use `f` and `g` in a few places to make sure wasm-opt won't inline them
   // in the test.
   final fTearOff = f;
@@ -32,12 +32,6 @@
   if (runtimeFalse()) f();
   if (runtimeFalse()) g();
 
-  // Read source map of the current program.
-  final compilationDir = const String.fromEnvironment("TEST_COMPILATION_DIR");
-  final sourceMapFileContents =
-      readfile('$compilationDir/${testName}_test.wasm.map');
-  final mapping = parse(utf8.decode(sourceMapFileContents)) as SingleMapping;
-
   // Get some simple stack trace.
   String? stackTraceString;
   try {
@@ -51,42 +45,74 @@
   print(stackTraceString);
   print("-----");
 
-  final stackTraceLines = stackTraceString!.split('\n');
+  final actualFrames =
+      parseStack(getSourceMapping(testName), stackTraceString!);
+  print('Got stack trace:');
+  for (final frame in actualFrames) {
+    print('  $frame');
+  }
+  print('Matching against:');
+  for (final frame in expectedFrames) {
+    print('  $frame');
+  }
 
-  for (int frameIdx = 0; frameIdx < frameDetails.length; frameIdx += 1) {
-    final line = stackTraceLines[frameIdx];
+  if (actualFrames.length < expectedFrames.length) {
+    throw 'Less actual frames than expected';
+  }
+
+  for (int i = 0; i < expectedFrames.length; i++) {
+    final expected = expectedFrames[i];
+    final actual = actualFrames[i];
+    if (expected == null) continue;
+    if (actual == null) {
+      throw 'Mismatch:\n  Expected: $expected\n  Actual: <no mapping>';
+    }
+    if ((expected.$1 != null && actual.$1 != expected.$1) ||
+        (expected.$2 != null && actual.$2 != expected.$2) ||
+        (expected.$3 != null && actual.$3 != expected.$3)) {
+      throw 'Mismatch:\n  Expected: $expected\n  Actual: $actual';
+    }
+  }
+}
+
+SingleMapping getSourceMapping(String testName) {
+  // Read source map of the current program.
+  final compilationDir = const String.fromEnvironment("TEST_COMPILATION_DIR");
+  final sourceMapFileContents =
+      readfile('$compilationDir/${testName}_test.wasm.map');
+  return parse(utf8.decode(sourceMapFileContents)) as SingleMapping;
+}
+
+List<(String?, int?, int?)?> parseStack(
+    SingleMapping mapping, String stackTraceString) {
+  final parsed = <(String?, int?, int?)?>[];
+  for (final line in stackTraceString.split('\n')) {
+    if (line.contains('.mjs') || line.contains('.js')) {
+      parsed.add(null);
+      continue;
+    }
+
     final hexOffsetMatch = stackTraceHexOffsetRegExp.firstMatch(line);
     if (hexOffsetMatch == null) {
-      throw "Unable to parse hex offset from stack frame $frameIdx";
+      throw 'Unable to parse hex offset in frame "$line"';
     }
     final hexOffsetStr = hexOffsetMatch.group(1)!; // includes '0x'
     final offset = int.tryParse(hexOffsetStr);
     if (offset == null) {
-      throw "Unable to parse hex number in frame $frameIdx: $hexOffsetStr";
+      throw 'Unable to parse hex number in frame "$line"';
     }
     final span = mapping.spanFor(0, offset);
-    final frameInfo = frameDetails[frameIdx];
-    if (frameInfo == null) {
-      if (span != null) {
-        throw "Stack frame $frameIdx should not have a source span, but it is mapped: $span";
-      }
+    if (span == null) {
+      print('Stack frame "$line" not have source mapping');
+      parsed.add(null);
       continue;
     }
-    if (span == null) {
-      print("Stack frame $frameIdx does not have source mapping");
-    } else {
-      if (frameInfo.$1 != null) {
-        if (span.start.line + 1 != frameInfo.$1) {
-          throw "Stack frame $frameIdx is expected to have line ${frameInfo.$1}, but it has line ${span.start.line + 1}";
-        }
-      }
-      if (frameInfo.$2 != null) {
-        if (span.start.column + 1 != frameInfo.$2) {
-          throw "Stack frame $frameIdx is expected to have column ${frameInfo.$2}, but it has column ${span.start.column + 1}";
-        }
-      }
-    }
+    final filename = span.sourceUrl!.pathSegments.last;
+    final lineNumber = span.start.line;
+    final columnNumber = span.start.column;
+    parsed.add((filename, 1 + lineNumber, 1 + columnNumber));
   }
+  return parsed;
 }
 
 /// Read the file at the given [path].
diff --git a/tests/web/wasm/source_map_simple_optimized_test.dart b/tests/web/wasm/source_map_simple_optimized_test.dart
index 445d548..dcc60dc 100644
--- a/tests/web/wasm/source_map_simple_optimized_test.dart
+++ b/tests/web/wasm/source_map_simple_optimized_test.dart
@@ -10,12 +10,12 @@
   Lib.testMain('source_map_simple_optimized', frameDetails);
 }
 
-const List<(int?, int?)?> frameDetails = [
-  (null, null), // _throwWithCurrentStackTrace
-  (16, 3), // g
-  (12, 3), // f
-  (44, 5), // testMain, inlined in main
-  (null, null), // _invokeMain
+const List<(String?, int?, int?)?> frameDetails = [
+  ('errors_patch.dart', null, null), // _throwWithCurrentStackTrace
+  ('source_map_simple_lib.dart', 16, 3), // g
+  ('source_map_simple_lib.dart', 12, 3), // f
+  ('source_map_simple_lib.dart', 38, 5), // testMain, inlined in main
+  ('internal_patch.dart', null, null), // _invokeMain
 ];
 
 /*
diff --git a/tests/web/wasm/source_map_simple_test.dart b/tests/web/wasm/source_map_simple_test.dart
index 49065de..d49ef09 100644
--- a/tests/web/wasm/source_map_simple_test.dart
+++ b/tests/web/wasm/source_map_simple_test.dart
@@ -10,14 +10,14 @@
   Lib.testMain('source_map_simple', frameDetails);
 }
 
-const List<(int?, int?)?> frameDetails = [
-  (null, null), // _throwWithCurrentStackTrace
-  (16, 3), // g
-  (12, 3), // f
-  (44, 5), // testMain
-  (10, 7), // main
+const List<(String?, int?, int?)?> frameDetails = [
+  ('errors_patch.dart', null, null), // _throwWithCurrentStackTrace
+  ('source_map_simple_lib.dart', 16, 3), // g
+  ('source_map_simple_lib.dart', 12, 3), // f
+  ('source_map_simple_lib.dart', 38, 5), // testMain
+  ('source_map_simple_test.dart', 10, 7), // main
   null, // main tear-off, compiler generated, not mapped
-  (null, null), // _invokeMain
+  ('internal_patch.dart', null, null), // _invokeMain
 ];
 
 /*