Report unsafe BuildContext access in Future.then and more

Fixes https://github.com/dart-lang/linter/issues/4923

In this change we now protect code inside callback arguments to
functions like `Future.then`. For example:

```dart
Future<int>.delayed(Duration(seconds: 1), (_) {
  return 42;
}).then((int value) {
 Navigator.of(context).pop();
});
```

In code like this, each reference to BuildContext must be guarded by
a mounted check.

There are a number of "protected functions" that the rule now checks:

Constructors:
* Future.new: 1st positional parameter
* Future.delayed: 2nd positional parameter
* Future.microtask: 1st positional parameter
* Stream.eventTransformed: 2nd positional parameter
* Stream.multi: 1st positional parameter
* Stream.periodic: 2nd positional parameter

* StreamController.new: 'onListen', 'onPause', 'onResume', and
  'onCancel' named parameters
* StreamController.broadcast: 'onListen' and 'onCancel' named
  parameters

Instance Methods:
* Future.catchError: first positional parameter, and 'test' named
  parameter
* Future.onError: 1st positional parameter, and 'test' named parameter
* Future.then: 1st positional parameter, and 'onError' named
  parameter
* Future.timeout: 'onTimeout' named parameter
* Future.whenComplete: 1st positional parameter
* Stream.any: 1st positional parameter
* Stream.asBroadcastStream: 'onListen' and 'onCancel' named parameters
* Stream.asyncExpand: 1st positional parameter
* Stream.asyncMap: 1st positional parameter
* Stream.distinct: 1st positional parameter
* Stream.expand: 1st positional parameter
* Stream.firstWhere: first positional parameter, and 'orElse' named
  parameter
* Stream.fold: second positional parameter
* Stream.forEach: 1st positional parameter
* Stream.handleError: 1st positional parameter, and 'test' named
  parameter
* Stream.lastWhere: 1st positional parameter, and 'orElse' named
  parameter
* Stream.listen: 1st positional parameter, and 'onError' and 'onDone'
  named parameters
* Stream.map: 1st positional parameter
* Stream.reduce: 1st positional parameter
* Stream.singleWhere: 1st positional parameter, and 'orElse' named
  parameter
* Stream.skipWhile: 1st positional parameter
* Stream.takeWhile: 1st positional parameter
* Stream.timeout: 'onTimeout' named parameter
* Stream.where: 1st positional parameter
* Stream.onData: 1st positional parameter
* Stream.onDone: 1st positional parameter
* Stream.onError: 1st positional parameter

Static Methods:
* Future.doWhile: 1st positional parameter
* Future.forEach: 2nd positional parameter
* Future.wait: 'cleanUp' named parameter



Change-Id: I703d5cfbee5e8af1f703f40046832d66b9b59bc2
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/365541
Reviewed-by: Phil Quitslund <pquitslund@google.com>
Commit-Queue: Samuel Rawlins <srawlins@google.com>
diff --git a/pkg/linter/lib/src/rules/use_build_context_synchronously.dart b/pkg/linter/lib/src/rules/use_build_context_synchronously.dart
index 02a80a8..3f9e3f7 100644
--- a/pkg/linter/lib/src/rules/use_build_context_synchronously.dart
+++ b/pkg/linter/lib/src/rules/use_build_context_synchronously.dart
@@ -13,6 +13,7 @@
 import 'package:pub_semver/pub_semver.dart';
 
 import '../analyzer.dart';
+import '../extensions.dart';
 import '../util/flutter_utils.dart';
 
 const _desc = r'Do not use `BuildContext` across asynchronous gaps.';
@@ -923,6 +924,32 @@
   }
 }
 
