[dart2wasm] Fix `this` restoration code in sync* handling.

Noticed that the same bug that was fixed in [0] also exists in other
places.

=> Remove duplicated code & share in macro assembler.
=> Make use of this in async & sync* generator.

[0] https://dart-review.googlesource.com/c/sdk/+/364321

Change-Id: Id424ab5e8ed8ab70d19977d10cf80fb8b44b3872
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/364441
Commit-Queue: Martin Kustermann <kustermann@google.com>
Reviewed-by: Slava Egorov <vegorov@google.com>
diff --git a/pkg/dart2wasm/lib/async.dart b/pkg/dart2wasm/lib/async.dart
index d39a455..f424045f 100644
--- a/pkg/dart2wasm/lib/async.dart
+++ b/pkg/dart2wasm/lib/async.dart
@@ -696,49 +696,6 @@
     b.end();
   }
 
-  /// Clones the context pointed to by the [srcContext] local. Returns a local
-  /// pointing to the cloned context.
-  ///
-  /// It is assumed that the context is the function-level context for the
-  /// `async` function.
-  w.Local _cloneContext(
-      FunctionNode functionNode, Context context, w.Local srcContext) {
-    assert(context.owner == functionNode);
-
-    final w.Local destContext = addLocal(context.currentLocal.type);
-    b.struct_new_default(context.struct);
-    b.local_set(destContext);
-
-    void copyCapture(TreeNode node) {
-      Capture? capture = closures.captures[node];
-      if (capture != null) {
-        assert(capture.context == context);
-        b.local_get(destContext);
-        b.local_get(srcContext);
-        b.struct_get(context.struct, capture.fieldIndex);
-        b.struct_set(context.struct, capture.fieldIndex);
-      }
-    }
-
-    if (context.containsThis) {
-      b.local_get(destContext);
-      b.local_get(srcContext);
-      b.struct_get(context.struct, context.thisFieldIndex);
-      b.struct_set(context.struct, context.thisFieldIndex);
-    }
-    if (context.parent != null) {
-      b.local_get(destContext);
-      b.local_get(srcContext);
-      b.struct_get(context.struct, context.parentFieldIndex);
-      b.struct_set(context.struct, context.parentFieldIndex);
-    }
-    functionNode.positionalParameters.forEach(copyCapture);
-    functionNode.namedParameters.forEach(copyCapture);
-    functionNode.typeParameters.forEach(copyCapture);
-
-    return destContext;
-  }
-
   void _generateInner(FunctionNode functionNode, Context? context,
       w.FunctionBuilder resumeFun) {
     // void Function(_AsyncSuspendState, Object?)
@@ -795,7 +752,14 @@
     _emitTargetLabel(initialTarget);
 
     // Clone context on first execution.
-    _restoreContextsAndThis(context, cloneContextFor: functionNode);
+    b.restoreSuspendStateContext(
+        suspendStateLocal,
+        asyncSuspendStateInfo.struct,
+        FieldIndex.asyncSuspendStateContext,
+        closures,
+        context,
+        thisLocal,
+        cloneContextFor: functionNode);
 
     visitStatement(functionNode.body!);
 
@@ -883,45 +847,6 @@
     }
   }
 
-  void _restoreContextsAndThis(Context? context,
-      {FunctionNode? cloneContextFor}) {
-    if (context != null) {
-      assert(!context.isEmpty);
-      b.local_get(suspendStateLocal);
-      b.struct_get(
-          asyncSuspendStateInfo.struct, FieldIndex.asyncSuspendStateContext);
-      b.ref_cast(context.currentLocal.type as w.RefType);
-      b.local_set(context.currentLocal);
-
-      if (context.owner == cloneContextFor) {
-        context.currentLocal =
-            _cloneContext(cloneContextFor!, context, context.currentLocal);
-      }
-
-      bool restoredThis = false;
-      while (context != null) {
-        if (context.containsThis) {
-          assert(!restoredThis);
-          b.local_get(context.currentLocal);
-          b.struct_get(context.struct, context.thisFieldIndex);
-          b.ref_as_non_null();
-          b.local_set(thisLocal!);
-          restoredThis = true;
-        }
-
-        final parent = context.parent;
-        if (parent != null) {
-          assert(!parent.isEmpty);
-          b.local_get(context.currentLocal);
-          b.struct_get(context.struct, context.parentFieldIndex);
-          b.ref_as_non_null();
-          b.local_set(parent.currentLocal);
-        }
-        context = parent;
-      }
-    }
-  }
-
   @override
   void visitDoStatement(DoStatement node) {
     StateTarget? inner = innerTargets[node];
@@ -1451,7 +1376,13 @@
     // Generate resume label
     _emitTargetLabel(after);
 
-    _restoreContextsAndThis(context);
+    b.restoreSuspendStateContext(
+        suspendStateLocal,
+        asyncSuspendStateInfo.struct,
+        FieldIndex.asyncSuspendStateContext,
+        closures,
+        context,
+        thisLocal);
 
     // Handle exceptions
     final exceptionBlock = b.block();
diff --git a/pkg/dart2wasm/lib/code_generator.dart b/pkg/dart2wasm/lib/code_generator.dart
index a558c37..6c96826 100644
--- a/pkg/dart2wasm/lib/code_generator.dart
+++ b/pkg/dart2wasm/lib/code_generator.dart
@@ -3927,6 +3927,101 @@
     struct_get(
         translator.closureLayouter.closureBaseStruct, FieldIndex.closureVtable);
   }
