[dart2wasm] Support accessing function type parameter from within lambda

This completes the implementation of runtime access to type arguments.

Generic function expressions and local functions are now explicitly
rejected, which causes a number of new failures. These tests were
previously spurious passes, since generic functions are not supported.

Change-Id: Ibbf8bde94cddf2c0f86e3740a798c1602897f631
Cq-Include-Trybots: luci.dart.try:dart2wasm-linux-x64-d8-try
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/244402
Reviewed-by: Joshua Litt <joshualitt@google.com>
Commit-Queue: Aske Simon Christensen <askesc@google.com>
diff --git a/pkg/dart2wasm/lib/closures.dart b/pkg/dart2wasm/lib/closures.dart
index cde1b4c..7929e21 100644
--- a/pkg/dart2wasm/lib/closures.dart
+++ b/pkg/dart2wasm/lib/closures.dart
@@ -47,6 +47,9 @@
   /// The variables captured by this context.
   final List<VariableDeclaration> variables = [];
 
+  /// The type parameters captured by this context.
+  final List<TypeParameter> typeParameters = [];
+
   /// Whether this context contains a captured `this`. Only member contexts can.
   bool containsThis = false;
 
@@ -57,7 +60,8 @@
   /// generation.
   late w.Local currentLocal;
 
-  bool get isEmpty => variables.isEmpty && !containsThis;
+  bool get isEmpty =>
+      variables.isEmpty && typeParameters.isEmpty && !containsThis;
 
   int get parentFieldIndex {
     assert(parent != null);
@@ -74,7 +78,7 @@
 
 /// A captured variable.
 class Capture {
-  final VariableDeclaration variable;
+  final TreeNode variable;
   late final Context context;
   late final int fieldIndex;
   bool written = false;
@@ -88,7 +92,7 @@
 /// tree for a member.
 class Closures {
   final CodeGenerator codeGen;
-  final Map<VariableDeclaration, Capture> captures = {};
+  final Map<TreeNode, Capture> captures = {};
   bool isThisCaptured = false;
   final Map<FunctionNode, Lambda> lambdas = {};
   final Map<TreeNode, Context> contexts = {};
@@ -98,6 +102,9 @@
 
   Translator get translator => codeGen.translator;
 
+  late final w.ValueType typeType =
+      translator.classInfo[translator.typeClass]!.nullableType;
+
   void findCaptures(Member member) {
     var find = CaptureFinder(this, member);
     if (member is Constructor) {
@@ -144,6 +151,11 @@
               translator.translateType(variable.type).withNullability(true)));
           captures[variable]!.fieldIndex = index;
         }
+        for (TypeParameter parameter in context.typeParameters) {
+          int index = struct.fields.length;
+          struct.fields.add(w.FieldType(typeType));
+          captures[parameter]!.fieldIndex = index;
+        }
       }
     }
   }
@@ -152,7 +164,7 @@
 class CaptureFinder extends RecursiveVisitor {
   final Closures closures;
   final Member member;
-  final Map<VariableDeclaration, int> variableDepth = {};
+  final Map<TreeNode, int> variableDepth = {};
   int depth = 0;
 
   CaptureFinder(this.closures, this.member);
@@ -170,12 +182,23 @@
     super.visitVariableDeclaration(node);
   }
 
-  void _visitVariableUse(VariableDeclaration variable) {
+  @override
+  void visitTypeParameter(TypeParameter node) {
+    if (node.parent is FunctionNode) {
+      if (depth > 0) {
+        variableDepth[node] = depth;
+      }
+    }
+    super.visitTypeParameter(node);
+  }
+
+  void _visitVariableUse(TreeNode variable) {
     int declDepth = variableDepth[variable] ?? 0;
     assert(declDepth <= depth);
     if (declDepth < depth) {
       closures.captures[variable] = Capture(variable);
-    } else if (variable.parent is FunctionDeclaration) {
+    } else if (variable is VariableDeclaration &&
+        variable.parent is FunctionDeclaration) {
       closures.closurizedFunctions.add(variable.parent as FunctionDeclaration);
     }
   }
@@ -213,7 +236,10 @@
   void visitTypeParameterType(TypeParameterType node) {
     if (node.parameter.parent == member.enclosingClass) {
       _visitThis();
+    } else if (node.parameter.parent is FunctionNode) {
+      _visitVariableUse(node.parameter);
     }
+    super.visitTypeParameterType(node);
   }
 
   void _visitLambda(FunctionNode node) {
@@ -222,6 +248,10 @@
       throw "Not supported: Optional parameters for "
           "function expression or local function at ${node.location}";
     }
+    if (node.typeParameters.isNotEmpty) {
+      throw "Not supported: Type parameters for "
+          "function expression or local function at ${node.location}";
+    }
     int parameterCount = node.requiredParameterCount;
     w.FunctionType type = translator.closureFunctionType(parameterCount);
     w.DefinedFunction function =
@@ -311,6 +341,16 @@
   }
 
   @override
