From 8f9d3d6993cb0780acee0aceb97f592337dc6dd6 Mon Sep 17 00:00:00 2001 From: Piyush Jagadish Bag Date: Sat, 27 Jun 2026 17:31:48 +0000 Subject: [PATCH] mcp: add StreamableHTTPHandler.Close for graceful shutdown Expose public Close() to tear down all sessions and reject new requests with 503, matching the API proposed in #440. Co-authored-by: Piyush Bag --- mcp/streamable.go | 44 +++++++++++++++++++++------------- mcp/streamable_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 16 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index c047c156..3e46d80b 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -51,6 +51,7 @@ type StreamableHTTPHandler struct { onTransportDeletion func(sessionID string) // for testing mu sync.Mutex + closed bool sessions map[string]*sessionInfo // keyed by session ID } @@ -219,27 +220,30 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea return h } -// closeAll closes all ongoing sessions, for tests. -// -// TODO(rfindley): investigate the best API for callers to configure their -// session lifecycle. (?) +// Close closes the handler, closing and removing all connected sessions, +// and preventing new sessions from being added. // -// Should we allow passing in a session store? That would allow the handler to -// be stateless. -func (h *StreamableHTTPHandler) closeAll() { - // TODO: if we ever expose this outside of tests, we'll need to do better - // than simply collecting sessions while holding the lock: we need to prevent - // new sessions from being added. - // - // Currently, sessions remove themselves from h.sessions when closed, so we - // can't call Close while holding the lock. +// Close is idempotent. +func (h *StreamableHTTPHandler) Close() error { h.mu.Lock() + if h.closed { + h.mu.Unlock() + return nil + } + h.closed = true sessionInfos := slices.Collect(maps.Values(h.sessions)) - h.sessions = nil + h.sessions = make(map[string]*sessionInfo) h.mu.Unlock() - for _, s := range sessionInfos { - s.session.Close() + + for _, info := range sessionInfos { + info.session.Close() } + return nil +} + +// closeAll closes all ongoing sessions, for tests. +func (h *StreamableHTTPHandler) closeAll() { + _ = h.Close() } // disablelocalhostprotection is a compatibility parameter that allows to disable @@ -308,6 +312,14 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } } + h.mu.Lock() + closed := h.closed + h.mu.Unlock() + if closed { + http.Error(w, "handler closed", http.StatusServiceUnavailable) + return + } + // [§2.7] of the spec (2025-06-18): validate the MCP-Protocol-Version // header. If provided, it must be a supported version. If absent, the // version is unknown (the request may be an initialize for any version). diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index dff5e1f4..b13b48f2 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -641,6 +641,60 @@ func TestStreamableServerDisconnect(t *testing.T) { } } +func TestStreamableHTTPHandlerClose(t *testing.T) { + server := NewServer(testImpl, nil) + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + client := NewClient(testImpl, nil) + // Pin a pre-SEP-2575 protocol version so Connect performs a single + // legacy initialize handshake (one session). Under the modern protocol the + // client first issues a stateless server/discover probe, which—when the + // server falls back to the legacy handshake—transiently leaves an extra + // session and would make the exact count below nondeterministic. This test + // exercises Close, not protocol negotiation. + clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, &ClientSessionOptions{protocolVersion: protocolVersion20251125}) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + t.Cleanup(func() { _ = clientSession.Close() }) + + handler.mu.Lock() + if len(handler.sessions) != 1 { + t.Fatalf("want 1 session before Close, got %d", len(handler.sessions)) + } + handler.mu.Unlock() + + if err := handler.Close(); err != nil { + t.Fatalf("Close() failed: %v", err) + } + if err := handler.Close(); err != nil { + t.Fatalf("second Close() failed: %v", err) + } + + handler.mu.Lock() + if len(handler.sessions) != 0 { + t.Fatalf("want 0 sessions after Close, got %d", len(handler.sessions)) + } + if !handler.closed { + t.Fatal("want handler.closed true after Close") + } + handler.mu.Unlock() + + resp, err := http.Get(httpServer.URL) + if err != nil { + t.Fatalf("http.Get after Close failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("got status %d after Close, want %d", resp.StatusCode, http.StatusServiceUnavailable) + } +} + func TestServerTransportCleanup(t *testing.T) { nClient := 3