+/// Function with callback parameters that should be "protected."
+///
+/// Any callback passed as a [positional] argument or [named] argument to such
+/// a function must have a mounted guard check for any references to
+/// BuildContext.
+class ProtectedFunction {
+  final String library;
+
+  /// The name of the target type of the function (for instance methods) or the
+  /// defining element (for constructors and static methods).
+  final String? type;
+
+  /// The name of the function. Can be `null` to represent an unnamed
+  /// constructor.
+  final String? name;
+
+  /// The list of positional parameters that are protected.
+  final List<int> positional;
+
+  /// The list of named parameters that are protected.
+  final List<String> named;
+
+  const ProtectedFunction(this.library, this.type, this.name,
+      {this.positional = const <int>[], this.named = const <String>[]});
+}
+
 class UseBuildContextSynchronously extends LintRule {
   static const LintCode asyncUseCode = LintCode(
     'use_build_context_synchronously',
@@ -972,6 +999,80 @@
 class _Visitor extends SimpleAstVisitor {
   static const mountedName = 'mounted';
 
+  static const protectedConstructors = [
+    // Future constructors.
+    // Protect the unnamed constructor as both `Future()` and `Future.new()`.
+    ProtectedFunction('dart.async', 'Future', null, positional: [0]),
+    ProtectedFunction('dart.async', 'Future', 'new', positional: [0]),
+    ProtectedFunction('dart.async', 'Future', 'delayed', positional: [1]),
+    ProtectedFunction('dart.async', 'Future', 'microtask', positional: [0]),
+
+    // Stream constructors.
+    ProtectedFunction('dart.async', 'Stream', 'eventTransformed',
+        positional: [1]),
+    ProtectedFunction('dart.async', 'Stream', 'multi', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'periodic', positional: [1]),
+
+    // StreamController constructors.
+    ProtectedFunction('dart.async', 'StreamController', null,
+        named: ['onListen', 'onPause', 'onResume', 'onCancel']),
+    ProtectedFunction('dart.async', 'StreamController', 'new',
+        named: ['onListen', 'onPause', 'onResume', 'onCancel']),
+    ProtectedFunction('dart.async', 'StreamController', 'broadcast',
+        named: ['onListen', 'onCancel']),
+  ];
+
+  static const protectedInstanceMethods = [
+    // Future instance methods.
+    ProtectedFunction('dart.async', 'Future', 'catchError',
+        positional: [0], named: ['test']),
+    ProtectedFunction('dart.async', 'Future', 'onError',
+        positional: [0], named: ['test']),
+    ProtectedFunction('dart.async', 'Future', 'then',
+        positional: [0], named: ['onError']),
+    ProtectedFunction('dart.async', 'Future', 'timeout', named: ['onTimeout']),
+    ProtectedFunction('dart.async', 'Future', 'whenComplete', positional: [0]),
+
+    // Stream instance methods.
+    ProtectedFunction('dart.async', 'Stream', 'any', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'asBroadcastStream',
+        named: ['onListen', 'onCancel']),
+    ProtectedFunction('dart.async', 'Stream', 'asyncExpand', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'asyncMap', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'distinct', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'expand', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'firstWhere',
+        positional: [0], named: ['orElse']),
+    ProtectedFunction('dart.async', 'Stream', 'fold', positional: [1]),
+    ProtectedFunction('dart.async', 'Stream', 'forEach', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'handleError',
+        positional: [0], named: ['test']),
+    ProtectedFunction('dart.async', 'Stream', 'lastWhere',
+        positional: [0], named: ['orElse']),
+    ProtectedFunction('dart.async', 'Stream', 'listen',
+        positional: [0], named: ['onError', 'onDone']),
+    ProtectedFunction('dart.async', 'Stream', 'map', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'reduce', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'singleWhere',
+        positional: [0], named: ['orElse']),
+    ProtectedFunction('dart.async', 'Stream', 'skipWhile', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'takeWhile', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'timeout', named: ['onTimeout']),
+    ProtectedFunction('dart.async', 'Stream', 'where', positional: [0]),
+
+    // StreamSubscription instance methods.
+    ProtectedFunction('dart.async', 'Stream', 'onData', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'onDone', positional: [0]),
+    ProtectedFunction('dart.async', 'Stream', 'onError', positional: [0]),
+  ];
+
+  static const protectedStaticMethods = [
+    // Future static methods.
+    ProtectedFunction('dart.async', 'Future', 'doWhile', positional: [0]),
+    ProtectedFunction('dart.async', 'Future', 'forEach', positional: [1]),
+    ProtectedFunction('dart.async', 'Future', 'wait', named: ['cleanUp']),
+  ];
+
   final LintRule rule;
 
   _Visitor(this.rule);
@@ -989,9 +1090,8 @@
       if (parent == null) break;
 
       var asyncState = asyncStateTracker.asyncStateFor(child, mountedElement);
-      if (asyncState.isGuarded) {
-        return;
-      }
+      if (asyncState.isGuarded) return;
+
       if (asyncState == AsyncState.asynchronous) {
         var errorCode = asyncStateTracker.hasUnrelatedMountedCheck
             ? UseBuildContextSynchronously.wrongMountedCode
@@ -1002,6 +1102,136 @@
 
       child = parent;
     }
+
+    if (child is FunctionBody) {
+      var parent = child.parent;
+      var grandparent = parent?.parent;
+      if (parent is! FunctionExpression) {
+        return;
+      }
+
+      if (grandparent is NamedExpression) {
+        // Given a FunctionBody in a named argument, like
+        // `future.catchError(test: (_) {...})`, we step up once more to the
+        // argument list.
+        grandparent = grandparent.parent;
+      }
+      if (grandparent is ArgumentList) {
+        if (grandparent.parent case InstanceCreationExpression invocation) {
+          checkConstructorCallback(invocation, parent, node);
+        }
+
+        if (grandparent.parent case MethodInvocation invocation) {
+          checkMethodCallback(invocation, parent, node);
+        }
+      }
+    }
+  }
+
+  /// Checks whether [invocation] involves a [callback] argument for a protected
+  /// constructor.
+  ///
+  /// The code inside a callback argument for a protected constructor must not
+  /// contain any references to a `BuildContext` without a guarding mounted
+  /// check.
+  void checkConstructorCallback(
+    InstanceCreationExpression invocation,
+    FunctionExpression callback,
+    Expression errorNode,
+  ) {
+    var staticType = invocation.staticType;
+    if (staticType == null) return;
+    var arguments = invocation.argumentList.arguments;
+    var positionalArguments =
+        arguments.where((a) => a is! NamedExpression).toList();
+    var namedArguments = arguments.whereType<NamedExpression>().toList();
+    for (var constructor in protectedConstructors) {
+      if (invocation.constructorName.name?.name == constructor.name &&
+          staticType.isSameAs(constructor.type, constructor.library)) {
+        checkPositionalArguments(
+            constructor.positional, positionalArguments, callback, errorNode);
+        checkNamedArguments(
+            constructor.named, namedArguments, callback, errorNode);
+      }
+    }
+  }
+
+  /// Checks whether [invocation] involves a [callback] argument for a protected
+  /// instance or static method.
+  ///
+  /// The code inside a callback argument for a protected method must not
+  /// contain any references to a `BuildContext` without a guarding mounted
+  /// check.
+  void checkMethodCallback(
+    MethodInvocation invocation,
+    FunctionExpression callback,
+    Expression errorNode,
+  ) {
+    var arguments = invocation.argumentList.arguments;
+    var positionalArguments =
+        arguments.where((a) => a is! NamedExpression).toList();
+    var namedArguments = arguments.whereType<NamedExpression>().toList();
+
+    var target = invocation.realTarget;
+    var targetElement = target is Identifier ? target.staticElement : null;
+    if (targetElement is ClassElement) {
+      // Static function called; `target` is the class.
+      for (var method in protectedStaticMethods) {
+        if (invocation.methodName.name == method.name &&
+            targetElement.name == method.type) {
+          checkPositionalArguments(
+              method.positional, positionalArguments, callback, errorNode);
+          checkNamedArguments(
+              method.named, namedArguments, callback, errorNode);
+        }
+      }
+    } else {
+      var staticType = target?.staticType;
+      if (staticType == null) return;
+      for (var method in protectedInstanceMethods) {
+        if (invocation.methodName.name == method.name &&
+            staticType.element?.name == method.type) {
+          checkPositionalArguments(
+              method.positional, positionalArguments, callback, errorNode);
+          checkNamedArguments(
+              method.named, namedArguments, callback, errorNode);
+        }
+      }
+    }
+  }
+
+  /// Checks whether [callback] is one of the [namedArguments] for one of the
+  /// protected argument [names] for a protected function.
+  void checkNamedArguments(
+      List<String> names,
+      List<NamedExpression> namedArguments,
+      Expression callback,
+      Expression errorNode) {
+    for (var named in names) {
+      var argument =
+          namedArguments.firstWhereOrNull((a) => a.name.label.name == named);
+      if (argument == null) continue;
+      if (callback == argument.expression) {
+        rule.reportLint(errorNode,
+            errorCode: UseBuildContextSynchronously.asyncUseCode);
+      }
+    }
+  }
+
+  /// Checks whether [callback] is one of the [positionalArguments] for one of
+  /// the protected argument [positions] for a protected function.
+  void checkPositionalArguments(
+      List<int> positions,
+      List<Expression> positionalArguments,
+      Expression callback,
+      Expression errorNode) {
+    for (var position in positions) {
+      if (positionalArguments.length > position &&
+          callback == positionalArguments[position]) {
+        rule.reportLint(errorNode,
+            errorCode: UseBuildContextSynchronously.asyncUseCode);
+      }
+    }
   }
 
   @override
diff --git a/pkg/linter/test/rules/use_build_context_synchronously_test.dart b/pkg/linter/test/rules/use_build_context_synchronously_test.dart
index 077f818..edb8c7e 100644
--- a/pkg/linter/test/rules/use_build_context_synchronously_test.dart
+++ b/pkg/linter/test/rules/use_build_context_synchronously_test.dart
@@ -1289,7 +1289,7 @@
     await resolveCode(r'''
 import 'package:flutter/widgets.dart';
 void foo(BuildContext context) async {
-  f(await c(), context /* ref */);
+  f(await Future.value(), context /* ref */);
 }
 void f(_, _) {}
 ''');
@@ -2309,6 +2309,155 @@
     ]);
   }
 