+  void visitTypeParameter(TypeParameter node) {
+    Capture? capture = closures.captures[node];
+    if (capture != null) {
+      currentContext!.typeParameters.add(node);
+      capture.context = currentContext!;
+    }
+    super.visitTypeParameter(node);
+  }
+
+  @override
   void visitVariableSet(VariableSet node) {
     closures.captures[node.variable]?.written = true;
     super.visitVariableSet(node);
diff --git a/pkg/dart2wasm/lib/code_generator.dart b/pkg/dart2wasm/lib/code_generator.dart
index 4a0d1c3..fb69fc5 100644
--- a/pkg/dart2wasm/lib/code_generator.dart
+++ b/pkg/dart2wasm/lib/code_generator.dart
@@ -400,6 +400,15 @@
         b.struct_set(capture.context.struct, capture.fieldIndex);
       }
     });
+    typeLocals.forEach((parameter, local) {
+      Capture? capture = closures.captures[parameter];
+      if (capture != null) {
+        b.local_get(capture.context.currentLocal);
+        b.local_get(local);
+        translator.convertType(function, local.type, capture.type);
+        b.struct_set(capture.context.struct, capture.fieldIndex);
+      }
+    });
   }
 
   /// Helper function to throw a Wasm ref downcast error.
@@ -475,7 +484,7 @@
     }
     for (TypeParameter typeParam in cls.typeParameters) {
       types.makeType(
-          this, TypeParameterType(typeParam, Nullability.nonNullable), node);
+          this, TypeParameterType(typeParam, Nullability.nonNullable));
     }
     _visitArguments(node.arguments, node.targetReference, 1);
     _call(node.targetReference);
@@ -493,7 +502,7 @@
       b.ref_as_non_null();
     }
     for (DartType typeArg in supertype!.typeArguments) {
-      types.makeType(this, typeArg, node);
+      types.makeType(this, typeArg);
     }
     _visitArguments(node.arguments, node.targetReference,
         1 + supertype.typeArguments.length);
