From 48df739c826ef064522ada2ef47520c948be5f37 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Wed, 18 Mar 2026 22:44:57 +0000 Subject: [PATCH] add MCP SSE transport for direct Claude Code connection Adds SSE endpoint at /mcp/sse so Claude Code can connect over HTTPS without needing a local Python bridge script. Benefits: - No local files or sandbox permission issues - Works from any machine (OpenClaw, DevVM) - No startup delay or stderr suppression hack - Auth via Bearer token in request headers --- pyproject.toml | 2 +- src/claude_memory/api/app.py | 212 ++++++++++++++++++++++++++++++++++- 2 files changed, 211 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 41dd3a6..9728689 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ classifiers = [ ] [project.optional-dependencies] -api = ["fastapi>=0.115", "asyncpg>=0.30", "uvicorn>=0.34", "pydantic>=2.0", "alembic>=1.14", "sqlalchemy>=2.0", "psycopg2-binary>=2.9"] +api = ["fastapi>=0.115", "asyncpg>=0.30", "uvicorn>=0.34", "pydantic>=2.0", "alembic>=1.14", "sqlalchemy>=2.0", "psycopg2-binary>=2.9", "mcp>=1.0.0"] vault = ["hvac>=2.0"] dev = ["pytest>=8.0", "pytest-asyncio>=0.24", "ruff>=0.8", "mypy>=1.13", "httpx>=0.28", "cryptography>=43.0"] diff --git a/src/claude_memory/api/app.py b/src/claude_memory/api/app.py index be4d037..2c2c75c 100644 --- a/src/claude_memory/api/app.py +++ b/src/claude_memory/api/app.py @@ -1,13 +1,19 @@ """Claude Memory API -- shared persistent memory with PostgreSQL full-text search.""" +import json import logging from contextlib import asynccontextmanager from datetime import datetime, timezone from typing import Any, AsyncGenerator, Optional -from fastapi import Depends, FastAPI, HTTPException +from fastapi import Depends, FastAPI, HTTPException, Request +from fastapi.responses import Response +from mcp.server.fastmcp import FastMCP +from mcp.server.sse import SseServerTransport +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.routing import Mount, Route -from claude_memory.api.auth import AuthUser, get_current_user +from claude_memory.api.auth import AuthUser, get_current_user, _key_to_user from claude_memory.api.database import close_pool, get_pool, init_pool from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse, SyncResponse from claude_memory.api.vault_service import ( @@ -431,3 +437,205 @@ async def import_memories( ) return imported + + +# --- MCP SSE Transport --- + + +def _resolve_user_from_token(token: str) -> str | None: + """Resolve API key to user_id, reusing auth module's key map.""" + return _key_to_user.get(token) + + +mcp_server = FastMCP("claude-memory") + + +@mcp_server.tool() +async def memory_store(content: str, category: str = "facts", tags: str = "", + expanded_keywords: str = "", importance: float = 0.5) -> str: + """Store a new memory.""" + pool = await get_pool() + # MCP SSE uses "default" user (single-user mode via middleware auth) + user_id = "default" + is_sensitive = _detect_sensitive(content) + stored_content = content if not is_sensitive else _redact_content(content) + + async with pool.acquire() as conn: + row = await conn.fetchrow( + """INSERT INTO memories (user_id, content, category, tags, expanded_keywords, importance, is_sensitive) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id""", + user_id, stored_content, category, tags, expanded_keywords, importance, is_sensitive, + ) + memory_id = row["id"] + + if is_sensitive and is_vault_configured(): + vault_path = await store_secret(user_id, memory_id, content) + await conn.execute("UPDATE memories SET vault_path = $1 WHERE id = $2", vault_path, memory_id) + + return json.dumps({"id": memory_id, "category": category, "importance": importance}) + + +@mcp_server.tool() +async def memory_recall(context: str, expanded_query: str = "", + category: str | None = None, sort_by: str = "importance", + limit: int = 10) -> str: + """Recall memories by semantic search.""" + pool = await get_pool() + user_id = "default" + query_text = f"{context} {expanded_query}".strip() + if not query_text: + return json.dumps({"error": "context is required"}) + + hybrid_score = "(ts_rank(search_vector, query) * 0.7 + importance * 0.3)" + if sort_by == "importance": + hybrid_score = "(ts_rank(search_vector, query) * 0.4 + importance * 0.6)" + + order_clause = f"{hybrid_score} DESC" + if sort_by == "recency": + order_clause = "created_at DESC" + + category_filter = "" + params: list[Any] = [user_id, query_text, limit] + if category: + category_filter = "AND category = $4" + params.append(category) + + async with pool.acquire() as conn: + rows = await conn.fetch( + f""" + SELECT id, content, category, tags, importance, is_sensitive, + ts_rank(search_vector, query) AS rank, created_at, updated_at + FROM memories, plainto_tsquery('english', $2) query + WHERE user_id = $1 AND deleted_at IS NULL + AND (search_vector @@ query OR $2 = '') + {category_filter} + ORDER BY {order_clause} + LIMIT $3 + """, + *params, + ) + + results = [] + for row in rows: + c = row["content"] + if row["is_sensitive"]: + c = f"[SENSITIVE - use secret_get(id={row['id']})]" + results.append({ + "id": row["id"], "content": c, "category": row["category"], + "tags": row["tags"], "importance": row["importance"], + "rank": float(row["rank"]), + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + }) + + return json.dumps({"memories": results}) + + +@mcp_server.tool() +async def memory_list(category: str | None = None, limit: int = 20) -> str: + """List stored memories.""" + pool = await get_pool() + user_id = "default" + + if category: + query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at + FROM memories WHERE user_id = $1 AND deleted_at IS NULL AND category = $2 + ORDER BY importance DESC LIMIT $3""" + params: list[Any] = [user_id, category, limit] + else: + query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at + FROM memories WHERE user_id = $1 AND deleted_at IS NULL + ORDER BY importance DESC LIMIT $2""" + params = [user_id, limit] + + async with pool.acquire() as conn: + rows = await conn.fetch(query, *params) + + results = [] + for row in rows: + c = row["content"] + if row["is_sensitive"]: + c = f"[SENSITIVE - use secret_get(id={row['id']})]" + results.append({ + "id": row["id"], "content": c, "category": row["category"], + "tags": row["tags"], "importance": row["importance"], + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + }) + + return json.dumps({"memories": results}) + + +@mcp_server.tool() +async def memory_delete(memory_id: int) -> str: + """Delete a memory by ID.""" + pool = await get_pool() + user_id = "default" + + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT id, vault_path, substr(content, 1, 50) AS preview FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL", + memory_id, user_id, + ) + if not row: + return json.dumps({"deleted": memory_id, "preview": "[already deleted]"}) + + if row["vault_path"]: + await delete_secret(user_id, row["vault_path"]) + + await conn.execute( + "UPDATE memories SET deleted_at = NOW(), updated_at = NOW() WHERE id = $1 AND user_id = $2", + memory_id, user_id, + ) + + return json.dumps({"deleted": memory_id, "preview": row["preview"]}) + + +@mcp_server.tool() +async def memory_count() -> str: + """Count total memories.""" + pool = await get_pool() + user_id = "default" + async with pool.acquire() as conn: + count = await conn.fetchval("SELECT COUNT(*) FROM memories WHERE user_id = $1 AND deleted_at IS NULL", user_id) + return json.dumps({"count": count}) + + +@mcp_server.tool() +async def secret_get(key: str) -> str: + """Retrieve a secret value by key. Returns empty if not found.""" + return json.dumps({"error": "secret_get is not available via SSE transport"}) + + +# Auth middleware for /mcp/* routes +class MCPAuthMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): # type: ignore[override] + if request.url.path.startswith("/mcp"): + auth = request.headers.get("authorization", "") + token = auth.removeprefix("Bearer ").strip() + if not _resolve_user_from_token(token): + return Response(content="Unauthorized", status_code=401) + return await call_next(request) + + +app.add_middleware(MCPAuthMiddleware) + +# Mount SSE transport +sse_transport = SseServerTransport("/messages/") + + +class HandleSSE: + """ASGI app for SSE connections.""" + async def __call__(self, scope: Any, receive: Any, send: Any) -> None: + async with sse_transport.connect_sse(scope, receive, send) as (read_stream, write_stream): + await mcp_server._mcp_server.run( + read_stream, write_stream, mcp_server._mcp_server.create_initialization_options() + ) + + +# Client connects to /mcp/sse, posts to /mcp/messages/ +app.router.routes.insert(0, Mount("/mcp", routes=[ + Route("/sse", endpoint=HandleSSE()), + Mount("/messages", app=sse_transport.handle_post_message), +]))