[vm/shared] Ensure context captured by isolategroup callbacks has only trivially shareable objects.

Fixes https://github.com/dart-lang/sdk/issues/61210
TEST=run_isolate_group_run_test, isolate_group_bound_callback_test

Change-Id: I9633726055007255f67eb9744186fdbe21fa832d
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/443639
Commit-Queue: Alexander Aprelev <aam@google.com>
Reviewed-by: Ryan Macnak <rmacnak@google.com>
diff --git a/runtime/lib/concurrent.cc b/runtime/lib/concurrent.cc
index e55e7c2..4621549 100644
--- a/runtime/lib/concurrent.cc
+++ b/runtime/lib/concurrent.cc
@@ -4,6 +4,8 @@
 
 #include "include/dart_api.h"
 #include "vm/bootstrap_natives.h"
+#include "vm/dart_api_impl.h"
+#include "vm/ffi_callback_metadata.h"
 #include "vm/heap/safepoint.h"
 #include "vm/os_thread.h"
 
@@ -120,9 +122,15 @@
         "Encountered shared data api when functionality is disabled. "
         "Pass --experimental-shared-data");
   }
-
   Thread* current_thread = Thread::Current();
   ASSERT(current_thread->execution_state() == Thread::kThreadInNative);
+
+  {
+    DARTSCOPE(current_thread);
+    FfiCallbackMetadata::EnsureOnlyTriviallyImmutableValuesInClosure(
+        current_thread->zone(), Closure::RawCast(Api::UnwrapHandle(closure)));
+  }
+
   Isolate* saved_isolate = current_thread->isolate();
   current_thread->ExitSafepointFromNative();
   current_thread->set_execution_state(Thread::kThreadInVM);
diff --git a/runtime/vm/ffi_callback_metadata.cc b/runtime/vm/ffi_callback_metadata.cc
index 716db15..8268e5f 100644
--- a/runtime/vm/ffi_callback_metadata.cc
+++ b/runtime/vm/ffi_callback_metadata.cc
@@ -351,6 +351,52 @@
   return handle;
 }
 
+static void ValidateTriviallyImmutabilityOfAnObject(Zone* zone,
+                                                    Object* p_obj,
+                                                    ObjectPtr object_ptr) {
+  *p_obj = object_ptr;
+  if (p_obj->IsSmi() || p_obj->IsNull()) {
+    return;
+  }
+  if (p_obj->IsClosure()) {
+    FfiCallbackMetadata::EnsureOnlyTriviallyImmutableValuesInClosure(
+        zone, Closure::RawCast(p_obj->ptr()));
+    return;
+  }
+  if (!p_obj->IsImmutable()) {
+    const String& error = String::Handle(
+        zone,
+        String::NewFormatted("Only trivially-immutable values are allowed: %s.",
+                             p_obj->ToCString()));
+    Exceptions::ThrowArgumentError(error);
+    UNREACHABLE();
+  }
+}
+
+void FfiCallbackMetadata::EnsureOnlyTriviallyImmutableValuesInClosure(
+    Zone* zone,
+    ClosurePtr closure_ptr) {
+  Closure& closure = Closure::Handle(zone, closure_ptr);
+  if (closure.IsNull()) {
+    return;
+  }
+  Object& obj = Object::Handle(zone);
+  const auto& function = Function::Handle(closure.function());
+  if (function.IsImplicitClosureFunction()) {
+    ValidateTriviallyImmutabilityOfAnObject(
+        zone, &obj, closure.GetImplicitClosureReceiver());
+  } else {
+    const Context& context = Context::Handle(zone, closure.GetContext());
+    if (context.IsNull()) {
+      return;
+    }
+    // Iterate through all elements of the context.
+    for (intptr_t i = 0; i < context.num_variables(); i++) {
+      ValidateTriviallyImmutabilityOfAnObject(zone, &obj, context.At(i));
+    }
+  }
+}
+
 FfiCallbackMetadata::Trampoline FfiCallbackMetadata::CreateLocalFfiCallback(
     Isolate* isolate,
     IsolateGroup* isolate_group,
@@ -375,6 +421,12 @@
            (isolate == nullptr && isolate_group != nullptr &&
             function.GetFfiCallbackKind() ==
                 FfiCallbackKind::kIsolateGroupBoundClosureCallback));