@@ -1788,7 +1797,7 @@
     final w.FunctionType signature = translator.signatureFor(target);
     final ParameterInfo paramInfo = translator.paramInfoFor(target);
     for (int i = 0; i < node.types.length; i++) {
-      types.makeType(this, node.types[i], node);
+      types.makeType(this, node.types[i]);
     }
     signatureOffset += node.types.length;
     for (int i = 0; i < node.positional.length; i++) {
@@ -1827,11 +1836,8 @@
   @override
   w.ValueType visitStringConcatenation(
       StringConcatenation node, w.ValueType expectedType) {
-    makeList(
-        node.expressions,
-        translator.fixedLengthListClass,
-        InterfaceType(translator.stringBaseClass, Nullability.nonNullable),
-        node);
+    makeList(node.expressions, translator.fixedLengthListClass,
+        InterfaceType(translator.stringBaseClass, Nullability.nonNullable));
     return _call(translator.stringInterpolate.reference);
   }
 
@@ -1899,12 +1905,12 @@
 
   @override
   w.ValueType visitListLiteral(ListLiteral node, w.ValueType expectedType) {
-    return makeList(node.expressions, translator.growableListClass,
-        node.typeArgument, node);
+    return makeList(
+        node.expressions, translator.growableListClass, node.typeArgument);
   }
 
-  w.ValueType makeList(List<Expression> expressions, Class cls,
-      DartType typeArg, TreeNode node) {
+  w.ValueType makeList(
+      List<Expression> expressions, Class cls, DartType typeArg) {
     ClassInfo info = translator.classInfo[cls]!;
     translator.functions.allocateClass(info.classId);
     w.RefType refType = info.struct.fields.last.type.unpacked as w.RefType;
@@ -1914,7 +1920,7 @@
 
     b.i32_const(info.classId);
     b.i32_const(initialIdentityHash);
-    types.makeType(this, typeArg, node);
+    types.makeType(this, typeArg);
     b.i64_const(length);
     if (options.lazyConstants) {
       // Avoid array.init instruction in lazy constants mode
@@ -1950,8 +1956,8 @@
     w.BaseFunction mapFactory =
         translator.functions.getFunction(translator.mapFactory.reference);
     w.ValueType factoryReturnType = mapFactory.type.outputs.single;
-    types.makeType(this, node.keyType, node);
-    types.makeType(this, node.valueType, node);
+    types.makeType(this, node.keyType);
+    types.makeType(this, node.valueType);
     b.call(mapFactory);
     if (node.entries.isEmpty) {
       return factoryReturnType;
@@ -1980,7 +1986,7 @@
     w.BaseFunction setFactory =
         translator.functions.getFunction(translator.setFactory.reference);
     w.ValueType factoryReturnType = setFactory.type.outputs.single;
-    types.makeType(this, node.typeArgument, node);
+    types.makeType(this, node.typeArgument);
     b.call(setFactory);
     if (node.expressions.isEmpty) {
       return factoryReturnType;
@@ -2005,7 +2011,7 @@
 
   @override
   w.ValueType visitTypeLiteral(TypeLiteral node, w.ValueType expectedType) {
-    return types.makeType(this, node.type, node);
+    return types.makeType(this, node.type);
   }
 
   @override
@@ -2027,7 +2033,7 @@
     types.emitTypeTest(this, node.type, dartTypeOf(node.operand), node);
     b.br_if(asCheckBlock);
     b.local_get(operand);
-    types.makeType(this, node.type, node);
+    types.makeType(this, node.type);
     _call(translator.stackTraceCurrent.reference);
     _call(translator.throwAsCheckError.reference);
     b.unreachable();
@@ -2035,6 +2041,36 @@
     b.local_get(operand);
     return operand.type;
   }
+
+  w.ValueType instantiateTypeParameter(TypeParameter parameter) {
+    w.ValueType resultType;
+    if (parameter.parent is FunctionNode) {
+      // Type argument to function
+      w.Local? local = typeLocals[parameter];
+      if (local != null) {
+        b.local_get(local);
+        resultType = local.type;
+      } else {
+        Capture capture = closures.captures[parameter]!;
+        b.local_get(capture.context.currentLocal);
+        b.struct_get(capture.context.struct, capture.fieldIndex);
+        resultType = capture.type;
+      }
+    } else {
+      // Type argument of class
+      Class cls = parameter.parent as Class;
+      ClassInfo info = translator.classInfo[cls]!;
+      int fieldIndex = translator.typeParameterIndex[parameter]!;
+      w.ValueType thisType = visitThis(info.nonNullableType);
+      translator.convertType(function, thisType, info.nonNullableType);
+      b.struct_get(info.struct, fieldIndex);
+      resultType = info.struct.fields[fieldIndex].type.unpacked;
+    }
+    final w.ValueType nonNullableTypeType =
+        translator.classInfo[translator.typeClass]!.nonNullableType;
+    translator.convertType(function, resultType, nonNullableTypeType);
+    return nonNullableTypeType;
+  }
 }
 
 class TryBlockFinalizer {
diff --git a/pkg/dart2wasm/lib/types.dart b/pkg/dart2wasm/lib/types.dart
index b94a144..53df180 100644
--- a/pkg/dart2wasm/lib/types.dart
+++ b/pkg/dart2wasm/lib/types.dart
@@ -117,27 +117,24 @@
     throw "Unexpected DartType: $type";
   }
 
-  void _makeTypeList(
-      CodeGenerator codeGen, List<DartType> types, TreeNode node) {
+  void _makeTypeList(CodeGenerator codeGen, List<DartType> types) {
     w.ValueType listType = codeGen.makeList(
         types.map((t) => TypeLiteral(t)).toList(),
         translator.fixedLengthListClass,
-        InterfaceType(translator.typeClass, Nullability.nonNullable),
-        node);
+        InterfaceType(translator.typeClass, Nullability.nonNullable));
     translator.convertType(codeGen.function, listType, typeListExpectedType);
   }
 
-  void _makeInterfaceType(CodeGenerator codeGen, ClassInfo info,
-      InterfaceType type, TreeNode node) {
+  void _makeInterfaceType(
+      CodeGenerator codeGen, ClassInfo info, InterfaceType type) {
     w.Instructions b = codeGen.b;
     ClassInfo typeInfo = translator.classInfo[type.classNode]!;
     encodeNullability(b, type);
     b.i64_const(typeInfo.classId);
-    _makeTypeList(codeGen, type.typeArguments, node);
+    _makeTypeList(codeGen, type.typeArguments);
   }
 
-  void _makeFutureOrType(
-      CodeGenerator codeGen, FutureOrType type, TreeNode node) {
+  void _makeFutureOrType(CodeGenerator codeGen, FutureOrType type) {
     w.Instructions b = codeGen.b;
     w.DefinedFunction function = codeGen.function;
 
@@ -146,7 +143,7 @@
     // undetermined nullability. To handle this, we emit the type argument, and
     // read back its nullability at runtime.
     if (type.nullability == Nullability.undetermined) {
-      w.ValueType typeArgumentType = makeType(codeGen, type.typeArgument, node);
+      w.ValueType typeArgumentType = makeType(codeGen, type.typeArgument);
       w.Local typeArgumentTemporary = codeGen.addLocal(typeArgumentType);
       b.local_tee(typeArgumentTemporary);
       b.struct_get(typeClassInfo.struct, FieldIndex.typeIsNullable);
@@ -154,15 +151,15 @@
       translator.convertType(function, typeArgumentType, nonNullableTypeType);
     } else {
       encodeNullability(b, type);
-      makeType(codeGen, type.typeArgument, node);
+      makeType(codeGen, type.typeArgument);
     }
   }
 
   void _makeFunctionType(
-      CodeGenerator codeGen, ClassInfo info, FunctionType type, TreeNode node) {
+      CodeGenerator codeGen, ClassInfo info, FunctionType type) {
     w.Instructions b = codeGen.b;
     encodeNullability(b, type);
-    makeType(codeGen, type.returnType, node);
+    makeType(codeGen, type.returnType);
     if (type.positionalParameters.every(_isTypeConstant)) {
       translator.constants.instantiateConstant(
           codeGen.function,
@@ -170,7 +167,7 @@
           translator.constants.makeTypeList(type.positionalParameters),
           typeListExpectedType);
     } else {
-      _makeTypeList(codeGen, type.positionalParameters, node);
+      _makeTypeList(codeGen, type.positionalParameters);
     }
     b.i64_const(type.requiredParameterCount);
     if (type.namedParameters.every((n) => _isTypeConstant(n.type))) {
@@ -197,8 +194,8 @@
                   BoolLiteral(n.isRequired)
                 ])));
       }
-      w.ValueType namedParametersListType = codeGen.makeList(expressions,
-          translator.fixedLengthListClass, namedParameterType, node);
+      w.ValueType namedParametersListType = codeGen.makeList(
+          expressions, translator.fixedLengthListClass, namedParameterType);
       translator.convertType(codeGen.function, namedParametersListType,
           namedParametersExpectedType);
     }
@@ -207,7 +204,7 @@
   /// Makes a `_Type` object on the stack.
   /// TODO(joshualitt): Refactor this logic to remove the dependency on
   /// CodeGenerator.
-  w.ValueType makeType(CodeGenerator codeGen, DartType type, TreeNode node) {
+  w.ValueType makeType(CodeGenerator codeGen, DartType type) {
     w.Instructions b = codeGen.b;
     if (_isTypeConstant(type)) {
       translator.constants.instantiateConstant(
@@ -221,38 +218,16 @@
         type is FutureOrType ||
         type is FunctionType);
     if (type is TypeParameterType) {
-      if (type.parameter.parent is FunctionNode) {
-        // Type argument to function
-        w.Local? local = codeGen.typeLocals[type.parameter];
-        if (local != null) {
-          b.local_get(local);
-          translator.convertType(
-              codeGen.function, local.type, nonNullableTypeType);
-          return nonNullableTypeType;
-        } else {
-          codeGen.unimplemented(node, "Type parameter access inside lambda",
-              [nonNullableTypeType]);
-          return nonNullableTypeType;
-        }
-      }
-      // Type argument of class
-      Class cls = type.parameter.parent as Class;
-      ClassInfo info = translator.classInfo[cls]!;
-      int fieldIndex = translator.typeParameterIndex[type.parameter]!;
-      w.ValueType thisType = codeGen.visitThis(info.nonNullableType);
-      translator.convertType(codeGen.function, thisType, info.nonNullableType);
-      b.struct_get(info.struct, fieldIndex);
-      b.ref_as_non_null();
-      return nonNullableTypeType;
+      return codeGen.instantiateTypeParameter(type.parameter);
     }
     ClassInfo info = translator.classInfo[classForType(type)]!;
     translator.functions.allocateClass(info.classId);
     b.i32_const(info.classId);
     b.i32_const(initialIdentityHash);
     if (type is InterfaceType) {
-      _makeInterfaceType(codeGen, info, type, node);
+      _makeInterfaceType(codeGen, info, type);
     } else if (type is FutureOrType) {
-      _makeFutureOrType(codeGen, type, node);
+      _makeFutureOrType(codeGen, type);
     } else if (type is FunctionType) {
       if (isGenericFunction(type)) {
         // TODO(joshualitt): Implement generic function types and share most of
@@ -260,7 +235,7 @@
         print("Not implemented: RTI ${type}");
         encodeNullability(b, type);
       } else {
-        _makeFunctionType(codeGen, info, type, node);
+        _makeFunctionType(codeGen, info, type);
       }
     } else {
       throw '`$type` should have already been handled.';