Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@

package io.modelcontextprotocol.server;

import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA;

import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;

Expand All @@ -19,6 +16,7 @@
import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport;
import io.modelcontextprotocol.server.transport.TomcatTestUtil;
import io.modelcontextprotocol.spec.HttpHeaders;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.CompleteRequest;
Expand All @@ -43,8 +41,6 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
Expand All @@ -53,9 +49,12 @@
import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.APPLICATION_JSON;
import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.TEXT_EVENT_STREAM;
import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER;
import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA;
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson;
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.InstanceOfAssertFactories.type;
import static org.awaitility.Awaitility.await;

@Timeout(15)
Expand All @@ -67,7 +66,12 @@ class HttpServletStatelessIntegrationTests {

private HttpServletStatelessServerTransport mcpStatelessServerTransport;

ConcurrentHashMap<String, McpClient.SyncSpec> clientBuilders = new ConcurrentHashMap<>();
private final McpClient.SyncSpec clientBuilder = McpClient
.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT)
.endpoint(CUSTOM_MESSAGE_ENDPOINT)
.build())
.initializationTimeout(Duration.ofHours(10))
.requestTimeout(Duration.ofHours(10));

private Tomcat tomcat;

Expand All @@ -85,12 +89,6 @@ public void before() {
catch (Exception e) {
throw new RuntimeException("Failed to start Tomcat", e);
}

clientBuilders
.put("httpclient",
McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT)
.endpoint(CUSTOM_MESSAGE_ENDPOINT)
.build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10)));
}