+
+    if (function.GetFfiCallbackKind() ==
+        FfiCallbackKind::kIsolateGroupBoundClosureCallback) {
+      EnsureOnlyTriviallyImmutableValuesInClosure(zone, closure.ptr());
+    }
+
     handle = CreatePersistentHandle(
         isolate != nullptr ? isolate->group() : isolate_group, closure);
   }
diff --git a/runtime/vm/ffi_callback_metadata.h b/runtime/vm/ffi_callback_metadata.h
index 5d5ed3f..2dc90bb 100644
--- a/runtime/vm/ffi_callback_metadata.h
+++ b/runtime/vm/ffi_callback_metadata.h
@@ -346,6 +346,10 @@
 #error What architecture?
 #endif
 
+  static void EnsureOnlyTriviallyImmutableValuesInClosure(
+      Zone* zone,
+      ClosurePtr closure_ptr);
+
   // Visible for testing.
   MetadataEntry* MetadataEntryOfTrampoline(Trampoline trampoline) const;
   Trampoline TrampolineOfMetadataEntry(MetadataEntry* metadata) const;
diff --git a/tests/ffi/isolate_group_bound_callback_test.dart b/tests/ffi/isolate_group_bound_callback_test.dart
index 64110fd..d0a8dc4 100644
--- a/tests/ffi/isolate_group_bound_callback_test.dart
+++ b/tests/ffi/isolate_group_bound_callback_test.dart
@@ -50,29 +50,24 @@
   return DynamicLibrary.open('$_dylibPrefix$name$_dylibExtension');
 }
 