+  test_future_catchError_referenceToContextInNamedArgument() async {
+    // `Future.catchError` call, with use of BuildContext inside, is REPORTED.
+    await assertDiagnostics(r'''
+import 'package:flutter/widgets.dart';
+void foo(BuildContext context, Future<void> f) async {
+  f.catchError((_) {}, test: (_) {
+    Navigator.of(context);
+    return false;
+  });
+}
+''', [
+      lint(146, 7),
+    ]);
+  }
+
+  test_future_catchError_referenceToContextInPositionalArgument() async {
+    // `Future.catchError` call, with use of BuildContext inside, is REPORTED.
+    await assertDiagnostics(r'''
+import 'package:flutter/widgets.dart';
+void foo(BuildContext context, Future<void> f) async {
+  f.catchError((_) {
+    Navigator.of(context);
+  });
+}
+''', [
+      lint(132, 7),
+    ]);
+  }
+
+  test_future_catchError_referenceToContextInPositionalArgument_precedingNamedArgument() async {
+    // `Future.catchError` call, with use of BuildContext inside, is REPORTED.
+    await assertDiagnostics(r'''
+import 'package:flutter/widgets.dart';
+void foo(BuildContext context, Future<void> f) async {
+  f.catchError(test: (_) => false, (_) {
+    Navigator.of(context);
+  });
+}
+''', [
+      lint(152, 7),
+    ]);
+  }
+
+  test_future_delayed_referenceToContextInWrongArgument() async {
+    await assertDiagnostics(r'''
+import 'package:flutter/widgets.dart';
+void foo(BuildContext context) async {
+  Future.delayed((_) {
+    Navigator.of(context);
+  });
+}
+''', [
+      // Just don't crash when one argument references BuildContext, and not all
+      // positional arguments are given.
+      error(CompileTimeErrorCode.ARGUMENT_TYPE_NOT_ASSIGNABLE, 95, 36),
+    ]);
+  }
+
+  test_future_new_referenceToContextInArgument() async {
+    // `Future.new()` call, with use of BuildContext inside, is REPORTED.
+    await assertDiagnostics(r'''
+import 'package:flutter/widgets.dart';
+void foo(BuildContext context) async {
+  Future.new(() {
+    Navigator.of(context);
+  });
+}
+''', [
+      lint(113, 7),
+    ]);
+  }
+
+  test_future_then_noReferenceToContext() async {
+    // `Future.then` call, with no use of BuildContext inside, is OK.
+    await assertNoDiagnostics(r'''
+import 'package:flutter/widgets.dart';
+void foo(BuildContext context, Future<void> f) async {
+  f.then((_) {});
+}
+''');
+  }
+
+  test_future_then_referenceToContextInCallback() async {
+    // `Future.then` call, with use of BuildContext inside, is REPORTED.
+    await assertDiagnostics(r'''
+import 'package:flutter/widgets.dart';
+void foo(BuildContext context, Future<void> f) async {
+  f.then((_) {
+    Navigator.of(context);
+  });
+}
+''', [
+      lint(126, 7),
+    ]);
+  }
+
+  test_future_then_referenceToContextInCallback_expressionBody() async {
+    // `Future.then` call, with use of BuildContext inside, is REPORTED.
+    await assertDiagnostics(r'''
+import 'package:flutter/widgets.dart';
+void foo(BuildContext context, Future<void> f) async {
+  f.then((_) => Navigator.of(context));
+}
+''', [
+      lint(123, 7),
+    ]);
+  }
+
+  test_future_then_referenceToContextInCallback_mountedGuard() async {
+    // `Future.then` call, with guarded use of BuildContext inside, is OK.
+    await assertNoDiagnostics(r'''
+import 'package:flutter/widgets.dart';
+void foo(BuildContext context, Future<void> f) async {
+  f.then((_) {
+    if (!context.mounted) return;
+    Navigator.of(context);
+  });
+}
+''');
+  }
+
+  test_future_unnamed_referenceToContextInArgument() async {
+    // `Future()` call, with use of BuildContext inside, is REPORTED.
+    await assertDiagnostics(r'''
+import 'package:flutter/widgets.dart';
+void foo(BuildContext context) async {
+  Future(() {
+    Navigator.of(context);
+  });
+}
+''', [
+      lint(109, 7),
+    ]);
+  }
+
+  test_future_wait_referenceToContextInArgument() async {
+    // `Future.wait` call, with use of BuildContext inside, is REPORTED.
+    await assertDiagnostics(r'''
+import 'package:flutter/widgets.dart';
+void foo(BuildContext context) async {
+  Future.wait([], cleanUp: (_) {
+    Navigator.of(context);
+  });
+}
+''', [
+      lint(128, 7),
+    ]);
+  }
+
   test_ifConditionContainsMountedAndReferenceToContext() async {
     // Binary expression contains mounted check AND use of BuildContext, is
     // OK.
@@ -2492,6 +2641,24 @@
     ]);
   }
 
+  test_referenceToContextInFunctionExpression() async {
+    // Inside a function expression, await then use of BuildContext is REPORTED.
+    await assertDiagnostics(r'''
+import 'package:flutter/widgets.dart';
+
+void foo(BuildContext context) async {
+  () async {
+    await f();
+    Navigator.of(context);
+  }();
+}
+
+Future<void> f() async {}
+''', [
+      lint(124, 7),
+    ]);
+  }
+
   test_referenceToContextInWhileBody_thenAwait() async {
     // While statement, and inside the while-body: use of BuildContext, then
     // await, is REPORTED.