add MCP SSE transport for direct Claude Code connection

Adds SSE endpoint at /mcp/sse so Claude Code can connect over HTTPS
without needing a local Python bridge script. Benefits:
- No local files or sandbox permission issues
- Works from any machine (OpenClaw, DevVM)
- No startup delay or stderr suppression hack
- Auth via Bearer token in request headers
This commit is contained in:
Viktor Barzin 2026-03-18 22:44:57 +00:00
parent 18e27d07d2
commit 48df739c82
No known key found for this signature in database
GPG key ID: 0EB088298288D958
2 changed files with 211 additions and 3 deletions

View file

@ -16,7 +16,7 @@ classifiers = [
]
[project.optional-dependencies]
api = ["fastapi>=0.115", "asyncpg>=0.30", "uvicorn>=0.34", "pydantic>=2.0", "alembic>=1.14", "sqlalchemy>=2.0", "psycopg2-binary>=2.9"]
api = ["fastapi>=0.115", "asyncpg>=0.30", "uvicorn>=0.34", "pydantic>=2.0", "alembic>=1.14", "sqlalchemy>=2.0", "psycopg2-binary>=2.9", "mcp>=1.0.0"]
vault = ["hvac>=2.0"]
dev = ["pytest>=8.0", "pytest-asyncio>=0.24", "ruff>=0.8", "mypy>=1.13", "httpx>=0.28", "cryptography>=43.0"]

View file

