diff --git a/src/claude_memory/api/app.py b/src/claude_memory/api/app.py index 69096da..f1fdc15 100644 --- a/src/claude_memory/api/app.py +++ b/src/claude_memory/api/app.py @@ -15,7 +15,7 @@ from fastapi.responses import Response from fastapi.staticfiles import StaticFiles from mcp.server.fastmcp import FastMCP from mcp.server.sse import SseServerTransport -from mcp.server.streamable_http import StreamableHTTPServerTransport +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.routing import Mount, Route from starlette.types import ASGIApp, Receive, Scope, Send @@ -42,7 +42,8 @@ _current_user: ContextVar[str] = ContextVar("_current_user", default="default") @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await init_pool() - yield + async with streamable_session_mgr.run(): + yield await close_pool() @@ -1297,12 +1298,17 @@ class HandleSSE: ) -# Streamable HTTP transport — stateless, no persistent SSE connection needed. -# Each request carries its own init+tool call. More reliable through proxies. +# Streamable HTTP transport — session manager handles lifecycle automatically. +# More reliable through proxies than SSE since responses come in HTTP body. +streamable_session_mgr = StreamableHTTPSessionManager( + app=mcp_server._mcp_server, + json_response=True, + stateless=True, +) + + class HandleStreamableHTTP: - """ASGI app for streamable-http MCP connections.""" - def __init__(self) -> None: - self._transport: StreamableHTTPServerTransport | None = None + """ASGI wrapper that sets _current_user before delegating to the session manager.""" async def __call__(self, scope: Any, receive: Any, send: Any) -> None: user_id = "default" @@ -1314,29 +1320,7 @@ class HandleStreamableHTTP: user_id = resolved break _current_user.set(user_id) - - session_id = None - for name, value in scope.get("headers", []): - if name == b"mcp-session-id": - session_id = value.decode() - break - - transport = StreamableHTTPServerTransport( - mcp_session_id=session_id, - is_json_response_enabled=True, - ) - async with transport.connect() as (read_stream, write_stream): - import anyio - async with anyio.create_task_group() as tg: - async def run_server() -> None: - await mcp_server._mcp_server.run( - read_stream, write_stream, - mcp_server._mcp_server.create_initialization_options(), - ) - - tg.start_soon(run_server) - await transport.handle_request(scope, receive, send) - tg.cancel_scope.cancel() + await streamable_session_mgr.handle_request(scope, receive, send) streamable_handler = HandleStreamableHTTP()