[dart2wasm] Make polymorphic specialization code use polymorphic dispatcher helpers

Instead of doing polymorphic specialization inline using custom logic,
we

* use a shared helper instead of doing it inline (can reduce size)
* use common class-id based searching code (enables binary search)
* remove dispatch table entries if we only dispatch using polymorphic
  specialization (can reduce size)

Change-Id: I14f5347f729b1949b1447d535681991a3a79f187
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/375501
Commit-Queue: Martin Kustermann <kustermann@google.com>
Reviewed-by: Srujan Gaddam <srujzs@google.com>
diff --git a/pkg/dart2wasm/lib/class_info.dart b/pkg/dart2wasm/lib/class_info.dart
index 4c97caf..bcba05d 100644
--- a/pkg/dart2wasm/lib/class_info.dart
+++ b/pkg/dart2wasm/lib/class_info.dart
@@ -746,6 +746,13 @@
   }
 
   @override
+  int get hashCode => Object.hash(start, end);
+
+  @override
+  bool operator ==(other) =>
+      other is Range && other.start == start && other.end == end;
+
+  @override
   String toString() => isEmpty ? '[]' : '[$start, $end]';
 }
 
diff --git a/pkg/dart2wasm/lib/code_generator.dart b/pkg/dart2wasm/lib/code_generator.dart
index 75f5cab..e5ea8d1 100644
--- a/pkg/dart2wasm/lib/code_generator.dart
+++ b/pkg/dart2wasm/lib/code_generator.dart
@@ -2194,6 +2194,7 @@
     pushReceiver(selector.signature);
 
     if (selector.targetRanges.length == 1) {
+      assert(selector.staticDispatchRanges.length == 1);
       final target = selector.targetRanges[0].target;
       final signature = translator.signatureForDirectCall(target);
       final paramInfo = translator.paramInfoForDirectCall(target);
@@ -2201,13 +2202,14 @@
       return translator.outputOrVoid(call(target));
     }
 
-    int? offset = selector.offset;
-    if (offset == null) {
+    if (selector.targetRanges.isEmpty) {
       // Unreachable call
-      assert(selector.targetRanges.isEmpty);
       b.comment("Virtual call of ${selector.name} with no targets"
           " at ${node.location}");
-      b.drop();
+      pushArguments(selector.signature, selector.paramInfo);
+      for (int i = 0; i < selector.signature.inputs.length; ++i) {
+        b.drop();
+      }
       b.block(const [], selector.signature.outputs);
       b.unreachable();
       b.end();
@@ -2219,9 +2221,13 @@
     assert(!receiverVar.type.nullable);
     b.local_tee(receiverVar);
     pushArguments(selector.signature, selector.paramInfo);
-    if (options.polymorphicSpecialization) {
-      _polymorphicSpecialization(selector, receiverVar);
+
+    if (selector.staticDispatchRanges.isNotEmpty) {
+      final polymorphicDispatcher =
+          translator.polymorphicDispatchers.getPolymorphicDispatcher(selector);
+      b.call(polymorphicDispatcher);
     } else {
+      final offset = selector.offset!;
       b.comment("Instance $kind of '${selector.name}'");
       b.local_get(receiverVar);
       b.struct_get(translator.topInfo.struct, FieldIndex.classId);
@@ -2237,66 +2243,6 @@
     return translator.outputOrVoid(selector.signature.outputs);
   }
 
-  void _polymorphicSpecialization(SelectorInfo selector, w.Local receiver) {
-    final implementations = <int, Reference>{};
-    for (final (:range, :target) in selector.targetRanges) {
-      for (int classId = range.start; classId <= range.end; ++classId) {
-        implementations[classId] = target;
-      }
-    }
-
-    w.Local idVar = addLocal(w.NumType.i32);
-    b.local_get(receiver);
-    b.struct_get(translator.topInfo.struct, FieldIndex.classId);
-    b.local_set(idVar);
-
-    w.Label block =
-        b.block(selector.signature.inputs, selector.signature.outputs);
-    calls:
-    while (Set.from(implementations.values).length > 1) {
-      for (int id in implementations.keys) {
-        Reference target = implementations[id]!;
-        if (implementations.values.where((t) => t == target).length == 1) {
-          // Single class id implements method.
-          b.local_get(idVar);
-          b.i32_const(id);
-          b.i32_eq();
-          b.if_(selector.signature.inputs, selector.signature.inputs);
-          call(target);
-          b.br(block);
-          b.end();
-          implementations.remove(id);
-          continue calls;
-        }
-      }
-      // Find class id that separates remaining classes in two.
-      List<int> sorted = implementations.keys.toList()..sort();
-      int pivotId = sorted.firstWhere(
-          (id) => implementations[id] != implementations[sorted.first]);
-      // Fail compilation if no such id exists.
-      assert(sorted.lastWhere(
-              (id) => implementations[id] != implementations[pivotId]) ==
-          pivotId - 1);
-      Reference target = implementations[sorted.first]!;
-      b.local_get(idVar);
-      b.i32_const(pivotId);
-      b.i32_lt_u();
-      b.if_(selector.signature.inputs, selector.signature.inputs);
-      call(target);
-      b.br(block);
-      b.end();
-      for (int id in sorted) {
-        if (id == pivotId) break;
-        implementations.remove(id);
-      }
-      continue calls;
-    }
-    // Call remaining implementation.
-    Reference target = implementations.values.first;
-    call(target);
-    b.end();
-  }
-
   @override
   w.ValueType visitVariableGet(VariableGet node, w.ValueType expectedType) {
     w.Local? local = locals[node.variable];
diff --git a/pkg/dart2wasm/lib/dispatch_table.dart b/pkg/dart2wasm/lib/dispatch_table.dart
index 7c25c34..b637058 100644
--- a/pkg/dart2wasm/lib/dispatch_table.dart
+++ b/pkg/dart2wasm/lib/dispatch_table.dart
@@ -50,6 +50,7 @@
   late final List<({Range range, Reference target})> targetRanges;
   late final Set<Reference> targetSet =
       targetRanges.map((e) => e.target).toSet();
+  late final List<({Range range, Reference target})> staticDispatchRanges;
 
   /// Wasm function type for the selector.
   ///
@@ -415,7 +416,13 @@
         }
       }
       ranges.length = writeIndex + 1;
+
+      final staticDispatchRanges =
+          translator.options.polymorphicSpecialization || ranges.length == 1
+              ? ranges
+              : <({Range range, Reference target})>[];
       selector.targetRanges = ranges;
+      selector.staticDispatchRanges = staticDispatchRanges;
     });
 
     _selectorInfo.forEach((_, selector) {
@@ -427,20 +434,24 @@
 
     // Assign selector offsets
 
-    /// Whether the selector will be used in an instance invocation.
-    ///
-    /// If not, then we don't add the selector to the dispatch table and don't
-    /// assign it a dispatch table offset.
-    ///
-    /// Special case for `objectNoSuchMethod`: we introduce instance
-    /// invocations of `objectNoSuchMethod` in dynamic calls, so keep it alive
-    /// even if there was no references to it from the Dart code.
-    bool needsDispatch(SelectorInfo selector) =>
-        (selector.callCount > 0 && selector.targetRanges.length > 1) ||
-        (selector.paramInfo.member! == translator.objectNoSuchMethod);
+    bool isUsedViaDispatchTableCall(SelectorInfo selector) {
+      // Special case for `objectNoSuchMethod`: we introduce instance
+      // invocations of `objectNoSuchMethod` in dynamic calls, so keep it alive
+      // even if there was no references to it from the Dart code.
+      if (selector.paramInfo.member! == translator.objectNoSuchMethod) {
+        return true;
+      }
+      if (selector.callCount == 0) return false;
+      if (selector.targetRanges.length <= 1) return false;
+      if (selector.staticDispatchRanges.length ==
+          selector.targetRanges.length) {
+        return false;
+      }
+      return true;
+    }
 
     final List<SelectorInfo> selectors =
-        selectorTargets.keys.where(needsDispatch).toList();
+        selectorTargets.keys.where(isUsedViaDispatchTableCall).toList();
 
     // Sort the selectors based on number of targets and number of use sites.
     // This is a heuristic to keep the table small.
diff --git a/pkg/dart2wasm/lib/translator.dart b/pkg/dart2wasm/lib/translator.dart
index 1927e76..7ba6c52 100644
--- a/pkg/dart2wasm/lib/translator.dart
+++ b/pkg/dart2wasm/lib/translator.dart
@@ -173,8 +173,8 @@
   late final w.RefType nullableObjectArrayTypeRef =
       w.RefType.def(nullableObjectArrayType, nullable: false);
 
-  late final PartialInstantiator partialInstantiator =
-      PartialInstantiator(this);
+  late final partialInstantiator = PartialInstantiator(this);
+  late final polymorphicDispatchers = PolymorphicDispatchers(this);
 
   /// Dart types that have specialized Wasm representations.
   late final Map<Class, w.StorageType> builtinTypes = {
@@ -1542,3 +1542,58 @@
     });
   }
 }
