Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type StreamableHTTPHandler struct {
onTransportDeletion func(sessionID string) // for testing

mu sync.Mutex
closed bool
sessions map[string]*sessionInfo // keyed by session ID
}

Expand Down Expand Up @@ -219,27 +220,30 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea
return h
}

// closeAll closes all ongoing sessions, for tests.
//
// TODO(rfindley): investigate the best API for callers to configure their
// session lifecycle. (?)
// Close closes the handler, closing and removing all connected sessions,
// and preventing new sessions from being added.
//
// Should we allow passing in a session store? That would allow the handler to
// be stateless.
func (h *StreamableHTTPHandler) closeAll() {
// TODO: if we ever expose this outside of tests, we'll need to do better
// than simply collecting sessions while holding the lock: we need to prevent
// new sessions from being added.
//
// Currently, sessions remove themselves from h.sessions when closed, so we
// can't call Close while holding the lock.
// Close is idempotent.
func (h *StreamableHTTPHandler) Close() error {
h.mu.Lock()
if h.closed {
h.mu.Unlock()
return nil
}
h.closed = true
sessionInfos := slices.Collect(maps.Values(h.sessions))
h.sessions = nil
h.sessions = make(map[string]*sessionInfo)
h.mu.Unlock()
for _, s := range sessionInfos {
s.session.Close()

for _, info := range sessionInfos {
info.session.Close()
}
return nil
}

// closeAll closes all ongoing sessions, for tests.
func (h *StreamableHTTPHandler) closeAll() {
_ = h.Close()
}

// disablelocalhostprotection is a compatibility parameter that allows to disable
Expand Down Expand Up @@ -308,6 +312,14 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
}
}

h.mu.Lock()
closed := h.closed
h.mu.Unlock()
if closed {
http.Error(w, "handler closed", http.StatusServiceUnavailable)
return
}

// [§2.7] of the spec (2025-06-18): validate the MCP-Protocol-Version
// header. If provided, it must be a supported version. If absent, the
// version is unknown (the request may be an initialize for any version).
Expand Down
54 changes: 54 additions & 0 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,60 @@ func TestStreamableServerDisconnect(t *testing.T) {
}
}

func TestStreamableHTTPHandlerClose(t *testing.T) {
server := NewServer(testImpl, nil)
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)
httpServer := httptest.NewServer(mustNotPanic(t, handler))
defer httpServer.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

client := NewClient(testImpl, nil)
// Pin a pre-SEP-2575 protocol version so Connect performs a single
// legacy initialize handshake (one session). Under the modern protocol the
// client first issues a stateless server/discover probe, which—when the
// server falls back to the legacy handshake—transiently leaves an extra
// session and would make the exact count below nondeterministic. This test
// exercises Close, not protocol negotiation.
clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, &ClientSessionOptions{protocolVersion: protocolVersion20251125})
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
t.Cleanup(func() { _ = clientSession.Close() })

handler.mu.Lock()
if len(handler.sessions) != 1 {
t.Fatalf("want 1 session before Close, got %d", len(handler.sessions))
}
handler.mu.Unlock()

if err := handler.Close(); err != nil {
t.Fatalf("Close() failed: %v", err)
}
if err := handler.Close(); err != nil {
t.Fatalf("second Close() failed: %v", err)
}

handler.mu.Lock()
if len(handler.sessions) != 0 {
t.Fatalf("want 0 sessions after Close, got %d", len(handler.sessions))
}
if !handler.closed {
t.Fatal("want handler.closed true after Close")
}
handler.mu.Unlock()

resp, err := http.Get(httpServer.URL)
if err != nil {
t.Fatalf("http.Get after Close failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("got status %d after Close, want %d", resp.StatusCode, http.StatusServiceUnavailable)
}
}

func TestServerTransportCleanup(t *testing.T) {
nClient := 3

Expand Down