diff --git a/src/claude_memory/api/app.py b/src/claude_memory/api/app.py index cd8140a..c912dd2 100644 --- a/src/claude_memory/api/app.py +++ b/src/claude_memory/api/app.py @@ -10,13 +10,13 @@ from contextvars import ContextVar from datetime import datetime, timezone from typing import Any, AsyncGenerator, Optional -from fastapi import Depends, FastAPI, HTTPException, Request +from fastapi import Depends, FastAPI, HTTPException from fastapi.responses import Response from fastapi.staticfiles import StaticFiles from mcp.server.fastmcp import FastMCP from mcp.server.sse import SseServerTransport -from starlette.middleware.base import BaseHTTPMiddleware from starlette.routing import Mount, Route +from starlette.types import ASGIApp, Receive, Scope, Send from claude_memory.api.auth import AuthUser, get_current_user, _key_to_user from claude_memory.api.database import close_pool, get_pool, init_pool @@ -1253,16 +1253,22 @@ async def memory_update(id: int, content: str | None = None, tags: str | None = return json.dumps({"updated": id}) -# Auth middleware for /mcp/* routes -class MCPAuthMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: Any) -> Response: - if request.url.path.startswith("/mcp"): - auth = request.headers.get("authorization", "") +# Auth middleware for /mcp/* routes — pure ASGI to avoid BaseHTTPMiddleware +# buffering which breaks SSE streaming (responses never reach the client). +class MCPAuthMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["path"].startswith("/mcp"): + headers = dict(scope.get("headers", [])) + auth = headers.get(b"authorization", b"").decode() token = auth.removeprefix("Bearer ").strip() if not _resolve_user_from_token(token): - return Response(content="Unauthorized", status_code=401) - response: Response = await call_next(request) - return response + response = Response(content="Unauthorized", status_code=401) + await response(scope, receive, send) + return + await self.app(scope, receive, send) app.add_middleware(MCPAuthMiddleware)