-class NativeLibrary {
-  late final FnRunnerType callFunctionOnSameThread;
-  late final FnRunnerType callFunctionOnNewThreadBlocking;
-  late final FnRunnerType callFunctionOnNewThreadNonBlocking;
-  late final TwoIntFnType callTwoIntFunction;
-  late final FnSleepType sleep;
+DynamicLibrary get ffiTestFunctions =>
+    dlopenPlatformSpecific("ffi_test_functions");
 
-  NativeLibrary(DynamicLibrary ffiTestFunctions) {
-    callFunctionOnNewThreadNonBlocking = ffiTestFunctions
-        .lookupFunction<FnRunnerNativeType, FnRunnerType>(
-          "CallFunctionOnNewThreadNonBlocking",
-        );
-    callFunctionOnNewThreadBlocking = ffiTestFunctions
-        .lookupFunction<FnRunnerNativeType, FnRunnerType>(
-          "CallFunctionOnNewThreadBlocking",
-        );
-    callTwoIntFunction = ffiTestFunctions
-        .lookupFunction<TwoIntFnNativeType, TwoIntFnType>("CallTwoIntFunction");
-    sleep = ffiTestFunctions.lookupFunction<FnSleepNativeType, FnSleepType>(
-      "SleepFor",
+FnRunnerType get callFunctionOnNewThreadNonBlocking =>
+    ffiTestFunctions.lookupFunction<FnRunnerNativeType, FnRunnerType>(
+      "CallFunctionOnNewThreadNonBlocking",
     );
-  }
-}
+
+FnRunnerType get callFunctionOnNewThreadBlocking =>
+    ffiTestFunctions.lookupFunction<FnRunnerNativeType, FnRunnerType>(
+      "CallFunctionOnNewThreadBlocking",
+    );
+
+TwoIntFnType get callTwoIntFunction => ffiTestFunctions
+    .lookupFunction<TwoIntFnNativeType, TwoIntFnType>("CallTwoIntFunction");
+
+FnSleepType get sleep =>
+    ffiTestFunctions.lookupFunction<FnSleepNativeType, FnSleepType>("SleepFor");
 
 @pragma('vm:shared')
 late Mutex mutexCondvar;
@@ -88,16 +83,14 @@
 
 void simpleFunction(int a, int b) {
   result += (a * b);
-  final ffiTestFunctions = dlopenPlatformSpecific("ffi_test_functions");
-  final lib = NativeLibrary(ffiTestFunctions);
-  lib.sleep(sleepForMs);
+  sleep(sleepForMs);
   mutexCondvar.runLocked(() {
     resultIsReady = true;
     conditionVariable.notify();
   });
 }
 
-Future<void> testNativeCallableHelloWorld(NativeLibrary lib) async {
+Future<void> testNativeCallableHelloWorld() async {
   mutexCondvar = Mutex();
   conditionVariable = ConditionVariable();
   final callback = NativeCallable<CallbackNativeType>.isolateGroupBound(
@@ -106,7 +99,7 @@
 
   result = 42;
   resultIsReady = false;
-  lib.callFunctionOnNewThreadNonBlocking(1001, callback.nativeFunction);
+  callFunctionOnNewThreadNonBlocking(1001, callback.nativeFunction);
 
   mutexCondvar.runLocked(() {
     while (!resultIsReady) {
@@ -118,7 +111,7 @@
   Expect.equals(42 + (1001 * 123), result);
 
   resultIsReady = false;
-  lib.callFunctionOnNewThreadNonBlocking(1001, callback.nativeFunction);
+  callFunctionOnNewThreadNonBlocking(1001, callback.nativeFunction);
   mutexCondvar.runLocked(() {
     while (!resultIsReady) {
       conditionVariable.wait(mutexCondvar, 10 * sleepForMs);
@@ -134,7 +127,7 @@
   throw 'hello, world';
 }
 
-Future<void> testNativeCallableThrows(NativeLibrary lib) async {
+Future<void> testNativeCallableThrows() async {
   mutexCondvar = Mutex();
   conditionVariable = ConditionVariable();
   final callback = NativeCallable<CallbackNativeType>.isolateGroupBound(
@@ -147,7 +140,7 @@
   // race between invoking the callback and closing it few lines down below.
   // So the main thing this test checks is condition variable timeout,
   // which is still valuable.
-  lib.callFunctionOnNewThreadBlocking(1001, callback.nativeFunction);
+  callFunctionOnNewThreadBlocking(1001, callback.nativeFunction);
 
   mutexCondvar.runLocked(() {
     // Just have short one second sleep - the condition variable is not
@@ -158,7 +151,24 @@
   callback.close();
 }
 
-Future<void> testNativeCallableHelloWorldClosure(NativeLibrary lib) async {
+@pragma('vm:shared')
+SendPort? sp;
+
+Future<void> testFailToCaptureReceivePort() async {
+  final rp = ReceivePort();
+  Expect.throws(
+    () {
+      NativeCallable<CallbackNativeType>.isolateGroupBound((int a, int b) {
+        sp = rp.sendPort;
+      });
+    },
+    (e) =>
+        e is ArgumentError && e.toString().contains('Only trivially-immutable'),
+  );
+  rp.close();
+}
+
+Future<void> testNativeCallableHelloWorldClosure() async {
   mutexCondvar = Mutex();
   conditionVariable = ConditionVariable();
   final callback = NativeCallable<CallbackNativeType>.isolateGroupBound((
@@ -166,7 +176,7 @@
     int b,
   ) {
     result += (a * b);
-    lib.sleep(sleepForMs);
+    sleep(sleepForMs);
     mutexCondvar.runLocked(() {
       resultIsReady = true;
       conditionVariable.notify();
@@ -175,7 +185,7 @@
 
   result = 42;
   resultIsReady = false;
-  lib.callFunctionOnNewThreadNonBlocking(1001, callback.nativeFunction);
+  callFunctionOnNewThreadNonBlocking(1001, callback.nativeFunction);
 
   mutexCondvar.runLocked(() {
     while (!resultIsReady) {
@@ -186,7 +196,7 @@
   Expect.equals(42 + (1001 * 123), result);
 
   resultIsReady = false;
-  lib.callFunctionOnNewThreadNonBlocking(1001, callback.nativeFunction);
+  callFunctionOnNewThreadNonBlocking(1001, callback.nativeFunction);
   mutexCondvar.runLocked(() {
     while (!resultIsReady) {
       conditionVariable.wait(mutexCondvar);
@@ -196,7 +206,7 @@
   callback.close();
 }
 
-void testNativeCallableSync(NativeLibrary lib) {
+void testNativeCallableSync() {
   final callback =
       NativeCallable<CallbackReturningIntNativeType>.isolateGroupBound((
         int a,
@@ -205,14 +215,11 @@
         return a + b;
       }, exceptionalReturn: 1111);
 
-  Expect.equals(
-    1234,
-    lib.callTwoIntFunction(callback.nativeFunction, 1000, 234),
-  );
+  Expect.equals(1234, callTwoIntFunction(callback.nativeFunction, 1000, 234));
   callback.close();
 }
 
-void testNativeCallableSyncThrows(NativeLibrary lib) {
+void testNativeCallableSyncThrows() {
   final callback =
       NativeCallable<CallbackReturningIntNativeType>.isolateGroupBound(
         (int a, int b) {
@@ -222,16 +229,13 @@
         exceptionalReturn: 1111,
       );
 
-  Expect.equals(
-    1111,
-    lib.callTwoIntFunction(callback.nativeFunction, 1000, 234),
-  );
+  Expect.equals(1111, callTwoIntFunction(callback.nativeFunction, 1000, 234));
   callback.close();
 }
 
 int isolateVar = 10;
 
-void testNativeCallableAccessNonSharedVar(NativeLibrary lib) {
+void testNativeCallableAccessNonSharedVar() {
   final callback =
       NativeCallable<CallbackReturningIntNativeType>.isolateGroupBound((
         int a,
@@ -241,10 +245,7 @@
       }, exceptionalReturn: 1111);
 
   isolateVar = 42;
-  Expect.equals(
-    1111,
-    lib.callTwoIntFunction(callback.nativeFunction, 1000, 234),
-  );
+  Expect.equals(1111, callTwoIntFunction(callback.nativeFunction, 1000, 234));
   callback.close();
 }
 
@@ -304,14 +305,13 @@
 main(args, message) async {
   asyncStart();
   // Simple tests.
-  final ffiTestFunctions = dlopenPlatformSpecific("ffi_test_functions");
-  final lib = NativeLibrary(ffiTestFunctions);
-  await testNativeCallableHelloWorld(lib);
-  await testNativeCallableThrows(lib);
-  await testNativeCallableHelloWorldClosure(lib);
-  testNativeCallableSync(lib);
-  testNativeCallableSyncThrows(lib);
-  testNativeCallableAccessNonSharedVar(lib);
+  await testNativeCallableHelloWorld();
+  await testNativeCallableThrows();
+  await testFailToCaptureReceivePort();
+  await testNativeCallableHelloWorldClosure();
+  testNativeCallableSync();
+  testNativeCallableSyncThrows();
+  testNativeCallableAccessNonSharedVar();
   await testKeepIsolateAliveTrue();
   await testKeepIsolateAliveFalse();
   asyncEnd();
diff --git a/tests/ffi/isolate_group_bound_send_test.dart b/tests/ffi/isolate_group_bound_send_test.dart
index 5aff059..2120c23 100644
--- a/tests/ffi/isolate_group_bound_send_test.dart
+++ b/tests/ffi/isolate_group_bound_send_test.dart
@@ -21,10 +21,15 @@
 main() async {
   asyncStart();
   ReceivePort rp = ReceivePort();
-  IsolateGroup.runSync(() {
-    rp.sendPort.send("hello");
-  });
-  Expect.equals("hello", await rp.first);
+  Expect.throws(
+    () {
+      IsolateGroup.runSync(() {
+        rp.sendPort.send("hello");
+      });
+    },
+    (e) =>
+        e is ArgumentError && e.toString().contains('Only trivially-immutable'),
+  );
   rp.close();
   asyncEnd();
 }
diff --git a/tests/ffi/run_isolate_group_run_test.dart b/tests/ffi/run_isolate_group_run_test.dart
index eaacaa3..da3a520 100644
--- a/tests/ffi/run_isolate_group_run_test.dart
+++ b/tests/ffi/run_isolate_group_run_test.dart
@@ -50,7 +50,30 @@
 @pragma('vm:shared')
 String string_foo = "";
 
-main() {
+@pragma('vm:shared')
+SendPort? sp;
+
+StringMethodTearoffTest() {
+  final stringTearoff = "abc".toString;
+  IsolateGroup.runSync(() {
+    stringTearoff;
+  });
+}
+
+ListMethodTearoffTest(List<String> args) {
+  final listTearoff = args.insert;
+  Expect.throws(
+    () {
+      IsolateGroup.runSync(() {
+        listTearoff;
+      });
+    },
+    (e) =>
+        e is ArgumentError && e.toString().contains("Only trivially-immutable"),
+  );
+}
+
+main(List<String> args) {
   IsolateGroup.runSync(() {
     final l = <int>[];
     for (int i = 0; i < 100; i++) {
@@ -181,5 +204,20 @@
     });
   }, (e) => e.toString().contains("Attempt to access isolate static field"));
 
+  StringMethodTearoffTest();
+  ListMethodTearoffTest(args);
+
+  final rp = ReceivePort();
+  Expect.throws(
+    () {
+      IsolateGroup.runSync(() {
+        sp = rp.sendPort;
+      });
+    },
+    (e) =>
+        e is ArgumentError && e.toString().contains("Only trivially-immutable"),
+  );
+  rp.close();
+
   print("All tests completed :)");
 }