Add service extension for sending a sampling request to the MCP server (#325)
diff --git a/pkgs/dart_mcp/CHANGELOG.md b/pkgs/dart_mcp/CHANGELOG.md
index 2ab9f92..4c86846 100644
--- a/pkgs/dart_mcp/CHANGELOG.md
+++ b/pkgs/dart_mcp/CHANGELOG.md
@@ -6,6 +6,7 @@
been possible to use in a functional manner, so it is assumed that it had
no usage previously.
- Fix the `type` getter on `EmbeddedResource` to read the actual type field.
+- Add `toJson` method to the `CreateMessageResult` of a sampling request.
## 0.4.0
diff --git a/pkgs/dart_mcp/lib/src/api/sampling.dart b/pkgs/dart_mcp/lib/src/api/sampling.dart
index 8cbe985..efa8e69 100644
--- a/pkgs/dart_mcp/lib/src/api/sampling.dart
+++ b/pkgs/dart_mcp/lib/src/api/sampling.dart
@@ -123,6 +123,9 @@
/// Known reasons are "endTurn", "stopSequence", "maxTokens", or any other
/// reason.
String? get stopReason => _value['stopReason'] as String?;
+
+ /// The JSON representation of this object.
+ Map<String, Object?> toJson() => _value;
}
/// Describes a message issued to or received from an LLM API.
diff --git a/pkgs/dart_mcp_server/CHANGELOG.md b/pkgs/dart_mcp_server/CHANGELOG.md
index 628a816..d26be82 100644
--- a/pkgs/dart_mcp_server/CHANGELOG.md
+++ b/pkgs/dart_mcp_server/CHANGELOG.md
@@ -10,6 +10,8 @@
resources or resource links (when reading a directory).
- Add `additionalProperties: false` to most schemas so they provide better
errors when invoked with incorrect arguments.
+- Add `DartMcpServer.samplingRequest` service extension method to send a sampling
+ request over DTD.
# 0.1.1 (Dart SDK 3.10.0)
diff --git a/pkgs/dart_mcp_server/lib/src/mixins/dtd.dart b/pkgs/dart_mcp_server/lib/src/mixins/dtd.dart
index 637b725..24916fc 100644
--- a/pkgs/dart_mcp_server/lib/src/mixins/dtd.dart
+++ b/pkgs/dart_mcp_server/lib/src/mixins/dtd.dart
@@ -20,6 +20,18 @@
import '../utils/constants.dart';
import '../utils/tools_configuration.dart';
+/// Constants used by the MCP server to register services on DTD.
+///
+/// TODO(elliette): Add these to package:dtd instead.
+extension McpServiceConstants on Never {
+ /// Service name for the Dart MCP Server.
+ static const serviceName = 'DartMcpServer';
+
+ /// Service method name for the method to send a sampling request to the MCP
+ /// client.
+ static const samplingRequest = 'samplingRequest';
+}
+
/// Mix this in to any MCPServer to add support for connecting to the Dart
/// Tooling Daemon and all of its associated functionality (see
/// https://pub.dev/packages/dtd).
@@ -251,6 +263,7 @@
}
unawaited(_dtd!.done.then((_) async => await _resetDtd()));
+ await _registerServices();
await _listenForServices();
return CallToolResult(
content: [TextContent(text: 'Connection succeeded')],
@@ -272,6 +285,36 @@
}
}
+ /// Registers all MCP server-provided services on the connected DTD instance.
+ Future<void> _registerServices() async {
+ final dtd = _dtd!;
+
+ if (clientCapabilities.sampling != null) {
+ try {
+ await dtd.registerService(
+ McpServiceConstants.serviceName,
+ McpServiceConstants.samplingRequest,
+ _handleSamplingRequest,
+ );
+ } on RpcException catch (e) {
+ // It is expected for there to be an exception if the sampling service
+ // was already registered by another Dart MCP Server.
+ if (e.code != RpcErrorCodes.kServiceAlreadyRegistered) rethrow;
+ }
+ }
+ }
+
+ Future<Map<String, Object?>> _handleSamplingRequest(Parameters params) async {
+ final result = await createMessage(
+ CreateMessageRequest.fromMap(params.asMap.cast<String, Object?>()),
+ );
+
+ return {
+ 'type': 'Success', // Type is required by DTD.
+ ...result.toJson(),
+ };
+ }
+
/// Listens to the `ConnectedApp` and `Editor` streams to get app and IDE
/// state information.
///
diff --git a/pkgs/dart_mcp_server/test/test_harness.dart b/pkgs/dart_mcp_server/test/test_harness.dart
index d8f8f3d..91b8568 100644
--- a/pkgs/dart_mcp_server/test/test_harness.dart
+++ b/pkgs/dart_mcp_server/test/test_harness.dart
@@ -303,7 +303,8 @@
}
/// A basic MCP client which is started as a part of the harness.
-final class DartToolingMCPClient extends MCPClient with RootsSupport {
+final class DartToolingMCPClient extends MCPClient
+ with RootsSupport, SamplingSupport {
DartToolingMCPClient()
: super(
Implementation(
@@ -311,6 +312,33 @@
version: '0.1.0',
),
);
+
+ @override
+ FutureOr<CreateMessageResult> handleCreateMessage(
+ CreateMessageRequest request,
+ Implementation serverInfo,
+ ) {
+ final messageTexts = request.messages
+ .map((message) {
+ final role = message.role.name;
+ return switch (message.content) {
+ final TextContent c when c.isText => '[$role] ${c.text}',
+ final ImageContent c when c.isImage => '[$role] ${c.mimeType}',
+ final AudioContent c when c.isAudio => '[$role] ${c.mimeType}',
+ final EmbeddedResource c when c.isEmbeddedResource =>
+ '[$role] ${c.resource.uri}',
+ _ => 'UNKNOWN',
+ };
+ })
+ .join('\n');
+ return CreateMessageResult(
+ role: Role.assistant,
+ content: Content.text(
+ text: 'TOKENS: ${request.maxTokens}\n$messageTexts',
+ ),
+ model: 'test-model',
+ );
+ }
}
/// The dart tooling daemon currently expects to get vm service uris through
diff --git a/pkgs/dart_mcp_server/test/tools/dtd_test.dart b/pkgs/dart_mcp_server/test/tools/dtd_test.dart
index 0b752df..3e84226 100644
--- a/pkgs/dart_mcp_server/test/tools/dtd_test.dart
+++ b/pkgs/dart_mcp_server/test/tools/dtd_test.dart
@@ -13,6 +13,8 @@
import 'package:dart_mcp_server/src/utils/analytics.dart';
import 'package:dart_mcp_server/src/utils/constants.dart';
import 'package:devtools_shared/devtools_shared.dart';
+import 'package:dtd/dtd.dart';
+import 'package:json_rpc_2/json_rpc_2.dart';
import 'package:test/test.dart';
import 'package:unified_analytics/testing.dart';
import 'package:unified_analytics/unified_analytics.dart' as ua;
@@ -98,6 +100,215 @@
});
});
+ group('sampling service extension', () {
+ List<String> extractResponse(DTDResponse response) {
+ final responseContent =
+ response.result['content'] as Map<String, Object?>;
+ return (responseContent['text'] as String).split('\n');
+ }
+
+ test('can make a sampling request with text', () async {
+ final dtdClient = testHarness.fakeEditorExtension.dtd;
+ final response = await dtdClient.call(
+ McpServiceConstants.serviceName,
+ McpServiceConstants.samplingRequest,
+ params: {
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': {'type': 'text', 'text': 'hello world'},
+ },
+ ],
+ 'maxTokens': 512,
+ },
+ );
+ expect(extractResponse(response), [
+ 'TOKENS: 512',
+ '[user] hello world',
+ ]);
+ });
+
+ test('can make a sampling request with an image', () async {
+ final dtdClient = testHarness.fakeEditorExtension.dtd;
+ final response = await dtdClient.call(
+ McpServiceConstants.serviceName,
+ McpServiceConstants.samplingRequest,
+ params: {
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': {
+ 'type': 'image',
+ 'data': 'fake-data',
+ 'mimeType': 'image/png',
+ },
+ },
+ ],
+ 'maxTokens': 256,
+ },
+ );
+ expect(extractResponse(response), [
+ 'TOKENS: 256',
+ '[user] image/png',
+ ]);
+ });
+
+ test('can make a sampling request with audio', () async {
+ final dtdClient = testHarness.fakeEditorExtension.dtd;
+ final response = await dtdClient.call(
+ McpServiceConstants.serviceName,
+ McpServiceConstants.samplingRequest,
+ params: {
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': {
+ 'type': 'audio',
+ 'data': 'fake-data',
+ 'mimeType': 'audio',
+ },
+ },
+ ],
+ 'maxTokens': 256,
+ },
+ );
+ expect(extractResponse(response), ['TOKENS: 256', '[user] audio']);
+ });
+
+ test('can make a sampling request with an embedded resource', () async {
+ final dtdClient = testHarness.fakeEditorExtension.dtd;
+ final response = await dtdClient.call(
+ McpServiceConstants.serviceName,
+ McpServiceConstants.samplingRequest,
+ params: {
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': {
+ 'type': 'resource',
+ 'resource': {'uri': 'www.google.com', 'text': 'Google'},
+ },
+ },
+ ],
+ 'maxTokens': 256,
+ },
+ );
+ expect(extractResponse(response), [
+ 'TOKENS: 256',
+ '[user] www.google.com',
+ ]);
+ });
+
+ test('can make a sampling request with mixed content', () async {
+ final dtdClient = testHarness.fakeEditorExtension.dtd;
+ final response = await dtdClient.call(
+ McpServiceConstants.serviceName,
+ McpServiceConstants.samplingRequest,
+ params: {
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': {'type': 'text', 'text': 'hello world'},
+ },
+ {
+ 'role': 'user',
+ 'content': {
+ 'type': 'image',
+ 'data': 'fake-data',
+ 'mimeType': 'image/jpeg',
+ },
+ },
+ ],
+ 'maxTokens': 128,
+ },
+ );
+ expect(extractResponse(response), [
+ 'TOKENS: 128',
+ '[user] hello world',
+ '[user] image/jpeg',
+ ]);
+ });
+
+ test('can handle user and assistant messages', () async {
+ final dtdClient = testHarness.fakeEditorExtension.dtd;
+ final response = await dtdClient.call(
+ McpServiceConstants.serviceName,
+ McpServiceConstants.samplingRequest,
+ params: {
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': {'type': 'text', 'text': 'Hi! I have a question.'},
+ },
+ {
+ 'role': 'assistant',
+ 'content': {'type': 'text', 'text': 'What is your question?'},
+ },
+ {
+ 'role': 'user',
+ 'content': {'type': 'text', 'text': 'How big is the sun?'},
+ },
+ ],
+ 'maxTokens': 512,
+ },
+ );
+ expect(extractResponse(response), [
+ 'TOKENS: 512',
+ '[user] Hi! I have a question.',
+ '[assistant] What is your question?',
+ '[user] How big is the sun?',
+ ]);
+ });
+
+ test('forwards all messages, even those with unknown types', () async {
+ final dtdClient = testHarness.fakeEditorExtension.dtd;
+ final response = await dtdClient.call(
+ McpServiceConstants.serviceName,
+ McpServiceConstants.samplingRequest,
+ params: {
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': {
+ // Not of type text, image, audio, or resource.
+ 'type': 'unknown',
+ 'text': 'Hi there!',
+ 'data': 'Hi there!',
+ },
+ },
+ ],
+ 'maxTokens': 512,
+ },
+ );
+ expect(extractResponse(response), ['TOKENS: 512', 'UNKNOWN']);
+ });
+
+ test('throws for invalid requests', () async {
+ final dtdClient = testHarness.fakeEditorExtension.dtd;
+ try {
+ await dtdClient.call(
+ McpServiceConstants.serviceName,
+ McpServiceConstants.samplingRequest,
+ params: {
+ 'messages': [
+ {
+ 'role': 'dog', // Invalid role.
+ 'content': {
+ 'type': 'text',
+ 'text': 'Hi! I have a question.',
+ },
+ },
+ ],
+ 'maxTokens': 512,
+ },
+ );
+ fail('Expected an RpcException to be thrown.');
+ } catch (e) {
+ expect(e, isA<RpcException>());
+ }
+ });
+ });
+
group('dart cli tests', () {
test('can perform a hot reload', () async {
final exampleApp = await Directory.systemTemp.createTemp('dart_app');