+
+  /// Will restore all context locals and `this` from a suspend state.
+  void restoreSuspendStateContext(
+      w.Local suspendStateLocal,
+      w.StructType suspendStateStruct,
+      int suspendStateContextField,
+      Closures closures,
+      Context? context,
+      w.Local? thisLocal,
+      {FunctionNode? cloneContextFor}) {
+    if (context != null) {
+      assert(!context.isEmpty);
+      local_get(suspendStateLocal);
+      struct_get(suspendStateStruct, suspendStateContextField);
+      ref_cast(context.currentLocal.type as w.RefType);
+      local_set(context.currentLocal);
+      if (context.owner == cloneContextFor) {
+        context.currentLocal =
+            cloneFunctionLevelContext(closures, context, cloneContextFor!);
+      }
+      restoreThisAndContextChain(context, thisLocal);
+    }
+  }
+
+  /// Will restore the parent context chain and `this` (if captured)
+  ///
+  /// Assumes the innermost context is already loaded.
+  void restoreThisAndContextChain(
+      Context innermostContext, w.Local? thisLocal) {
+    bool restoredThis = false;
+
+    Context? context = innermostContext;
+    while (context != null) {
+      if (context.containsThis) {
+        assert(!restoredThis);
+        local_get(context.currentLocal);
+        struct_get(context.struct, context.thisFieldIndex);
+        ref_as_non_null();
+        local_set(thisLocal!);
+        restoredThis = true;
+      }
+
+      final parent = context.parent;
+      if (parent != null) {
+        assert(!parent.isEmpty);
+        local_get(context.currentLocal);
+        struct_get(context.struct, context.parentFieldIndex);
+        ref_as_non_null();
+        local_set(parent.currentLocal);
+      }
+      context = parent;
+    }
+  }
+
+  /// Clones the [context] and returns a local to the clone it.
+  ///
+  /// It is assumed that the context is a function-level context.
+  w.Local cloneFunctionLevelContext(
+      Closures closures, Context context, FunctionNode functionNode) {
+    final w.Local srcContext = context.currentLocal;
+    final w.Local destContext =
+        addLocal(context.currentLocal.type, isParameter: false);
+
+    struct_new_default(context.struct);
+    local_set(destContext);
+
+    void copyCapture(TreeNode node) {
+      Capture? capture = closures.captures[node];
+      if (capture != null) {
+        assert(capture.context == context);
+        local_get(destContext);
+        local_get(srcContext);
+        struct_get(context.struct, capture.fieldIndex);
+        struct_set(context.struct, capture.fieldIndex);
+      }
+    }
+
+    if (context.containsThis) {
+      local_get(destContext);
+      local_get(srcContext);
+      struct_get(context.struct, context.thisFieldIndex);
+      struct_set(context.struct, context.thisFieldIndex);
+    }
+    if (context.parent != null) {
+      local_get(destContext);
+      local_get(srcContext);
+      struct_get(context.struct, context.parentFieldIndex);
+      struct_set(context.struct, context.parentFieldIndex);
+    }
+    functionNode.positionalParameters.forEach(copyCapture);
+    functionNode.namedParameters.forEach(copyCapture);
+    functionNode.typeParameters.forEach(copyCapture);
+
+    return destContext;
+  }
 }
 
 bool guardCanMatchJSException(Translator translator, DartType guard) {
diff --git a/pkg/dart2wasm/lib/sync_star.dart b/pkg/dart2wasm/lib/sync_star.dart
index a468b84..6232d10 100644
--- a/pkg/dart2wasm/lib/sync_star.dart
+++ b/pkg/dart2wasm/lib/sync_star.dart
@@ -280,49 +280,6 @@
     b.end();
   }
 