+
+class PolymorphicDispatchers {
+  final Translator translator;
+  final cache = <SelectorInfo, w.BaseFunction>{};
+
+  PolymorphicDispatchers(this.translator);
+
+  w.BaseFunction getPolymorphicDispatcher(SelectorInfo selector) {
+    assert(selector.targetRanges.length > 1);
+    return cache.putIfAbsent(selector, () {
+      final name = '${selector.name} (polymorphic dispatcher)';
+      final signature = selector.signature;
+      final inputs = signature.inputs;
+      final outputs = signature.outputs;
+      final function = translator.m.functions
+          .define(translator.m.types.defineFunction(inputs, outputs), name);
+
+      final b = function.body;
+
+      final targetRanges = selector.staticDispatchRanges
+          .map((entry) => (range: entry.range, value: entry.target))
+          .toList();
+
+      final bool needFallback =
+          selector.targetRanges.length > selector.staticDispatchRanges.length;
+      void emitDirectCall(Reference target) {
+        for (int i = 0; i < inputs.length; ++i) {
+          b.local_get(b.locals[i]);
+        }
+        b.call(translator.functions.getFunction(target));
+      }
+
+      void emitDispatchTableCall() {
+        for (int i = 0; i < inputs.length; ++i) {
+          b.local_get(b.locals[i]);
+        }
+        b.local_get(b.locals[0]);
+        b.struct_get(translator.topInfo.struct, FieldIndex.classId);
+        b.i32_const(selector.offset!);
+        b.i32_add();
+        b.call_indirect(signature, translator.dispatchTable.wasmTable);
+        translator.functions.recordSelectorUse(selector);
+      }
+
+      b.local_get(b.locals[0]);
+      b.struct_get(translator.topInfo.struct, FieldIndex.classId);
+      b.classIdSearch(targetRanges, outputs, emitDirectCall,
+          needFallback ? emitDispatchTableCall : null);
+      b.return_();
+      b.end();
+
+      return function;
+    });
+  }
+}