feat: multi-user MCP SSE support + shared memories in recall/list

- Use contextvars to resolve user identity from Authorization header
  in SSE connections, replacing hardcoded "default" user_id
- memory_recall now includes shared memories (individual + tag-based)
  with deduplication and shared_by attribution
- memory_list now includes shared memories with same approach
- All 11 MCP tool functions use _current_user contextvar
This commit is contained in:
Viktor Barzin 2026-03-22 22:50:18 +02:00
parent 95dd937765
commit f242c45c73
No known key found for this signature in database
GPG key ID: 0EB088298288D958

View file

@ -4,6 +4,7 @@ import json
import logging
import pathlib
from contextlib import asynccontextmanager
from contextvars import ContextVar
from datetime import datetime, timezone
from typing import Any, AsyncGenerator, Optional
@ -31,6 +32,9 @@ from claude_memory.api.vault_service import (
logger = logging.getLogger(__name__)
# Context variable for MCP SSE multi-user support
_current_user: ContextVar[str] = ContextVar("_current_user", default="default")
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
@ -846,8 +850,7 @@ async def memory_store(content: str, category: str = "facts", tags: str = "",
expanded_keywords: str = "", importance: float = 0.5) -> str:
"""Store a new memory."""
pool = await get_pool()
# MCP SSE uses "default" user (single-user mode via middleware auth)
user_id = "default"
user_id = _current_user.get()
is_sensitive = _detect_sensitive(content)
stored_content = content if not is_sensitive else _redact_content(content)
@ -873,7 +876,7 @@ async def memory_recall(context: str, expanded_query: str = "",
limit: int = 10) -> str:
"""Recall memories by semantic search."""
pool = await get_pool()
user_id = "default"
user_id = _current_user.get()
query_text = f"{context} {expanded_query}".strip()
if not query_text:
return json.dumps({"error": "context is required"})
@ -896,7 +899,8 @@ async def memory_recall(context: str, expanded_query: str = "",
rows = await conn.fetch(
f"""
SELECT id, content, category, tags, importance, is_sensitive,
ts_rank(search_vector, query) AS rank, created_at, updated_at
ts_rank(search_vector, query) AS rank, created_at, updated_at,
NULL::text AS shared_by
FROM memories, plainto_tsquery('english', $2) query
WHERE user_id = $1 AND deleted_at IS NULL
AND (search_vector @@ query OR $2 = '')
@ -907,8 +911,50 @@ async def memory_recall(context: str, expanded_query: str = "",
*params,
)
# Also fetch shared memories (individual + tag-based)
shared_rows = await conn.fetch(
f"""
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance,
m.is_sensitive, ts_rank(m.search_vector, query) AS rank,
m.created_at, m.updated_at, m.user_id AS shared_by
FROM memories m, plainto_tsquery('english', $2) query
WHERE m.deleted_at IS NULL
AND (m.search_vector @@ query OR $2 = '')
AND m.user_id != $1
AND (
EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1)
OR EXISTS (
SELECT 1 FROM tag_shares ts
WHERE ts.owner_id = m.user_id AND ts.shared_with = $1
AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag)
)
)
ORDER BY m.id
LIMIT $3
""",
*params,
)
seen_ids = set()
results = []
for row in rows:
seen_ids.add(row["id"])
c = row["content"]
if row["is_sensitive"]:
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
entry: dict[str, Any] = {
"id": row["id"], "content": c, "category": row["category"],
"tags": row["tags"], "importance": row["importance"],
"rank": float(row["rank"]),
"created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(),
}
results.append(entry)
for row in shared_rows:
if row["id"] in seen_ids:
continue
seen_ids.add(row["id"])
c = row["content"]
if row["is_sensitive"]:
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
@ -916,6 +962,7 @@ async def memory_recall(context: str, expanded_query: str = "",
"id": row["id"], "content": c, "category": row["category"],
"tags": row["tags"], "importance": row["importance"],
"rank": float(row["rank"]),
"shared_by": row["shared_by"],
"created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(),
})
@ -927,7 +974,7 @@ async def memory_recall(context: str, expanded_query: str = "",
async def memory_list(category: str | None = None, limit: int = 20) -> str:
"""List stored memories."""
pool = await get_pool()
user_id = "default"
user_id = _current_user.get()
if category:
query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
@ -940,11 +987,47 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str:
ORDER BY importance DESC LIMIT $2"""
params = [user_id, limit]
if category:
shared_query = """
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance,
m.is_sensitive, m.created_at, m.updated_at, m.user_id AS shared_by
FROM memories m
WHERE m.deleted_at IS NULL AND m.category = $2 AND m.user_id != $1
AND (
EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1)
OR EXISTS (
SELECT 1 FROM tag_shares ts
WHERE ts.owner_id = m.user_id AND ts.shared_with = $1
AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag)
)
)
ORDER BY m.id LIMIT $3"""
shared_params: list[Any] = [user_id, category, limit]
else:
shared_query = """
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance,
m.is_sensitive, m.created_at, m.updated_at, m.user_id AS shared_by
FROM memories m
WHERE m.deleted_at IS NULL AND m.user_id != $1
AND (
EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1)
OR EXISTS (
SELECT 1 FROM tag_shares ts
WHERE ts.owner_id = m.user_id AND ts.shared_with = $1
AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag)
)
)
ORDER BY m.id LIMIT $2"""
shared_params = [user_id, limit]
async with pool.acquire() as conn:
rows = await conn.fetch(query, *params)
shared_rows = await conn.fetch(shared_query, *shared_params)
seen_ids = set()
results = []
for row in rows:
seen_ids.add(row["id"])
c = row["content"]
if row["is_sensitive"]:
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
@ -955,6 +1038,21 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str:
"updated_at": row["updated_at"].isoformat(),
})
for row in shared_rows:
if row["id"] in seen_ids:
continue
seen_ids.add(row["id"])
c = row["content"]
if row["is_sensitive"]:
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
results.append({
"id": row["id"], "content": c, "category": row["category"],
"tags": row["tags"], "importance": row["importance"],
"shared_by": row["shared_by"],
"created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(),
})
return json.dumps({"memories": results})
@ -962,7 +1060,7 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str:
async def memory_delete(memory_id: int) -> str:
"""Delete a memory by ID."""
pool = await get_pool()
user_id = "default"
user_id = _current_user.get()
async with pool.acquire() as conn:
row = await conn.fetchrow(
@ -987,7 +1085,7 @@ async def memory_delete(memory_id: int) -> str:
async def memory_count() -> str:
"""Count total memories."""
pool = await get_pool()
user_id = "default"
user_id = _current_user.get()
async with pool.acquire() as conn:
count = await conn.fetchval("SELECT COUNT(*) FROM memories WHERE user_id = $1 AND deleted_at IS NULL", user_id)
return json.dumps({"count": count})
@ -1005,7 +1103,7 @@ async def memory_share(id: int, shared_with: str, permission: str = "read") -> s
if permission not in ("read", "write"):
return json.dumps({"error": "permission must be 'read' or 'write'"})
pool = await get_pool()
user_id = "default"
user_id = _current_user.get()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
@ -1026,7 +1124,7 @@ async def memory_share(id: int, shared_with: str, permission: str = "read") -> s
async def memory_unshare(id: int, shared_with: str) -> str:
"""Revoke sharing of a memory from a user."""
pool = await get_pool()
user_id = "default"
user_id = _current_user.get()
async with pool.acquire() as conn:
await conn.execute(
"DELETE FROM memory_shares WHERE memory_id = $1 AND owner_id = $2 AND shared_with = $3",
@ -1041,7 +1139,7 @@ async def memory_share_tag(tag: str, shared_with: str, permission: str = "read")
if permission not in ("read", "write"):
return json.dumps({"error": "permission must be 'read' or 'write'"})
pool = await get_pool()
user_id = "default"
user_id = _current_user.get()
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO tag_shares (owner_id, tag, shared_with, permission)
@ -1056,7 +1154,7 @@ async def memory_share_tag(tag: str, shared_with: str, permission: str = "read")
async def memory_unshare_tag(tag: str, shared_with: str) -> str:
"""Revoke tag-based sharing."""
pool = await get_pool()
user_id = "default"
user_id = _current_user.get()
async with pool.acquire() as conn:
await conn.execute(
"DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3",
@ -1070,7 +1168,7 @@ async def memory_update(id: int, content: str | None = None, tags: str | None =
importance: float | None = None, expanded_keywords: str | None = None) -> str:
"""Update an existing memory's content, tags, importance, or keywords."""
pool = await get_pool()
user_id = "default"
user_id = _current_user.get()
async with pool.acquire() as conn:
allowed, owner_id = await check_memory_permission(conn, id, user_id, "write")
if not allowed:
@ -1132,6 +1230,16 @@ sse_transport = SseServerTransport("/messages/")
class HandleSSE:
"""ASGI app for SSE connections."""
async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
# Extract user from Authorization header for multi-user MCP
user_id = "default"
for name, value in scope.get("headers", []):
if name == b"authorization":
token = value.decode().removeprefix("Bearer ").strip()
resolved = _resolve_user_from_token(token)
if resolved:
user_id = resolved
break
_current_user.set(user_id)
async with sse_transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
await mcp_server._mcp_server.run(
read_stream, write_stream, mcp_server._mcp_server.create_initialization_options()