-  /// Clones the context pointed to by the [srcContext] local. Returns a local
-  /// pointing to the cloned context.
-  ///
-  /// It is assumed that the context is the function-level context for the
-  /// `sync*` function.
-  w.Local cloneContext(
-      FunctionNode functionNode, Context context, w.Local srcContext) {
-    assert(context.owner == functionNode);
-
-    final w.Local destContext = addLocal(context.currentLocal.type);
-    b.struct_new_default(context.struct);
-    b.local_set(destContext);
-
-    void copyCapture(TreeNode node) {
-      Capture? capture = closures.captures[node];
-      if (capture != null) {
-        assert(capture.context == context);
-        b.local_get(destContext);
-        b.local_get(srcContext);
-        b.struct_get(context.struct, capture.fieldIndex);
-        b.struct_set(context.struct, capture.fieldIndex);
-      }
-    }
-
-    if (context.containsThis) {
-      b.local_get(destContext);
-      b.local_get(srcContext);
-      b.struct_get(context.struct, context.thisFieldIndex);
-      b.struct_set(context.struct, context.thisFieldIndex);
-    }
-    if (context.parent != null) {
-      b.local_get(destContext);
-      b.local_get(srcContext);
-      b.struct_get(context.struct, context.parentFieldIndex);
-      b.struct_set(context.struct, context.parentFieldIndex);
-    }
-    functionNode.positionalParameters.forEach(copyCapture);
-    functionNode.namedParameters.forEach(copyCapture);
-    functionNode.typeParameters.forEach(copyCapture);
-
-    return destContext;
-  }
-
   void generateInner(FunctionNode functionNode, Context? context,
       w.FunctionBuilder resumeFun) {
     // Set the current Wasm function for the code generator to the inner
@@ -375,7 +332,9 @@
     emitTargetLabel(initialTarget);
 
     // Clone context on first execution.
-    restoreContextsAndThis(context, cloneContextFor: functionNode);
+    b.restoreSuspendStateContext(suspendStateLocal, suspendStateInfo.struct,
+        FieldIndex.suspendStateContext, closures, context, thisLocal,
+        cloneContextFor: functionNode);
 
     visitStatement(functionNode.body!);
 
@@ -421,37 +380,6 @@
     }
   }
 
-  void restoreContextsAndThis(Context? context,
-      {FunctionNode? cloneContextFor}) {
-    if (context != null) {
-      assert(!context.isEmpty);
-      b.local_get(suspendStateLocal);
-      b.struct_get(suspendStateInfo.struct, FieldIndex.suspendStateContext);
-      b.ref_cast(context.currentLocal.type as w.RefType);
-      b.local_set(context.currentLocal);
-
-      if (context.owner == cloneContextFor) {
-        context.currentLocal =
-            cloneContext(cloneContextFor!, context, context.currentLocal);
-      }
-
-      while (context!.parent != null) {
-        assert(!context.parent!.isEmpty);
-        b.local_get(context.currentLocal);
-        b.struct_get(context.struct, context.parentFieldIndex);
-        b.ref_as_non_null();
-        context = context.parent!;
-        b.local_set(context.currentLocal);
-      }
-      if (context.containsThis) {
-        b.local_get(context.currentLocal);
-        b.struct_get(context.struct, context.thisFieldIndex);
-        b.ref_as_non_null();
-        b.local_set(thisLocal!);
-      }
-    }
-  }
-
   @override
   void visitDoStatement(DoStatement node) {
     StateTarget? inner = innerTargets[node];
@@ -615,7 +543,8 @@
       b.end(); // exceptionCheck
     }
 
-    restoreContextsAndThis(context);
+    b.restoreSuspendStateContext(suspendStateLocal, suspendStateInfo.struct,
+        FieldIndex.suspendStateContext, closures, context, thisLocal);
   }
 
   @override
diff --git a/tests/web/wasm/capture_type_and_this_test.dart b/tests/web/wasm/capture_type_and_this_test.dart
index b487c57..bbc9684 100644
--- a/tests/web/wasm/capture_type_and_this_test.dart
+++ b/tests/web/wasm/capture_type_and_this_test.dart
@@ -20,7 +20,7 @@
   }
 }
 
-main() {
+void testThisRestorationInAsyncClosure() {
   final a = A<String>();
 
   if (!identical(a, capturedThis)) {
@@ -30,3 +30,39 @@
     throw 'Should have captured the correct `T`.';
   }
 }
+
+late Iterable iterable;
+
+class B<T> {
+  B() {
+    foo() sync* {
+      // This will create a context chain as follows:
+      //   Context [T]
+      //     `--> Context [<parent-context>, this]
+      //            `--> Context [<parent-context, ...]
+      yield T;
+      yield this;
+    }
+
+    iterable = foo();
+  }
+}
+
+void testThisRestorationInSyncStarClosure() {
+  final a = B<String>();
+
+  final it = iterable.iterator;
+  if (!it.moveNext()) throw 'expected first element';
+  if (!identical(String, it.current)) {
+    throw 'Should have captured the correct `T`.';
+  }
+  if (!it.moveNext()) throw 'expected second element';
+  if (!identical(a, it.current)) {
+    throw 'Should have captured the correct `this`.';
+  }
+}
+
+main() {
+  testThisRestorationInAsyncClosure();
+  testThisRestorationInSyncStarClosure();
+}