@ -1,13 +1,19 @@
"""Claude Memory API -- shared persistent memory with PostgreSQL full-text search."""
import json
import logging
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import Any, AsyncGenerator, Optional
from fastapi import Depends, FastAPI, HTTPException
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import Response
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Mount, Route
from claude_memory.api.auth import AuthUser, get_current_user
from claude_memory.api.auth import AuthUser, get_current_user, _key_to_user
from claude_memory.api.database import close_pool, get_pool, init_pool
from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse, SyncResponse
from claude_memory.api.vault_service import (
@ -431,3 +437,205 @@ async def import_memories(
)
return imported
# --- MCP SSE Transport ---
def _resolve_user_from_token(token: str) -> str | None:
"""Resolve API key to user_id, reusing auth module's key map."""
return _key_to_user.get(token)
mcp_server = FastMCP("claude-memory")
@mcp_server.tool()
async def memory_store(content: str, category: str = "facts", tags: str = "",
expanded_keywords: str = "", importance: float = 0.5) -> str:
"""Store a new memory."""
pool = await get_pool()
# MCP SSE uses "default" user (single-user mode via middleware auth)
user_id = "default"
is_sensitive = _detect_sensitive(content)
stored_content = content if not is_sensitive else _redact_content(content)
async with pool.acquire() as conn:
row = await conn.fetchrow(
"""INSERT INTO memories (user_id, content, category, tags, expanded_keywords, importance, is_sensitive)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id""",
user_id, stored_content, category, tags, expanded_keywords, importance, is_sensitive,
)
memory_id = row["id"]
if is_sensitive and is_vault_configured():
vault_path = await store_secret(user_id, memory_id, content)
await conn.execute("UPDATE memories SET vault_path = $1 WHERE id = $2", vault_path, memory_id)
return json.dumps({"id": memory_id, "category": category, "importance": importance})
@mcp_server.tool()
async def memory_recall(context: str, expanded_query: str = "",
category: str | None = None, sort_by: str = "importance",
limit: int = 10) -> str:
"""Recall memories by semantic search."""
pool = await get_pool()
user_id = "default"
query_text = f"{context} {expanded_query}".strip()
if not query_text:
return json.dumps({"error": "context is required"})
hybrid_score = "(ts_rank(search_vector, query) * 0.7 + importance * 0.3)"
if sort_by == "importance":
hybrid_score = "(ts_rank(search_vector, query) * 0.4 + importance * 0.6)"
order_clause = f"{hybrid_score} DESC"
if sort_by == "recency":
order_clause = "created_at DESC"
category_filter = ""
params: list[Any] = [user_id, query_text, limit]
if category:
category_filter = "AND category = $4"
params.append(category)
async with pool.acquire() as conn:
rows = await conn.fetch(
f"""
SELECT id, content, category, tags, importance, is_sensitive,
ts_rank(search_vector, query) AS rank, created_at, updated_at
FROM memories, plainto_tsquery('english', $2) query
WHERE user_id = $1 AND deleted_at IS NULL
AND (search_vector @@ query OR $2 = '')
{category_filter}
ORDER BY {order_clause}
LIMIT $3
""",
*params,
)
results = []
for row in rows:
c = row["content"]
if row["is_sensitive"]:
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
results.append({
"id": row["id"], "content": c, "category": row["category"],
"tags": row["tags"], "importance": row["importance"],
"rank": float(row["rank"]),
"created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(),
})
return json.dumps({"memories": results})
@mcp_server.tool()
async def memory_list(category: str | None = None, limit: int = 20) -> str:
"""List stored memories."""
pool = await get_pool()
user_id = "default"
if category:
query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
FROM memories WHERE user_id = $1 AND deleted_at IS NULL AND category = $2
ORDER BY importance DESC LIMIT $3"""
params: list[Any] = [user_id, category, limit]
else:
query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
FROM memories WHERE user_id = $1 AND deleted_at IS NULL
ORDER BY importance DESC LIMIT $2"""
params = [user_id, limit]
async with pool.acquire() as conn:
rows = await conn.fetch(query, *params)
results = []
for row in rows:
c = row["content"]
if row["is_sensitive"]:
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
results.append({
"id": row["id"], "content": c, "category": row["category"],
"tags": row["tags"], "importance": row["importance"],
"created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(),
})
return json.dumps({"memories": results})
@mcp_server.tool()
async def memory_delete(memory_id: int) -> str:
"""Delete a memory by ID."""
pool = await get_pool()
user_id = "default"
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT id, vault_path, substr(content, 1, 50) AS preview FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
memory_id, user_id,
)
if not row:
return json.dumps({"deleted": memory_id, "preview": "[already deleted]"})
if row["vault_path"]:
await delete_secret(user_id, row["vault_path"])
await conn.execute(
"UPDATE memories SET deleted_at = NOW(), updated_at = NOW() WHERE id = $1 AND user_id = $2",
memory_id, user_id,
)
return json.dumps({"deleted": memory_id, "preview": row["preview"]})
@mcp_server.tool()
async def memory_count() -> str:
"""Count total memories."""
pool = await get_pool()
user_id = "default"
async with pool.acquire() as conn:
count = await conn.fetchval("SELECT COUNT(*) FROM memories WHERE user_id = $1 AND deleted_at IS NULL", user_id)
return json.dumps({"count": count})
@mcp_server.tool()
async def secret_get(key: str) -> str:
"""Retrieve a secret value by key. Returns empty if not found."""
return json.dumps({"error": "secret_get is not available via SSE transport"})
# Auth middleware for /mcp/* routes
class MCPAuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): # type: ignore[override]
if request.url.path.startswith("/mcp"):
auth = request.headers.get("authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not _resolve_user_from_token(token):
return Response(content="Unauthorized", status_code=401)
return await call_next(request)
app.add_middleware(MCPAuthMiddleware)
# Mount SSE transport
sse_transport = SseServerTransport("/messages/")
class HandleSSE:
"""ASGI app for SSE connections."""
async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
async with sse_transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
await mcp_server._mcp_server.run(
read_stream, write_stream, mcp_server._mcp_server.create_initialization_options()
)
# Client connects to /mcp/sse, posts to /mcp/messages/
app.router.routes.insert(0, Mount("/mcp", routes=[
Route("/sse", endpoint=HandleSSE()),
Mount("/messages", app=sse_transport.handle_post_message),
]))