feat: multi-user MCP SSE support + shared memories in recall/list

- Use contextvars to resolve user identity from Authorization header
  in SSE connections, replacing hardcoded "default" user_id
- memory_recall now includes shared memories (individual + tag-based)
  with deduplication and shared_by attribution
- memory_list now includes shared memories with same approach
- All 11 MCP tool functions use _current_user contextvar
This commit is contained in:
Viktor Barzin 2026-03-22 22:50:18 +02:00
parent 95dd937765
commit f242c45c73
No known key found for this signature in database
GPG key ID: 0EB088298288D958

View file

@ -4,6 +4,7 @@ import json
import logging import logging
import pathlib import pathlib
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from contextvars import ContextVar
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, AsyncGenerator, Optional from typing import Any, AsyncGenerator, Optional
@ -31,6 +32,9 @@ from claude_memory.api.vault_service import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Context variable for MCP SSE multi-user support
_current_user: ContextVar[str] = ContextVar("_current_user", default="default")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
@ -846,8 +850,7 @@ async def memory_store(content: str, category: str = "facts", tags: str = "",
expanded_keywords: str = "", importance: float = 0.5) -> str: expanded_keywords: str = "", importance: float = 0.5) -> str:
"""Store a new memory.""" """Store a new memory."""
pool = await get_pool() pool = await get_pool()
# MCP SSE uses "default" user (single-user mode via middleware auth) user_id = _current_user.get()
user_id = "default"
is_sensitive = _detect_sensitive(content) is_sensitive = _detect_sensitive(content)
stored_content = content if not is_sensitive else _redact_content(content) stored_content = content if not is_sensitive else _redact_content(content)
@ -873,7 +876,7 @@ async def memory_recall(context: str, expanded_query: str = "",
limit: int = 10) -> str: limit: int = 10) -> str:
"""Recall memories by semantic search.""" """Recall memories by semantic search."""
pool = await get_pool() pool = await get_pool()
user_id = "default" user_id = _current_user.get()
query_text = f"{context} {expanded_query}".strip() query_text = f"{context} {expanded_query}".strip()
if not query_text: if not query_text:
return json.dumps({"error": "context is required"}) return json.dumps({"error": "context is required"})
@ -896,7 +899,8 @@ async def memory_recall(context: str, expanded_query: str = "",
rows = await conn.fetch( rows = await conn.fetch(
f""" f"""
SELECT id, content, category, tags, importance, is_sensitive, SELECT id, content, category, tags, importance, is_sensitive,
ts_rank(search_vector, query) AS rank, created_at, updated_at ts_rank(search_vector, query) AS rank, created_at, updated_at,
NULL::text AS shared_by
FROM memories, plainto_tsquery('english', $2) query FROM memories, plainto_tsquery('english', $2) query
WHERE user_id = $1 AND deleted_at IS NULL WHERE user_id = $1 AND deleted_at IS NULL
AND (search_vector @@ query OR $2 = '') AND (search_vector @@ query OR $2 = '')
@ -907,8 +911,50 @@ async def memory_recall(context: str, expanded_query: str = "",
*params, *params,
) )
# Also fetch shared memories (individual + tag-based)
shared_rows = await conn.fetch(
f"""
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance,
m.is_sensitive, ts_rank(m.search_vector, query) AS rank,
m.created_at, m.updated_at, m.user_id AS shared_by
FROM memories m, plainto_tsquery('english', $2) query
WHERE m.deleted_at IS NULL
AND (m.search_vector @@ query OR $2 = '')
AND m.user_id != $1
AND (
EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1)
OR EXISTS (
SELECT 1 FROM tag_shares ts
WHERE ts.owner_id = m.user_id AND ts.shared_with = $1
AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag)
)
)
ORDER BY m.id
LIMIT $3
""",
*params,
)
seen_ids = set()
results = [] results = []
for row in rows: for row in rows:
seen_ids.add(row["id"])
c = row["content"]
if row["is_sensitive"]:
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
entry: dict[str, Any] = {
"id": row["id"], "content": c, "category": row["category"],
"tags": row["tags"], "importance": row["importance"],
"rank": float(row["rank"]),
"created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(),
}
results.append(entry)
for row in shared_rows:
if row["id"] in seen_ids:
continue
seen_ids.add(row["id"])
c = row["content"] c = row["content"]
if row["is_sensitive"]: if row["is_sensitive"]:
c = f"[SENSITIVE - use secret_get(id={row['id']})]" c = f"[SENSITIVE - use secret_get(id={row['id']})]"
@ -916,6 +962,7 @@ async def memory_recall(context: str, expanded_query: str = "",
"id": row["id"], "content": c, "category": row["category"], "id": row["id"], "content": c, "category": row["category"],
"tags": row["tags"], "importance": row["importance"], "tags": row["tags"], "importance": row["importance"],
"rank": float(row["rank"]), "rank": float(row["rank"]),
"shared_by": row["shared_by"],
"created_at": row["created_at"].isoformat(), "created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(), "updated_at": row["updated_at"].isoformat(),
}) })
@ -927,7 +974,7 @@ async def memory_recall(context: str, expanded_query: str = "",
async def memory_list(category: str | None = None, limit: int = 20) -> str: async def memory_list(category: str | None = None, limit: int = 20) -> str:
"""List stored memories.""" """List stored memories."""
pool = await get_pool() pool = await get_pool()
user_id = "default" user_id = _current_user.get()
if category: if category:
query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
@ -940,11 +987,47 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str:
ORDER BY importance DESC LIMIT $2""" ORDER BY importance DESC LIMIT $2"""
params = [user_id, limit] params = [user_id, limit]
if category:
shared_query = """
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance,
m.is_sensitive, m.created_at, m.updated_at, m.user_id AS shared_by
FROM memories m
WHERE m.deleted_at IS NULL AND m.category = $2 AND m.user_id != $1
AND (
EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1)
OR EXISTS (
SELECT 1 FROM tag_shares ts
WHERE ts.owner_id = m.user_id AND ts.shared_with = $1
AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag)
)
)
ORDER BY m.id LIMIT $3"""
shared_params: list[Any] = [user_id, category, limit]
else:
shared_query = """
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance,
m.is_sensitive, m.created_at, m.updated_at, m.user_id AS shared_by
FROM memories m
WHERE m.deleted_at IS NULL AND m.user_id != $1
AND (
EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1)
OR EXISTS (
SELECT 1 FROM tag_shares ts
WHERE ts.owner_id = m.user_id AND ts.shared_with = $1
AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag)
)
)
ORDER BY m.id LIMIT $2"""
shared_params = [user_id, limit]
async with pool.acquire() as conn: async with pool.acquire() as conn:
rows = await conn.fetch(query, *params) rows = await conn.fetch(query, *params)
shared_rows = await conn.fetch(shared_query, *shared_params)
seen_ids = set()
results = [] results = []
for row in rows: for row in rows:
seen_ids.add(row["id"])
c = row["content"] c = row["content"]
if row["is_sensitive"]: if row["is_sensitive"]:
c = f"[SENSITIVE - use secret_get(id={row['id']})]" c = f"[SENSITIVE - use secret_get(id={row['id']})]"
@ -955,6 +1038,21 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str:
"updated_at": row["updated_at"].isoformat(), "updated_at": row["updated_at"].isoformat(),
}) })
for row in shared_rows:
if row["id"] in seen_ids:
continue
seen_ids.add(row["id"])
c = row["content"]
if row["is_sensitive"]:
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
results.append({
"id": row["id"], "content": c, "category": row["category"],
"tags": row["tags"], "importance": row["importance"],
"shared_by": row["shared_by"],
"created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(),
})
return json.dumps({"memories": results}) return json.dumps({"memories": results})
@ -962,7 +1060,7 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str:
async def memory_delete(memory_id: int) -> str: async def memory_delete(memory_id: int) -> str:
"""Delete a memory by ID.""" """Delete a memory by ID."""
pool = await get_pool() pool = await get_pool()
user_id = "default" user_id = _current_user.get()
async with pool.acquire() as conn: async with pool.acquire() as conn:
row = await conn.fetchrow( row = await conn.fetchrow(
@ -987,7 +1085,7 @@ async def memory_delete(memory_id: int) -> str:
async def memory_count() -> str: async def memory_count() -> str:
"""Count total memories.""" """Count total memories."""
pool = await get_pool() pool = await get_pool()
user_id = "default" user_id = _current_user.get()
async with pool.acquire() as conn: async with pool.acquire() as conn:
count = await conn.fetchval("SELECT COUNT(*) FROM memories WHERE user_id = $1 AND deleted_at IS NULL", user_id) count = await conn.fetchval("SELECT COUNT(*) FROM memories WHERE user_id = $1 AND deleted_at IS NULL", user_id)
return json.dumps({"count": count}) return json.dumps({"count": count})
@ -1005,7 +1103,7 @@ async def memory_share(id: int, shared_with: str, permission: str = "read") -> s
if permission not in ("read", "write"): if permission not in ("read", "write"):
return json.dumps({"error": "permission must be 'read' or 'write'"}) return json.dumps({"error": "permission must be 'read' or 'write'"})
pool = await get_pool() pool = await get_pool()
user_id = "default" user_id = _current_user.get()
async with pool.acquire() as conn: async with pool.acquire() as conn:
row = await conn.fetchrow( row = await conn.fetchrow(
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL", "SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
@ -1026,7 +1124,7 @@ async def memory_share(id: int, shared_with: str, permission: str = "read") -> s
async def memory_unshare(id: int, shared_with: str) -> str: async def memory_unshare(id: int, shared_with: str) -> str:
"""Revoke sharing of a memory from a user.""" """Revoke sharing of a memory from a user."""
pool = await get_pool() pool = await get_pool()
user_id = "default" user_id = _current_user.get()
async with pool.acquire() as conn: async with pool.acquire() as conn:
await conn.execute( await conn.execute(
"DELETE FROM memory_shares WHERE memory_id = $1 AND owner_id = $2 AND shared_with = $3", "DELETE FROM memory_shares WHERE memory_id = $1 AND owner_id = $2 AND shared_with = $3",
@ -1041,7 +1139,7 @@ async def memory_share_tag(tag: str, shared_with: str, permission: str = "read")
if permission not in ("read", "write"): if permission not in ("read", "write"):
return json.dumps({"error": "permission must be 'read' or 'write'"}) return json.dumps({"error": "permission must be 'read' or 'write'"})
pool = await get_pool() pool = await get_pool()
user_id = "default" user_id = _current_user.get()
async with pool.acquire() as conn: async with pool.acquire() as conn:
await conn.execute( await conn.execute(
"""INSERT INTO tag_shares (owner_id, tag, shared_with, permission) """INSERT INTO tag_shares (owner_id, tag, shared_with, permission)
@ -1056,7 +1154,7 @@ async def memory_share_tag(tag: str, shared_with: str, permission: str = "read")
async def memory_unshare_tag(tag: str, shared_with: str) -> str: async def memory_unshare_tag(tag: str, shared_with: str) -> str:
"""Revoke tag-based sharing.""" """Revoke tag-based sharing."""
pool = await get_pool() pool = await get_pool()
user_id = "default" user_id = _current_user.get()
async with pool.acquire() as conn: async with pool.acquire() as conn:
await conn.execute( await conn.execute(
"DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3", "DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3",
@ -1070,7 +1168,7 @@ async def memory_update(id: int, content: str | None = None, tags: str | None =
importance: float | None = None, expanded_keywords: str | None = None) -> str: importance: float | None = None, expanded_keywords: str | None = None) -> str:
"""Update an existing memory's content, tags, importance, or keywords.""" """Update an existing memory's content, tags, importance, or keywords."""
pool = await get_pool() pool = await get_pool()
user_id = "default" user_id = _current_user.get()
async with pool.acquire() as conn: async with pool.acquire() as conn:
allowed, owner_id = await check_memory_permission(conn, id, user_id, "write") allowed, owner_id = await check_memory_permission(conn, id, user_id, "write")
if not allowed: if not allowed:
@ -1132,6 +1230,16 @@ sse_transport = SseServerTransport("/messages/")
class HandleSSE: class HandleSSE:
"""ASGI app for SSE connections.""" """ASGI app for SSE connections."""
async def __call__(self, scope: Any, receive: Any, send: Any) -> None: async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
# Extract user from Authorization header for multi-user MCP
user_id = "default"
for name, value in scope.get("headers", []):
if name == b"authorization":
token = value.decode().removeprefix("Bearer ").strip()
resolved = _resolve_user_from_token(token)
if resolved:
user_id = resolved
break
_current_user.set(user_id)
async with sse_transport.connect_sse(scope, receive, send) as (read_stream, write_stream): async with sse_transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
await mcp_server._mcp_server.run( await mcp_server._mcp_server.run(
read_stream, write_stream, mcp_server._mcp_server.create_initialization_options() read_stream, write_stream, mcp_server._mcp_server.create_initialization_options()