feat: multi-user MCP SSE support + shared memories in recall/list
- Use contextvars to resolve user identity from Authorization header in SSE connections, replacing hardcoded "default" user_id - memory_recall now includes shared memories (individual + tag-based) with deduplication and shared_by attribution - memory_list now includes shared memories with same approach - All 11 MCP tool functions use _current_user contextvar
This commit is contained in:
parent
95dd937765
commit
f242c45c73
1 changed files with 120 additions and 12 deletions
|
|
@ -4,6 +4,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, AsyncGenerator, Optional
|
from typing import Any, AsyncGenerator, Optional
|
||||||
|
|
||||||
|
|
@ -31,6 +32,9 @@ from claude_memory.api.vault_service import (
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Context variable for MCP SSE multi-user support
|
||||||
|
_current_user: ContextVar[str] = ContextVar("_current_user", default="default")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
|
@ -846,8 +850,7 @@ async def memory_store(content: str, category: str = "facts", tags: str = "",
|
||||||
expanded_keywords: str = "", importance: float = 0.5) -> str:
|
expanded_keywords: str = "", importance: float = 0.5) -> str:
|
||||||
"""Store a new memory."""
|
"""Store a new memory."""
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
# MCP SSE uses "default" user (single-user mode via middleware auth)
|
user_id = _current_user.get()
|
||||||
user_id = "default"
|
|
||||||
is_sensitive = _detect_sensitive(content)
|
is_sensitive = _detect_sensitive(content)
|
||||||
stored_content = content if not is_sensitive else _redact_content(content)
|
stored_content = content if not is_sensitive else _redact_content(content)
|
||||||
|
|
||||||
|
|
@ -873,7 +876,7 @@ async def memory_recall(context: str, expanded_query: str = "",
|
||||||
limit: int = 10) -> str:
|
limit: int = 10) -> str:
|
||||||
"""Recall memories by semantic search."""
|
"""Recall memories by semantic search."""
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
user_id = "default"
|
user_id = _current_user.get()
|
||||||
query_text = f"{context} {expanded_query}".strip()
|
query_text = f"{context} {expanded_query}".strip()
|
||||||
if not query_text:
|
if not query_text:
|
||||||
return json.dumps({"error": "context is required"})
|
return json.dumps({"error": "context is required"})
|
||||||
|
|
@ -896,7 +899,8 @@ async def memory_recall(context: str, expanded_query: str = "",
|
||||||
rows = await conn.fetch(
|
rows = await conn.fetch(
|
||||||
f"""
|
f"""
|
||||||
SELECT id, content, category, tags, importance, is_sensitive,
|
SELECT id, content, category, tags, importance, is_sensitive,
|
||||||
ts_rank(search_vector, query) AS rank, created_at, updated_at
|
ts_rank(search_vector, query) AS rank, created_at, updated_at,
|
||||||
|
NULL::text AS shared_by
|
||||||
FROM memories, plainto_tsquery('english', $2) query
|
FROM memories, plainto_tsquery('english', $2) query
|
||||||
WHERE user_id = $1 AND deleted_at IS NULL
|
WHERE user_id = $1 AND deleted_at IS NULL
|
||||||
AND (search_vector @@ query OR $2 = '')
|
AND (search_vector @@ query OR $2 = '')
|
||||||
|
|
@ -907,8 +911,50 @@ async def memory_recall(context: str, expanded_query: str = "",
|
||||||
*params,
|
*params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Also fetch shared memories (individual + tag-based)
|
||||||
|
shared_rows = await conn.fetch(
|
||||||
|
f"""
|
||||||
|
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance,
|
||||||
|
m.is_sensitive, ts_rank(m.search_vector, query) AS rank,
|
||||||
|
m.created_at, m.updated_at, m.user_id AS shared_by
|
||||||
|
FROM memories m, plainto_tsquery('english', $2) query
|
||||||
|
WHERE m.deleted_at IS NULL
|
||||||
|
AND (m.search_vector @@ query OR $2 = '')
|
||||||
|
AND m.user_id != $1
|
||||||
|
AND (
|
||||||
|
EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1)
|
||||||
|
OR EXISTS (
|
||||||
|
SELECT 1 FROM tag_shares ts
|
||||||
|
WHERE ts.owner_id = m.user_id AND ts.shared_with = $1
|
||||||
|
AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ORDER BY m.id
|
||||||
|
LIMIT $3
|
||||||
|
""",
|
||||||
|
*params,
|
||||||
|
)
|
||||||
|
|
||||||
|
seen_ids = set()
|
||||||
results = []
|
results = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
seen_ids.add(row["id"])
|
||||||
|
c = row["content"]
|
||||||
|
if row["is_sensitive"]:
|
||||||
|
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||||
|
entry: dict[str, Any] = {
|
||||||
|
"id": row["id"], "content": c, "category": row["category"],
|
||||||
|
"tags": row["tags"], "importance": row["importance"],
|
||||||
|
"rank": float(row["rank"]),
|
||||||
|
"created_at": row["created_at"].isoformat(),
|
||||||
|
"updated_at": row["updated_at"].isoformat(),
|
||||||
|
}
|
||||||
|
results.append(entry)
|
||||||
|
|
||||||
|
for row in shared_rows:
|
||||||
|
if row["id"] in seen_ids:
|
||||||
|
continue
|
||||||
|
seen_ids.add(row["id"])
|
||||||
c = row["content"]
|
c = row["content"]
|
||||||
if row["is_sensitive"]:
|
if row["is_sensitive"]:
|
||||||
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||||
|
|
@ -916,6 +962,7 @@ async def memory_recall(context: str, expanded_query: str = "",
|
||||||
"id": row["id"], "content": c, "category": row["category"],
|
"id": row["id"], "content": c, "category": row["category"],
|
||||||
"tags": row["tags"], "importance": row["importance"],
|
"tags": row["tags"], "importance": row["importance"],
|
||||||
"rank": float(row["rank"]),
|
"rank": float(row["rank"]),
|
||||||
|
"shared_by": row["shared_by"],
|
||||||
"created_at": row["created_at"].isoformat(),
|
"created_at": row["created_at"].isoformat(),
|
||||||
"updated_at": row["updated_at"].isoformat(),
|
"updated_at": row["updated_at"].isoformat(),
|
||||||
})
|
})
|
||||||
|
|
@ -927,7 +974,7 @@ async def memory_recall(context: str, expanded_query: str = "",
|
||||||
async def memory_list(category: str | None = None, limit: int = 20) -> str:
|
async def memory_list(category: str | None = None, limit: int = 20) -> str:
|
||||||
"""List stored memories."""
|
"""List stored memories."""
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
user_id = "default"
|
user_id = _current_user.get()
|
||||||
|
|
||||||
if category:
|
if category:
|
||||||
query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
||||||
|
|
@ -940,11 +987,47 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str:
|
||||||
ORDER BY importance DESC LIMIT $2"""
|
ORDER BY importance DESC LIMIT $2"""
|
||||||
params = [user_id, limit]
|
params = [user_id, limit]
|
||||||
|
|
||||||
|
if category:
|
||||||
|
shared_query = """
|
||||||
|
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance,
|
||||||
|
m.is_sensitive, m.created_at, m.updated_at, m.user_id AS shared_by
|
||||||
|
FROM memories m
|
||||||
|
WHERE m.deleted_at IS NULL AND m.category = $2 AND m.user_id != $1
|
||||||
|
AND (
|
||||||
|
EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1)
|
||||||
|
OR EXISTS (
|
||||||
|
SELECT 1 FROM tag_shares ts
|
||||||
|
WHERE ts.owner_id = m.user_id AND ts.shared_with = $1
|
||||||
|
AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ORDER BY m.id LIMIT $3"""
|
||||||
|
shared_params: list[Any] = [user_id, category, limit]
|
||||||
|
else:
|
||||||
|
shared_query = """
|
||||||
|
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance,
|
||||||
|
m.is_sensitive, m.created_at, m.updated_at, m.user_id AS shared_by
|
||||||
|
FROM memories m
|
||||||
|
WHERE m.deleted_at IS NULL AND m.user_id != $1
|
||||||
|
AND (
|
||||||
|
EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1)
|
||||||
|
OR EXISTS (
|
||||||
|
SELECT 1 FROM tag_shares ts
|
||||||
|
WHERE ts.owner_id = m.user_id AND ts.shared_with = $1
|
||||||
|
AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ORDER BY m.id LIMIT $2"""
|
||||||
|
shared_params = [user_id, limit]
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
rows = await conn.fetch(query, *params)
|
rows = await conn.fetch(query, *params)
|
||||||
|
shared_rows = await conn.fetch(shared_query, *shared_params)
|
||||||
|
|
||||||
|
seen_ids = set()
|
||||||
results = []
|
results = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
seen_ids.add(row["id"])
|
||||||
c = row["content"]
|
c = row["content"]
|
||||||
if row["is_sensitive"]:
|
if row["is_sensitive"]:
|
||||||
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||||
|
|
@ -955,6 +1038,21 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str:
|
||||||
"updated_at": row["updated_at"].isoformat(),
|
"updated_at": row["updated_at"].isoformat(),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
for row in shared_rows:
|
||||||
|
if row["id"] in seen_ids:
|
||||||
|
continue
|
||||||
|
seen_ids.add(row["id"])
|
||||||
|
c = row["content"]
|
||||||
|
if row["is_sensitive"]:
|
||||||
|
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||||
|
results.append({
|
||||||
|
"id": row["id"], "content": c, "category": row["category"],
|
||||||
|
"tags": row["tags"], "importance": row["importance"],
|
||||||
|
"shared_by": row["shared_by"],
|
||||||
|
"created_at": row["created_at"].isoformat(),
|
||||||
|
"updated_at": row["updated_at"].isoformat(),
|
||||||
|
})
|
||||||
|
|
||||||
return json.dumps({"memories": results})
|
return json.dumps({"memories": results})
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -962,7 +1060,7 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str:
|
||||||
async def memory_delete(memory_id: int) -> str:
|
async def memory_delete(memory_id: int) -> str:
|
||||||
"""Delete a memory by ID."""
|
"""Delete a memory by ID."""
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
user_id = "default"
|
user_id = _current_user.get()
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
row = await conn.fetchrow(
|
row = await conn.fetchrow(
|
||||||
|
|
@ -987,7 +1085,7 @@ async def memory_delete(memory_id: int) -> str:
|
||||||
async def memory_count() -> str:
|
async def memory_count() -> str:
|
||||||
"""Count total memories."""
|
"""Count total memories."""
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
user_id = "default"
|
user_id = _current_user.get()
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
count = await conn.fetchval("SELECT COUNT(*) FROM memories WHERE user_id = $1 AND deleted_at IS NULL", user_id)
|
count = await conn.fetchval("SELECT COUNT(*) FROM memories WHERE user_id = $1 AND deleted_at IS NULL", user_id)
|
||||||
return json.dumps({"count": count})
|
return json.dumps({"count": count})
|
||||||
|
|
@ -1005,7 +1103,7 @@ async def memory_share(id: int, shared_with: str, permission: str = "read") -> s
|
||||||
if permission not in ("read", "write"):
|
if permission not in ("read", "write"):
|
||||||
return json.dumps({"error": "permission must be 'read' or 'write'"})
|
return json.dumps({"error": "permission must be 'read' or 'write'"})
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
user_id = "default"
|
user_id = _current_user.get()
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
row = await conn.fetchrow(
|
row = await conn.fetchrow(
|
||||||
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||||
|
|
@ -1026,7 +1124,7 @@ async def memory_share(id: int, shared_with: str, permission: str = "read") -> s
|
||||||
async def memory_unshare(id: int, shared_with: str) -> str:
|
async def memory_unshare(id: int, shared_with: str) -> str:
|
||||||
"""Revoke sharing of a memory from a user."""
|
"""Revoke sharing of a memory from a user."""
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
user_id = "default"
|
user_id = _current_user.get()
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"DELETE FROM memory_shares WHERE memory_id = $1 AND owner_id = $2 AND shared_with = $3",
|
"DELETE FROM memory_shares WHERE memory_id = $1 AND owner_id = $2 AND shared_with = $3",
|
||||||
|
|
@ -1041,7 +1139,7 @@ async def memory_share_tag(tag: str, shared_with: str, permission: str = "read")
|
||||||
if permission not in ("read", "write"):
|
if permission not in ("read", "write"):
|
||||||
return json.dumps({"error": "permission must be 'read' or 'write'"})
|
return json.dumps({"error": "permission must be 'read' or 'write'"})
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
user_id = "default"
|
user_id = _current_user.get()
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"""INSERT INTO tag_shares (owner_id, tag, shared_with, permission)
|
"""INSERT INTO tag_shares (owner_id, tag, shared_with, permission)
|
||||||
|
|
@ -1056,7 +1154,7 @@ async def memory_share_tag(tag: str, shared_with: str, permission: str = "read")
|
||||||
async def memory_unshare_tag(tag: str, shared_with: str) -> str:
|
async def memory_unshare_tag(tag: str, shared_with: str) -> str:
|
||||||
"""Revoke tag-based sharing."""
|
"""Revoke tag-based sharing."""
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
user_id = "default"
|
user_id = _current_user.get()
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3",
|
"DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3",
|
||||||
|
|
@ -1070,7 +1168,7 @@ async def memory_update(id: int, content: str | None = None, tags: str | None =
|
||||||
importance: float | None = None, expanded_keywords: str | None = None) -> str:
|
importance: float | None = None, expanded_keywords: str | None = None) -> str:
|
||||||
"""Update an existing memory's content, tags, importance, or keywords."""
|
"""Update an existing memory's content, tags, importance, or keywords."""
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
user_id = "default"
|
user_id = _current_user.get()
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
allowed, owner_id = await check_memory_permission(conn, id, user_id, "write")
|
allowed, owner_id = await check_memory_permission(conn, id, user_id, "write")
|
||||||
if not allowed:
|
if not allowed:
|
||||||
|
|
@ -1132,6 +1230,16 @@ sse_transport = SseServerTransport("/messages/")
|
||||||
class HandleSSE:
|
class HandleSSE:
|
||||||
"""ASGI app for SSE connections."""
|
"""ASGI app for SSE connections."""
|
||||||
async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
|
async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
|
||||||
|
# Extract user from Authorization header for multi-user MCP
|
||||||
|
user_id = "default"
|
||||||
|
for name, value in scope.get("headers", []):
|
||||||
|
if name == b"authorization":
|
||||||
|
token = value.decode().removeprefix("Bearer ").strip()
|
||||||
|
resolved = _resolve_user_from_token(token)
|
||||||
|
if resolved:
|
||||||
|
user_id = resolved
|
||||||
|
break
|
||||||
|
_current_user.set(user_id)
|
||||||
async with sse_transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
|
async with sse_transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
|
||||||
await mcp_server._mcp_server.run(
|
await mcp_server._mcp_server.run(
|
||||||
read_stream, write_stream, mcp_server._mcp_server.create_initialization_options()
|
read_stream, write_stream, mcp_server._mcp_server.create_initialization_options()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue