[dart2wasm] Improve closure hash codes

Currently hash code for a closure is the hash code of it's runtime type.

This causes a lot hash collisions in some apps and cause performance
issues.

With this patch we now use captured objects in closures when calculating
hash codes.

The hash codes are now:

- For tear-offs:

    mix(receiver hash, closure runtime type hash)

- For instantiations:

    mix(instantiated closure hash,
        hashes of captured types)

  Note that an instantiation can be of a tear-off, in which case
  "instantiated closure hash" will calculate the tear-off hash as above.

- For others (function literals, static functions), the hash is the
  identity hash.

Fixes #54912.

CoreLibraryReviewExempt: Mark private corelib function as entry-point in Wasm
Change-Id: I6a123fdc690237f543bb8bf832f0f8119d013a55
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/353162
Reviewed-by: Martin Kustermann <kustermann@google.com>
Commit-Queue: Ömer Ağacan <omersa@google.com>
diff --git a/pkg/dart2wasm/lib/class_info.dart b/pkg/dart2wasm/lib/class_info.dart
index b683be7..df09cba 100644
--- a/pkg/dart2wasm/lib/class_info.dart
+++ b/pkg/dart2wasm/lib/class_info.dart
@@ -39,7 +39,8 @@
   static const closureRuntimeType = 4;
   static const vtableDynamicCallEntry = 0;
   static const vtableInstantiationTypeComparisonFunction = 1;
-  static const vtableInstantiationFunction = 2;
+  static const vtableInstantiationTypeHashFunction = 2;
+  static const vtableInstantiationFunction = 3;
   static const instantiationContextInner = 0;
   static const instantiationContextTypeArgumentsBase = 1;
   static const typeIsDeclaredNullable = 2;
diff --git a/pkg/dart2wasm/lib/closures.dart b/pkg/dart2wasm/lib/closures.dart
index 61b1faf..3374ccf 100644
--- a/pkg/dart2wasm/lib/closures.dart
+++ b/pkg/dart2wasm/lib/closures.dart
@@ -78,6 +78,10 @@
       _instantiationTypeComparisonFunctionThunk!();
   w.BaseFunction Function()? _instantiationTypeComparisonFunctionThunk;
 
+  late final w.BaseFunction instantiationTypeHashFunction =
+      _instantiationTypeHashFunctionThunk!();
+  w.BaseFunction Function()? _instantiationTypeHashFunctionThunk;
+
   /// The signature of the function that instantiates this generic closure.
   w.FunctionType get instantiationFunctionType {
     assert(isGeneric);
@@ -170,9 +174,10 @@
   // For generic closures. The entries are:
   // 0: Dynamic call entry
   // 1: Instantiation type comparison function
-  // 2: Instantiation function
-  // 3-...: Entries for calling the closure
-  static const int vtableBaseIndexGeneric = 3;
+  // 2: Instantiation type hash function
+  // 3: Instantiation function
+  // 4-...: Entries for calling the closure
+  static const int vtableBaseIndexGeneric = 4;
 
   // Base struct for vtables without the dynamic call entry added. Referenced
   // by [closureBaseStruct] instead of the fully initialized version
@@ -203,6 +208,10 @@
         ..add(w.FieldType(
             w.RefType.def(instantiationClosureTypeComparisonFunctionType,
                 nullable: false),
+            mutable: false))
+        ..add(w.FieldType(
+            w.RefType.def(instantiationClosureTypeHashFunctionType,
+                nullable: false),
             mutable: false)),
       superType: vtableBaseStruct);
 
@@ -216,6 +225,12 @@
     [w.NumType.i32], // bool
   );
 
+  late final w.FunctionType instantiationClosureTypeHashFunctionType =
+      m.types.defineFunction(
+    [w.RefType.def(instantiationContextBaseStruct, nullable: false)],
+    [w.NumType.i64], // hash
+  );
+
   // Base struct for closures.
   late final w.StructType closureBaseStruct = _makeClosureStruct(
       "#ClosureBase", _vtableBaseStructBare, translator.closureInfo.struct);
@@ -245,6 +260,12 @@
       _instantiationTypeComparisonFunctions.putIfAbsent(
           numTypes, () => _createInstantiationTypeComparisonFunction(numTypes));
 
+  final Map<int, w.BaseFunction> _instantiationTypeHashFunctions = {};
+
+  w.BaseFunction _getInstantiationTypeHashFunction(int numTypes) =>
+      _instantiationTypeHashFunctions.putIfAbsent(
+          numTypes, () => _createInstantiationTypeHashFunction(numTypes));
+
   w.StructType _makeClosureStruct(
       String name, w.StructType vtableStruct, w.StructType superType) {
     // A closure contains:
@@ -482,6 +503,9 @@
 
       representation._instantiationTypeComparisonFunctionThunk =
           () => _getInstantiationTypeComparisonFunction(typeCount);
+
+      representation._instantiationTypeHashFunctionThunk =
+          () => _getInstantiationTypeHashFunction(typeCount);
     }
 
     return representation;
@@ -721,6 +745,42 @@
     return function;
   }
 
+  w.BaseFunction _createInstantiationTypeHashFunction(int numTypes) {
+    final function = m.functions.define(
+        instantiationClosureTypeHashFunctionType,
+        "#InstantiationTypeHash-$numTypes");
+
+    final b = function.body;
+
+    final contextStructType = _getInstantiationContextBaseStruct(numTypes);
+    final contextRefType = w.RefType.def(contextStructType, nullable: false);
+
+    final thisContext = function.locals[0];
+    final thisContextLocal = function.addLocal(contextRefType);
+
+    b.local_get(thisContext);
+    b.ref_cast(contextRefType);
+    b.local_set(thisContextLocal);
+
+    // Same as `SystemHash.hashN` functions: combine first hash with
+    // `_hashSeed`.
+    translator.globals.readGlobal(b, translator.hashSeed);
+
+    // Field 0 is the instantiated closure. Types start at 1.
+    for (int typeFieldIdx = 1; typeFieldIdx <= numTypes; typeFieldIdx += 1) {
+      b.local_get(thisContextLocal);
+      b.struct_get(contextStructType, typeFieldIdx);
+      b.call(translator.functions
+          .getFunction(translator.runtimeTypeHashCode.reference));
+      b.call(translator.functions
+          .getFunction(translator.systemHashCombine.reference));
+    }
+
+    b.end();
+
+    return function;
+  }
+
   ClosureRepresentationsForParameterCount _representationsForCounts(
       int typeCount, int positionalCount) {
     while (representations.length <= typeCount) {
diff --git a/pkg/dart2wasm/lib/code_generator.dart b/pkg/dart2wasm/lib/code_generator.dart
index 080ee16..5123ca3 100644
--- a/pkg/dart2wasm/lib/code_generator.dart
+++ b/pkg/dart2wasm/lib/code_generator.dart
@@ -3783,8 +3783,10 @@
 }
 
 extension MacroAssembler on w.InstructionsBuilder {
-  // Expects there to be a i32 on the stack, will consume it and leave
-  // true/false on the stack.
+  /// `[i32] -> [i32]`
+  ///
+  /// Consumes an `i32` for a class ID, leaves an `i32` as `bool` for whether
+  /// the class ID is in the given list of ranges.
   void emitClassIdRangeCheck(List<Range> ranges) {
     if (ranges.isEmpty) {
       drop();
@@ -3823,4 +3825,73 @@
       end(); // done
     }
   }
+
+  /// `[ref _Closure] -> [i32]`
+  ///
+  /// Given a closure reference returns whether the closure is an
+  /// instantiation.
+  void emitInstantiationClosureCheck(Translator translator) {
+    ref_cast(w.RefType(translator.closureLayouter.closureBaseStruct,
+        nullable: false));
+    struct_get(translator.closureLayouter.closureBaseStruct,
+        FieldIndex.closureContext);
+    ref_test(w.RefType(
+        translator.closureLayouter.instantiationContextBaseStruct,
+        nullable: false));
+  }
+
+  /// `[ref _Closure] -> [ref _ClosureBase]`
+  ///
+  /// Given an instantiation closure returns the instantiated closure.
+  void emitGetInstantiatedClosure(Translator translator) {
+    // instantiation.context
+    ref_cast(w.RefType(translator.closureLayouter.closureBaseStruct,
+        nullable: false));
+    struct_get(translator.closureLayouter.closureBaseStruct,
+        FieldIndex.closureContext);
+    ref_cast(w.RefType(
+        translator.closureLayouter.instantiationContextBaseStruct,
+        nullable: false));
+    // instantiation.context.inner
+    struct_get(translator.closureLayouter.instantiationContextBaseStruct,
+        FieldIndex.instantiationContextInner);
+  }
+
+  /// `[ref #ClosureBase] -> [ref #InstantiationContextBase]`
+  ///
+  /// Given an instantiation closure returns the instantiated closure's
+  /// context.
+  void emitGetInstantiationContextInner(Translator translator) {
+    // instantiation.context
+    struct_get(translator.closureLayouter.closureBaseStruct,
+        FieldIndex.closureContext);
+    ref_cast(w.RefType(
+        translator.closureLayouter.instantiationContextBaseStruct,
+        nullable: false));
+    // instantiation.context.inner
+    struct_get(translator.closureLayouter.instantiationContextBaseStruct,
+        FieldIndex.instantiationContextInner);
+  }
+
+  /// `[ref _Closure] -> [i32]`
+  ///
+  /// Given a closure returns whether the closure is a tear-off.
+  void emitTearOffCheck(Translator translator) {
+    ref_cast(w.RefType(translator.closureLayouter.closureBaseStruct,
+        nullable: false));
+    struct_get(translator.closureLayouter.closureBaseStruct,
+        FieldIndex.closureContext);
+    ref_test(translator.topInfo.nonNullableType);
+  }
+
+  /// `[ref _Closure] -> [ref #Top]`
+  ///
+  /// Given a closure returns the receiver of the closure.
+  void emitGetTearOffReceiver(Translator translator) {
+    ref_cast(w.RefType(translator.closureLayouter.closureBaseStruct,
+        nullable: false));
+    struct_get(translator.closureLayouter.closureBaseStruct,
+        FieldIndex.closureContext);
+    ref_cast(translator.topInfo.nonNullableType);
+  }
 }
diff --git a/pkg/dart2wasm/lib/intrinsics.dart b/pkg/dart2wasm/lib/intrinsics.dart
index 0f27693..08bfc4e 100644
--- a/pkg/dart2wasm/lib/intrinsics.dart
+++ b/pkg/dart2wasm/lib/intrinsics.dart
@@ -1570,18 +1570,6 @@
       b.br_on_cast_fail(notInstantiationBlock,
           const w.RefType.struct(nullable: false), instantiationContextBase);
 
-      // Closures are instantiations. Compare inner function vtables to check
-      // that instantiations are for the same generic function.
-      void getInstantiationContextInner(w.Local fun) {
-        b.local_get(fun);
-        // instantiation.context
-        b.struct_get(closureBaseStruct, FieldIndex.closureContext);
-        b.ref_cast(instantiationContextBase);
-        // instantiation.context.inner
-        b.struct_get(translator.closureLayouter.instantiationContextBaseStruct,
-            FieldIndex.instantiationContextInner);
-      }
-
       // Closures are instantiations of the same function, compare types.
       b.local_get(fun1);
       b.struct_get(closureBaseStruct, FieldIndex.closureContext);
@@ -1589,7 +1577,8 @@
       b.local_get(fun2);
       b.struct_get(closureBaseStruct, FieldIndex.closureContext);
       b.ref_cast(instantiationContextBase);
-      getInstantiationContextInner(fun1);
+      b.local_get(fun1);
+      _getInstantiationContextInner(translator, b);
       b.struct_get(closureBaseStruct, FieldIndex.closureVtable);
       b.ref_cast(w.RefType.def(
           translator.closureLayouter.genericVtableBaseStruct,
@@ -1599,9 +1588,11 @@
       b.call_ref(translator
           .closureLayouter.instantiationClosureTypeComparisonFunctionType);
       b.if_();
-      getInstantiationContextInner(fun1);
+      b.local_get(fun1);
+      _getInstantiationContextInner(translator, b);
       b.local_tee(fun1);
-      getInstantiationContextInner(fun2);
+      b.local_get(fun2);
+      _getInstantiationContextInner(translator, b);
       b.local_tee(fun2);
       b.ref_eq();
       b.if_();
@@ -1651,12 +1642,76 @@
       return true;
     }
 
+    if (member.enclosingClass == translator.closureClass &&
+        name == "_isInstantiationClosure") {
+      assert(function.locals.length == 1);
+      b.local_get(function.locals[0]); // ref _Closure
+      b.emitInstantiationClosureCheck(translator);
+      return true;
+    }
+
+    if (member.enclosingClass == translator.closureClass &&
+        name == "_instantiatedClosure") {
+      assert(function.locals.length == 1);
+      b.local_get(function.locals[0]); // ref _Closure
+      b.emitGetInstantiatedClosure(translator);
+      return true;
+    }
+
+    if (member.enclosingClass == translator.closureClass &&
+        name == "_instantiationClosureTypeHash") {
+      assert(function.locals.length == 1);
+
+      // Instantiation context, to be passed to the hash function.
+      b.local_get(function.locals[0]); // ref _Closure
+      b.ref_cast(w.RefType(translator.closureLayouter.closureBaseStruct,
+          nullable: false));
+      b.struct_get(translator.closureLayouter.closureBaseStruct,
+          FieldIndex.closureContext);
+      b.ref_cast(w.RefType(
+          translator.closureLayouter.instantiationContextBaseStruct,
+          nullable: false));
+
+      // Hash function.
+      b.local_get(function.locals[0]); // ref _Closure
+      b.ref_cast(w.RefType(translator.closureLayouter.closureBaseStruct,
+          nullable: false));
+      _getInstantiationContextInner(translator, b);
+      b.struct_get(translator.closureLayouter.closureBaseStruct,
+          FieldIndex.closureVtable);
+      b.ref_cast(w.RefType.def(
+          translator.closureLayouter.genericVtableBaseStruct,
+          nullable: false));
+      b.struct_get(translator.closureLayouter.genericVtableBaseStruct,
+          FieldIndex.vtableInstantiationTypeHashFunction);
+      b.call_ref(
+          translator.closureLayouter.instantiationClosureTypeHashFunctionType);
+
+      return true;
+    }
+
+    if (member.enclosingClass == translator.closureClass &&
+        name == "_isInstanceTearOff") {
+      assert(function.locals.length == 1);
+      b.local_get(function.locals[0]); // ref _Closure
+      b.emitTearOffCheck(translator);
+      return true;
+    }
+
+    if (member.enclosingClass == translator.closureClass &&
+        name == "_instanceTearOffReceiver") {
+      assert(function.locals.length == 1);
+      b.local_get(function.locals[0]); // ref _Closure
+      b.emitGetTearOffReceiver(translator);
+      return true;
+    }
+
     if (member.enclosingClass == translator.coreTypes.functionClass &&
         name == "apply") {
       assert(function.type.inputs.length == 3);
 
       final closureLocal = function.locals[0]; // ref #ClosureBase
-      final posArgsNullableLocal = function.locals[1]; // ref null Object,
+      final posArgsNullableLocal = function.locals[1]; // ref null Object
       final namedArgsLocal = function.locals[2]; // ref null Object
 
       // Create empty type arguments array.
@@ -1760,3 +1815,19 @@
     return false;
   }
 }
+
+/// Expects a `ref #ClosureBase` for an instantiation closure on stack. Pops
+/// the value and pushes the instantiated closure's (not instantiation's!)
+/// context.
+void _getInstantiationContextInner(
+    Translator translator, w.InstructionsBuilder b) {
+  // instantiation.context
+  b.struct_get(
+      translator.closureLayouter.closureBaseStruct, FieldIndex.closureContext);
+  b.ref_cast(w.RefType(
+      translator.closureLayouter.instantiationContextBaseStruct,
+      nullable: false));
+  // instantiation.context.inner
+  b.struct_get(translator.closureLayouter.instantiationContextBaseStruct,
+      FieldIndex.instantiationContextInner);
+}
diff --git a/pkg/dart2wasm/lib/kernel_nodes.dart b/pkg/dart2wasm/lib/kernel_nodes.dart
index 1fa86ab..57b0e80 100644
--- a/pkg/dart2wasm/lib/kernel_nodes.dart
+++ b/pkg/dart2wasm/lib/kernel_nodes.dart
@@ -217,6 +217,8 @@
       index.getProcedure("dart:core", "_BoxedInt", "_truncDiv");
   late final Procedure runtimeTypeEquals =
       index.getTopLevelProcedure("dart:core", "_runtimeTypeEquals");
+  late final Procedure runtimeTypeHashCode =
+      index.getTopLevelProcedure("dart:core", "_runtimeTypeHashCode");
 
   // dart:core invocation/exception procedures
   late final Procedure invocationGetterFactory =
@@ -297,6 +299,11 @@
   late final Procedure wasmTableCallIndirect =
       index.getProcedure("dart:_wasm", "WasmTable", "callIndirect");
 
+  // Hash utils
+  late final Field hashSeed = index.getTopLevelField('dart:core', '_hashSeed');
+  late final Procedure systemHashCombine =
+      index.getProcedure("dart:_internal", "SystemHash", "combine");
+
   // Debugging
   late final Procedure printToConsole =
       index.getTopLevelProcedure("dart:_internal", "printToConsole");
diff --git a/pkg/dart2wasm/lib/translator.dart b/pkg/dart2wasm/lib/translator.dart
index 744eb91..5f4c085 100644
--- a/pkg/dart2wasm/lib/translator.dart
+++ b/pkg/dart2wasm/lib/translator.dart
@@ -813,6 +813,7 @@
     ib.ref_func(dynamicCallEntry);
     if (representation.isGeneric) {
       ib.ref_func(representation.instantiationTypeComparisonFunction);
+      ib.ref_func(representation.instantiationTypeHashFunction);
       ib.ref_func(representation.instantiationFunction);
     }
     for (int posArgCount = 0; posArgCount <= positionalCount; posArgCount++) {
diff --git a/sdk/lib/_internal/wasm/lib/closure.dart b/sdk/lib/_internal/wasm/lib/closure.dart
index 26e052f..2cc2169 100644
--- a/sdk/lib/_internal/wasm/lib/closure.dart
+++ b/sdk/lib/_internal/wasm/lib/closure.dart
@@ -26,9 +26,19 @@
   @pragma("wasm:prefer-inline")
   external static _FunctionType _getClosureRuntimeType(_Closure closure);
 
-  // Simple hash code for now, we can optimize later
   @override
-  int get hashCode => runtimeType.hashCode;
+  int get hashCode {
+    if (_isInstantiationClosure) {
+      return Object.hash(_instantiatedClosure, _instantiationClosureTypeHash());
+    }
+
+    if (_isInstanceTearOff) {
+      return Object.hash(
+          _instanceTearOffReceiver, _getClosureRuntimeType(this));
+    }
+
+    return Object._objectHashCode(this); // identity hash
+  }
 
   // Support dynamic tear-off of `.call` on functions
   @pragma("wasm:entry-point")
@@ -36,4 +46,30 @@
 
   @override
   String toString() => 'Closure: $runtimeType';
+
+  // Helpers for implementing `hashCode`, `operator ==`.
+
+  /// Whether the closure is an instantiation.
+  external bool get _isInstantiationClosure;
+
+  /// When the closure is an instantiation, get the instantiated closure.
+  ///
+  /// Traps when the closure is not an instantiation.
+  external _Closure? get _instantiatedClosure;
+
+  /// When the closure is an instantiation, returns the combined hash code of
+  /// the captured types.
+  ///
+  /// Traps when the closure is not an instantiation.
+  external int _instantiationClosureTypeHash();
+
+  /// Whether the closure is an instance tear-off.
+  ///
+  /// Instance tear-offs will have receivers.
+  external bool get _isInstanceTearOff;
+
+  /// When the closure is an instance tear-off, returns the receiver.
+  ///
+  /// Traps when the closure is not an instance tear-off.
+  external Object? get _instanceTearOffReceiver;
 }
diff --git a/sdk/lib/_internal/wasm/lib/type.dart b/sdk/lib/_internal/wasm/lib/type.dart
index 282b096..52565c4 100644
--- a/sdk/lib/_internal/wasm/lib/type.dart
+++ b/sdk/lib/_internal/wasm/lib/type.dart
@@ -1413,3 +1413,8 @@
 @pragma("wasm:entry-point")
 @pragma("wasm:prefer-inline")
 bool _runtimeTypeEquals(_Type t1, _Type t2) => t1 == t2;
+
+// Same as [_RuntimeTypeEquals], but for `Object.hashCode`.
+@pragma("wasm:entry-point")
+@pragma("wasm:prefer-inline")
+int _runtimeTypeHashCode(_Type t) => t.hashCode;
diff --git a/sdk/lib/core/object.dart b/sdk/lib/core/object.dart
index 6e67af1..9eae932 100644
--- a/sdk/lib/core/object.dart
+++ b/sdk/lib/core/object.dart
@@ -561,4 +561,5 @@
 }
 
 // A per-isolate seed for hash code computations.
+@pragma("wasm:entry-point")
 final int _hashSeed = identityHashCode(Object);
diff --git a/tests/web/wasm/closure_hash_code_test.dart b/tests/web/wasm/closure_hash_code_test.dart
new file mode 100644
index 0000000..1ab39f1
--- /dev/null
+++ b/tests/web/wasm/closure_hash_code_test.dart
@@ -0,0 +1,66 @@
+// Copyright (c) 2024, the Dart project authors.  Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+import 'package:expect/expect.dart';
+
+// Check closure (tear-off, instantiation, static) equality, hash code, and
+// identities.
+//
+// In principle principle unequal objects can have the same hash code, but it's
+// very unlikely in this test. So we also check that if two closures are not
+// equal then they should have different hash codes.
+
+void staticFunction() {}
+
+void genericStaticFunction<T>(T t) {}
+
+class C {
+  void memberFunction() {}
+
+  void genericMemberFunction<T>(T t) {}
+}
+
+void main() {
+  var functionExpression = () {};
+
+  var genericFunctionExpression = <T>(T t) {};
+
+  check(functionExpression, functionExpression, equal: true, isIdentical: true);
+  check(genericFunctionExpression, genericFunctionExpression,
+      equal: true, isIdentical: true);
+  check(genericFunctionExpression<int>, genericFunctionExpression<int>,
+      equal: true, isIdentical: false);
+
+  check(() {}, () {}, equal: false, isIdentical: false);
+
+  check(staticFunction, staticFunction, equal: true, isIdentical: true);
+  check(genericStaticFunction, genericStaticFunction,
+      equal: true, isIdentical: true);
+  check(genericStaticFunction<int>, genericStaticFunction<int>,
+      equal: true, isIdentical: true);
+
+  final o1 = C();
+
+  check(o1.memberFunction, o1.memberFunction, equal: true, isIdentical: false);
+  check(o1.genericMemberFunction, o1.genericMemberFunction,
+      equal: true, isIdentical: false);
+  check(o1.genericMemberFunction<int>, o1.genericMemberFunction<int>,
+      equal: true, isIdentical: false);
+
+  final o2 = C();
+
+  check(o1.memberFunction, o2.memberFunction, equal: false, isIdentical: false);
+  check(o1.genericMemberFunction, o2.genericMemberFunction,
+      equal: false, isIdentical: false);
+  check(o1.genericMemberFunction<int>, o2.genericMemberFunction<int>,
+      equal: false, isIdentical: false);
+}
+
+void check(Object? o1, Object? o2,
+    {required bool equal, required bool isIdentical}) {
+  (equal ? Expect.equals : Expect.notEquals)(o1, o2);
+  (equal ? Expect.equals : Expect.notEquals)(o1.hashCode, o2.hashCode);
+
+  (isIdentical ? Expect.isTrue : Expect.isFalse)(identical(o1, o2));
+}