@AfterEach
Expand All @@ -112,12 +110,8 @@ public void after() {
// ---------------------------------------
// Tools Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testToolCallSuccess(String clientType) {

var clientBuilder = clientBuilders.get(clientType);

@Test
void testToolCallSuccess() {
var callResponse = CallToolResult.builder()
.content(List.of(McpSchema.TextContent.builder("CALL RESPONSE").build()))
.isError(false)
Expand Down Expand Up @@ -158,12 +152,8 @@ void testToolCallSuccess(String clientType) {
}
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testInitialize(String clientType) {

var clientBuilder = clientBuilders.get(clientType);

@Test
void testInitialize() {
var mcpServer = McpServer.sync(mcpStatelessServerTransport).build();

try (var mcpClient = clientBuilder.build()) {
Expand All @@ -178,11 +168,8 @@ void testInitialize(String clientType) {
// ---------------------------------------
// Completion Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : Completion call")
@ValueSource(strings = { "httpclient" })
void testCompletionShouldReturnExpectedSuggestions(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

@Test
void testCompletionShouldReturnExpectedSuggestions() {
var expectedValues = List.of("python", "pytorch", "pyside");
var completionResponse = new CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total
true // hasMore
Expand Down Expand Up @@ -233,11 +220,8 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) {
}
}

@ParameterizedTest(name = "{0} : Completion call without matching handler")
@ValueSource(strings = { "httpclient" })
void testCompletionWithoutMatchingHandlerReturnsEmptyResult(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

@Test
void testCompletionWithoutMatchingHandlerReturnsEmptyResult() {
BiFunction<McpTransportContext, CompleteRequest, CompleteResult> completionHandler = (transportContext,
request) -> new CompleteResult(new CompleteResult.CompleteCompletion(List.of("java"), 1, false));

Expand Down Expand Up @@ -286,11 +270,8 @@ void testCompletionWithoutMatchingHandlerReturnsEmptyResult(String clientType) {
}
}

@ParameterizedTest(name = "{0} : Resource template completion call without matching handler")
@ValueSource(strings = { "httpclient" })
void testResourceTemplateCompletionWithoutMatchingHandlerReturnsEmptyResult(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

@Test
void testResourceTemplateCompletionWithoutMatchingHandlerReturnsEmptyResult() {
BiFunction<McpTransportContext, CompleteRequest, CompleteResult> completionHandler = (transportContext,
request) -> new CompleteResult(new CompleteResult.CompleteCompletion(List.of("java"), 1, false));

Expand Down Expand Up @@ -337,14 +318,62 @@ void testResourceTemplateCompletionWithoutMatchingHandlerReturnsEmptyResult(Stri
}
}

@Test
void testCompletionForNonExistentPromptReturnsInvalidParams() {
var mcpServer = McpServer.sync(mcpStatelessServerTransport)
.capabilities(ServerCapabilities.builder().completions().build())
.build();

try (var mcpClient = clientBuilder.build()) {
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

CompleteRequest request = CompleteRequest
.builder(new PromptReference("nonexistent-prompt"), new CompleteRequest.CompleteArgument("arg", "val"))
.build();

assertThatThrownBy(() -> mcpClient.completeCompletion(request)).isInstanceOf(McpError.class)
.asInstanceOf(type(McpError.class))
.extracting(McpError::getJsonRpcError)
.extracting(McpSchema.JSONRPCResponse.JSONRPCError::code)
.isEqualTo(ErrorCodes.INVALID_PARAMS);
}
finally {
mcpServer.close();
}
}

@Test
void testCompletionForNonExistentResourceReturnsResourceNotFound() {
var mcpServer = McpServer.sync(mcpStatelessServerTransport)
.capabilities(ServerCapabilities.builder().completions().build())
.build();

try (var mcpClient = clientBuilder.build()) {
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

CompleteRequest request = CompleteRequest
.builder(new ResourceReference("test://nonexistent/{param}"),
new CompleteRequest.CompleteArgument("param", "val"))
.build();

assertThatThrownBy(() -> mcpClient.completeCompletion(request)).isInstanceOf(McpError.class)
.asInstanceOf(type(McpError.class))
.extracting(McpError::getJsonRpcError)
.extracting(McpSchema.JSONRPCResponse.JSONRPCError::code)
.isEqualTo(McpSchema.ErrorCodes.RESOURCE_NOT_FOUND);
}
finally {
mcpServer.close();
}
}

// ---------------------------------------
// Tool Structured Output Schema Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testStructuredOutputValidationSuccess(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

@Test
void testStructuredOutputValidationSuccess() {
// Create a tool with output schema
Map<String, Object> outputSchema = Map.of(
"type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation",
Expand Down Expand Up @@ -409,11 +438,8 @@ void testStructuredOutputValidationSuccess(String clientType) {
}
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

@Test
void testStructuredOutputOfObjectArrayValidationSuccess() {
// Create a tool with output schema that returns an array of objects
Map<String, Object> outputSchema = Map
.of( // @formatter:off
Expand Down Expand Up @@ -470,11 +496,8 @@ void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) {
}
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testStructuredOutputWithInHandlerError(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

@Test
void testStructuredOutputWithInHandlerError() {
// Create a tool with output schema
Map<String, Object> outputSchema = Map.of(
"type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation",
Expand Down Expand Up @@ -528,11 +551,8 @@ void testStructuredOutputWithInHandlerError(String clientType) {
}
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testStructuredOutputValidationFailure(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

@Test
void testStructuredOutputValidationFailure() {
// Create a tool with output schema
Map<String, Object> outputSchema = Map.of("type", "object", "properties",
Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required",
Expand Down Expand Up @@ -580,11 +600,8 @@ void testStructuredOutputValidationFailure(String clientType) {
}
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testStructuredOutputMissingStructuredContent(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

@Test
void testStructuredOutputMissingStructuredContent() {
// Create a tool with output schema
Map<String, Object> outputSchema = Map.of("type", "object", "properties",
Map.of("result", Map.of("type", "number")), "required", List.of("result"));
Expand Down Expand Up @@ -629,11 +646,8 @@ void testStructuredOutputMissingStructuredContent(String clientType) {
}
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testStructuredOutputRuntimeToolAddition(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

@Test
void testStructuredOutputRuntimeToolAddition() {
// Start server without tools
var mcpServer = McpServer.sync(mcpStatelessServerTransport)
.serverInfo("test-server", "1.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,25 @@
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider;
import io.modelcontextprotocol.server.transport.TomcatTestUtil;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CompleteRequest;
import io.modelcontextprotocol.spec.McpSchema.CompleteResult;
import io.modelcontextprotocol.spec.McpSchema.ErrorCodes;
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
import io.modelcontextprotocol.spec.McpSchema.Prompt;
import io.modelcontextprotocol.spec.McpSchema.PromptArgument;
import io.modelcontextprotocol.spec.McpSchema.PromptReference;
import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult;
import io.modelcontextprotocol.spec.McpSchema.Resource;
import io.modelcontextprotocol.spec.McpSchema.ResourceReference;
import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate;
import io.modelcontextprotocol.spec.McpSchema.PromptReference;
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
import io.modelcontextprotocol.spec.McpError;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.InstanceOfAssertFactories.type;

/**
* Tests for completion functionality with context support.
Expand Down Expand Up @@ -273,6 +275,59 @@ void testResourceTemplateCompletionWithoutMatchingHandlerReturnsEmptyResult() {
mcpServer.close();
}

@Test
void testCompletionForNonExistentPromptReturnsInvalidParams() {
var mcpServer = McpServer.sync(mcpServerTransportProvider)
.capabilities(ServerCapabilities.builder().completions().build())
.build();

try (var mcpClient = clientBuilder
.clientInfo(McpSchema.Implementation.builder("Sample " + "client", "0.0.0").build())
.build()) {
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

CompleteRequest request = CompleteRequest
.builder(new PromptReference("nonexistent-prompt"), new CompleteRequest.CompleteArgument("arg", "val"))
.build();

assertThatThrownBy(() -> mcpClient.completeCompletion(request)).isInstanceOf(McpError.class)
.asInstanceOf(type(McpError.class))
.extracting(McpError::getJsonRpcError)
.extracting(McpSchema.JSONRPCResponse.JSONRPCError::code)
.isEqualTo(ErrorCodes.INVALID_PARAMS);
}

mcpServer.close();
}

@Test
void testCompletionForNonExistentResourceReturnsResourceNotFound() {
var mcpServer = McpServer.sync(mcpServerTransportProvider)
.capabilities(ServerCapabilities.builder().completions().build())
.build();

try (var mcpClient = clientBuilder
.clientInfo(McpSchema.Implementation.builder("Sample " + "client", "0.0.0").build())
.build()) {
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

CompleteRequest request = CompleteRequest
.builder(new ResourceReference("test://nonexistent/{param}"),
new CompleteRequest.CompleteArgument("param", "val"))
.build();

assertThatThrownBy(() -> mcpClient.completeCompletion(request)).isInstanceOf(McpError.class)
.asInstanceOf(type(McpError.class))
.extracting(McpError::getJsonRpcError)
.extracting(McpSchema.JSONRPCResponse.JSONRPCError::code)
.isEqualTo(McpSchema.ErrorCodes.RESOURCE_NOT_FOUND);
}

mcpServer.close();
}

@Test
void testDependentCompletionScenario() {
BiFunction<McpSyncServerExchange, CompleteRequest, CompleteResult> completionHandler = (exchange, request) -> {
Expand Down
Loading