add multi-user memory sharing with r/w permissions

- New migration 004: memory_shares and tag_shares tables with indexes
- Share individual memories or entire tags with other users (read/write)
- Tag shares are live rules: future memories with shared tags auto-visible
- Recall query merges own + shared memories via UNION, returns shared_by field
- Owner-only delete enforcement (403 for non-owners, even with write access)
- PUT /api/memories/{id} update endpoint with permission checks
- 5 new MCP SSE tools: memory_share, memory_unshare, memory_share_tag,
  memory_unshare_tag, memory_update
- Permission helper checks ownership, individual shares, and tag shares
This commit is contained in:
Viktor Barzin 2026-03-22 15:34:01 +02:00
parent 1a275e976c
commit f45e8ce2b3
No known key found for this signature in database
GPG key ID: 0EB088298288D958
4 changed files with 556 additions and 13 deletions

View file

@ -0,0 +1,62 @@
"""Add memory sharing tables.
Revision ID: 004
Revises: 003
Create Date: 2026-03-22
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "004"
down_revision: Union[str, None] = "003"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _table_exists(conn, table_name: str) -> bool:
result = conn.execute(
sa.text(
"SELECT EXISTS(SELECT 1 FROM information_schema.tables "
"WHERE table_name = :tbl)"
),
{"tbl": table_name},
)
return result.scalar()
def upgrade() -> None:
conn = op.get_bind()
if not _table_exists(conn, "memory_shares"):
op.create_table(
"memory_shares",
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
sa.Column("memory_id", sa.Integer, sa.ForeignKey("memories.id"), nullable=False),
sa.Column("owner_id", sa.String(100), nullable=False),
sa.Column("shared_with", sa.String(100), nullable=False),
sa.Column("permission", sa.String(10), nullable=False, server_default="read"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
sa.UniqueConstraint("memory_id", "shared_with", name="uq_memory_shares_memory_user"),
)
op.create_index("idx_shares_shared_with", "memory_shares", ["shared_with"])
op.create_index("idx_shares_memory_id", "memory_shares", ["memory_id"])
if not _table_exists(conn, "tag_shares"):
op.create_table(
"tag_shares",
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
sa.Column("owner_id", sa.String(100), nullable=False),
sa.Column("tag", sa.String(100), nullable=False),
sa.Column("shared_with", sa.String(100), nullable=False),
sa.Column("permission", sa.String(10), nullable=False, server_default="read"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
sa.UniqueConstraint("owner_id", "tag", "shared_with", name="uq_tag_shares_owner_tag_user"),
)
op.create_index("idx_tag_shares_shared_with", "tag_shares", ["shared_with"])
def downgrade() -> None:
op.drop_table("tag_shares")
op.drop_table("memory_shares")

View file

@ -15,7 +15,11 @@ from starlette.routing import Mount, Route
from claude_memory.api.auth import AuthUser, get_current_user, _key_to_user
from claude_memory.api.database import close_pool, get_pool, init_pool
from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse, SyncResponse
from claude_memory.api.models import (
MemoryRecall, MemoryResponse, MemoryStore, MemoryUpdate,
SecretResponse, ShareMemory, ShareTag, SyncResponse, UnshareTag,
)
from claude_memory.api.permissions import check_memory_permission
from claude_memory.api.vault_service import (
delete_secret,
get_secret,
@ -179,13 +183,13 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
params.append(body.category)
async with pool.acquire() as conn:
# Try AND-match first (plainto_tsquery ANDs by default), fall back to
# OR-match via individual word disjunction for broader results
# Own memories (AND-match)
rows = await conn.fetch(
f"""
SELECT id, content, category, tags, importance, is_sensitive,
ts_rank(search_vector, query) AS rank,
created_at, updated_at
created_at, updated_at,
NULL::text AS shared_by, NULL::text AS share_permission
FROM memories, plainto_tsquery('english', $2) query
WHERE user_id = $1
AND deleted_at IS NULL
@ -197,8 +201,65 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
*params,
)
# If AND-match returned too few results, broaden to OR-match
if len(rows) < body.limit and query_text:
# Individually shared memories
shared_rows = await conn.fetch(
f"""
SELECT m.id, m.content, m.category, m.tags, m.importance, m.is_sensitive,
ts_rank(m.search_vector, query) AS rank,
m.created_at, m.updated_at,
m.user_id AS shared_by, ms.permission AS share_permission
FROM memories m
JOIN memory_shares ms ON ms.memory_id = m.id,
plainto_tsquery('english', $2) query
WHERE ms.shared_with = $1
AND m.deleted_at IS NULL
AND (m.search_vector @@ query OR $2 = '')
{category_filter}
ORDER BY {order_clause}
LIMIT $3
""",
*params,
)
# Tag-shared memories
tag_shared_rows = await conn.fetch(
f"""
SELECT DISTINCT ON (m.id)
m.id, m.content, m.category, m.tags, m.importance, m.is_sensitive,
ts_rank(m.search_vector, query) AS rank,
m.created_at, m.updated_at,
m.user_id AS shared_by, ts.permission AS share_permission
FROM memories m
JOIN tag_shares ts ON ts.owner_id = m.user_id,
plainto_tsquery('english', $2) query
WHERE ts.shared_with = $1
AND m.deleted_at IS NULL
AND (m.search_vector @@ query OR $2 = '')
AND EXISTS (
SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t
WHERE trim(t) = ts.tag
)
{category_filter}
ORDER BY m.id
LIMIT $3
""",
*params,
)
# Merge and deduplicate
seen_ids: set[int] = set()
all_rows = []
for row in list(rows) + list(shared_rows) + list(tag_shared_rows):
if row["id"] not in seen_ids:
seen_ids.add(row["id"])
all_rows.append(row)
# Sort merged results by importance desc and trim
all_rows.sort(key=lambda r: r["importance"], reverse=True)
all_rows = all_rows[:body.limit]
# If AND-match returned too few results, broaden to OR-match (own memories only)
if len(all_rows) < body.limit and query_text:
words = query_text.split()
if len(words) > 1:
or_tsquery = " | ".join(w for w in words if w)
@ -207,12 +268,12 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
if body.category:
or_cat_filter = "AND category = $4"
or_params.append(body.category)
seen_ids = {r["id"] for r in rows}
or_rows = await conn.fetch(
f"""
SELECT id, content, category, tags, importance, is_sensitive,
ts_rank(search_vector, query) AS rank,
created_at, updated_at
created_at, updated_at,
NULL::text AS shared_by, NULL::text AS share_permission
FROM memories, to_tsquery('english', $2) query
WHERE user_id = $1
AND deleted_at IS NULL
@ -223,11 +284,11 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
""",
*or_params,
)
rows = list(rows) + [r for r in or_rows if r["id"] not in seen_ids]
rows = rows[:body.limit]
all_rows = all_rows + [r for r in or_rows if r["id"] not in seen_ids]
all_rows = all_rows[:body.limit]
results = []
for row in rows:
for row in all_rows:
content = row["content"]
if row["is_sensitive"]:
content = f"[SENSITIVE - use secret_get(id={row['id']})]"
@ -242,6 +303,8 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
"rank": float(row["rank"]),
"created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(),
"shared_by": row["shared_by"],
"share_permission": row["share_permission"],
}
)
@ -300,19 +363,27 @@ async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_use
pool = await get_pool()
async with pool.acquire() as conn:
# Only the owner can delete — even write-shared users cannot
row = await conn.fetchrow(
"SELECT id, vault_path, substr(content, 1, 50) AS preview FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
memory_id,
user.user_id,
)
if not row:
# Check if memory exists but is owned by someone else
exists = await conn.fetchrow(
"SELECT id FROM memories WHERE id = $1 AND deleted_at IS NULL", memory_id
)
if exists:
raise HTTPException(status_code=403, detail="Only the owner can delete a memory")
# Idempotent: return success even if already deleted
# Prevents old clients without 404-handling from infinite retry loops
return {"deleted": memory_id, "preview": "[already deleted]"}
if row["vault_path"]:
await delete_secret(user.user_id, row["vault_path"])
# Also clean up any shares for this memory
await conn.execute("DELETE FROM memory_shares WHERE memory_id = $1", memory_id)
await conn.execute(
"UPDATE memories SET deleted_at = NOW(), updated_at = NOW() WHERE id = $1 AND user_id = $2",
memory_id,
@ -439,6 +510,221 @@ async def import_memories(
return imported
# --- Sharing Endpoints ---
@app.post("/api/memories/{memory_id}/share")
async def share_memory(memory_id: int, body: ShareMemory, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
memory_id, user.user_id,
)
if not row:
raise HTTPException(status_code=404, detail="Memory not found or not owned by you")
if body.shared_with == user.user_id:
raise HTTPException(status_code=400, detail="Cannot share with yourself")
await conn.execute(
"""
INSERT INTO memory_shares (memory_id, owner_id, shared_with, permission)
VALUES ($1, $2, $3, $4)
ON CONFLICT (memory_id, shared_with)
DO UPDATE SET permission = EXCLUDED.permission
""",
memory_id, user.user_id, body.shared_with, body.permission,
)
return {"shared": memory_id, "with": body.shared_with, "permission": body.permission}
@app.delete("/api/memories/{memory_id}/share/{target_user}")
async def unshare_memory(memory_id: int, target_user: str, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
memory_id, user.user_id,
)
if not row:
raise HTTPException(status_code=404, detail="Memory not found or not owned by you")
await conn.execute(
"DELETE FROM memory_shares WHERE memory_id = $1 AND shared_with = $2",
memory_id, target_user,
)
return {"unshared": memory_id, "from": target_user}
@app.post("/api/memories/share-tag")
async def share_tag(body: ShareTag, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
pool = await get_pool()
if body.shared_with == user.user_id:
raise HTTPException(status_code=400, detail="Cannot share with yourself")
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO tag_shares (owner_id, tag, shared_with, permission)
VALUES ($1, $2, $3, $4)
ON CONFLICT (owner_id, tag, shared_with)
DO UPDATE SET permission = EXCLUDED.permission
""",
user.user_id, body.tag, body.shared_with, body.permission,
)
return {"shared_tag": body.tag, "with": body.shared_with, "permission": body.permission}
@app.delete("/api/memories/share-tag")
async def unshare_tag(body: UnshareTag, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
pool = await get_pool()
async with pool.acquire() as conn:
await conn.execute(
"DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3",
user.user_id, body.tag, body.shared_with,
)
return {"unshared_tag": body.tag, "from": body.shared_with}
@app.get("/api/memories/shared-with-me")
async def shared_with_me(user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
pool = await get_pool()
async with pool.acquire() as conn:
# Individual shares
individual = await conn.fetch(
"""
SELECT m.id, m.content, m.category, m.tags, m.importance, m.is_sensitive,
m.created_at, m.updated_at, m.user_id AS shared_by, ms.permission
FROM memories m
JOIN memory_shares ms ON ms.memory_id = m.id
WHERE ms.shared_with = $1 AND m.deleted_at IS NULL
ORDER BY m.importance DESC
""",
user.user_id,
)
# Tag shares
tag_shared = await conn.fetch(
"""
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance, m.is_sensitive,
m.created_at, m.updated_at, m.user_id AS shared_by, ts.permission
FROM memories m
JOIN tag_shares ts ON ts.owner_id = m.user_id
WHERE ts.shared_with = $1 AND m.deleted_at IS NULL
AND EXISTS (
SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t
WHERE trim(t) = ts.tag
)
ORDER BY m.id, m.importance DESC
""",
user.user_id,
)
seen_ids = set()
results = []
for row in list(individual) + list(tag_shared):
if row["id"] in seen_ids:
continue
seen_ids.add(row["id"])
content = row["content"]
if row["is_sensitive"]:
content = f"[SENSITIVE - use secret_get(id={row['id']})]"
results.append({
"id": row["id"], "content": content, "category": row["category"],
"tags": row["tags"], "importance": row["importance"],
"shared_by": row["shared_by"], "permission": row["permission"],
"created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(),
})
return {"memories": results}
@app.get("/api/memories/my-shares")
async def my_shares(user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
pool = await get_pool()
async with pool.acquire() as conn:
memory_shares = await conn.fetch(
"""
SELECT ms.memory_id, ms.shared_with, ms.permission, ms.created_at,
substr(m.content, 1, 80) AS preview
FROM memory_shares ms
JOIN memories m ON m.id = ms.memory_id
WHERE ms.owner_id = $1 AND m.deleted_at IS NULL
ORDER BY ms.created_at DESC
""",
user.user_id,
)
tag_shares = await conn.fetch(
"SELECT tag, shared_with, permission, created_at FROM tag_shares WHERE owner_id = $1 ORDER BY created_at DESC",
user.user_id,
)
return {
"memory_shares": [
{
"memory_id": r["memory_id"], "shared_with": r["shared_with"],
"permission": r["permission"], "preview": r["preview"],
"created_at": r["created_at"].isoformat(),
}
for r in memory_shares
],
"tag_shares": [
{
"tag": r["tag"], "shared_with": r["shared_with"],
"permission": r["permission"],
"created_at": r["created_at"].isoformat(),
}
for r in tag_shares
],
}
@app.put("/api/memories/{memory_id}")
async def update_memory(memory_id: int, body: MemoryUpdate, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
pool = await get_pool()
async with pool.acquire() as conn:
allowed, owner_id = await check_memory_permission(conn, memory_id, user.user_id, "write")
if not allowed:
if owner_id is None:
raise HTTPException(status_code=404, detail="Memory not found")
raise HTTPException(status_code=403, detail="Write permission required")
updates = []
params: list[Any] = []
idx = 1
if body.content is not None:
updates.append(f"content = ${idx}")
params.append(body.content)
idx += 1
if body.tags is not None:
updates.append(f"tags = ${idx}")
params.append(body.tags)
idx += 1
if body.importance is not None:
updates.append(f"importance = ${idx}")
params.append(body.importance)
idx += 1
if body.expanded_keywords is not None:
updates.append(f"expanded_keywords = ${idx}")
params.append(body.expanded_keywords)
idx += 1
if not updates:
raise HTTPException(status_code=400, detail="No fields to update")
updates.append("updated_at = NOW()")
params.append(memory_id)
await conn.execute(
f"UPDATE memories SET {', '.join(updates)} WHERE id = ${idx}",
*params,
)
return {"updated": memory_id}
# --- MCP SSE Transport ---
@ -608,6 +894,118 @@ async def secret_get(key: str) -> str:
return json.dumps({"error": "secret_get is not available via SSE transport"})
@mcp_server.tool()
async def memory_share(id: int, shared_with: str, permission: str = "read") -> str:
"""Share a memory with another user. Permission: 'read' or 'write'."""
if permission not in ("read", "write"):
return json.dumps({"error": "permission must be 'read' or 'write'"})
pool = await get_pool()
user_id = "default"
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
id, user_id,
)
if not row:
return json.dumps({"error": "Memory not found or not owned by you"})
await conn.execute(
"""INSERT INTO memory_shares (memory_id, owner_id, shared_with, permission)
VALUES ($1, $2, $3, $4)
ON CONFLICT (memory_id, shared_with) DO UPDATE SET permission = EXCLUDED.permission""",
id, user_id, shared_with, permission,
)
return json.dumps({"shared": id, "with": shared_with, "permission": permission})
@mcp_server.tool()
async def memory_unshare(id: int, shared_with: str) -> str:
"""Revoke sharing of a memory from a user."""
pool = await get_pool()
user_id = "default"
async with pool.acquire() as conn:
await conn.execute(
"DELETE FROM memory_shares WHERE memory_id = $1 AND owner_id = $2 AND shared_with = $3",
id, user_id, shared_with,
)
return json.dumps({"unshared": id, "from": shared_with})
@mcp_server.tool()
async def memory_share_tag(tag: str, shared_with: str, permission: str = "read") -> str:
"""Share all memories with a given tag with another user. Future memories with this tag are automatically shared."""
if permission not in ("read", "write"):
return json.dumps({"error": "permission must be 'read' or 'write'"})
pool = await get_pool()
user_id = "default"
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO tag_shares (owner_id, tag, shared_with, permission)
VALUES ($1, $2, $3, $4)
ON CONFLICT (owner_id, tag, shared_with) DO UPDATE SET permission = EXCLUDED.permission""",
user_id, tag, shared_with, permission,
)
return json.dumps({"shared_tag": tag, "with": shared_with, "permission": permission})
@mcp_server.tool()
async def memory_unshare_tag(tag: str, shared_with: str) -> str:
"""Revoke tag-based sharing."""
pool = await get_pool()
user_id = "default"
async with pool.acquire() as conn:
await conn.execute(
"DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3",
user_id, tag, shared_with,
)
return json.dumps({"unshared_tag": tag, "from": shared_with})
@mcp_server.tool()
async def memory_update(id: int, content: str | None = None, tags: str | None = None,
importance: float | None = None, expanded_keywords: str | None = None) -> str:
"""Update an existing memory's content, tags, importance, or keywords."""
pool = await get_pool()
user_id = "default"
async with pool.acquire() as conn:
allowed, owner_id = await check_memory_permission(conn, id, user_id, "write")
if not allowed:
if owner_id is None:
return json.dumps({"error": "Memory not found"})
return json.dumps({"error": "Write permission required"})
updates = []
params: list[Any] = []
idx = 1
if content is not None:
updates.append(f"content = ${idx}")
params.append(content)
idx += 1
if tags is not None:
updates.append(f"tags = ${idx}")
params.append(tags)
idx += 1
if importance is not None:
updates.append(f"importance = ${idx}")
params.append(importance)
idx += 1
if expanded_keywords is not None:
updates.append(f"expanded_keywords = ${idx}")
params.append(expanded_keywords)
idx += 1
if not updates:
return json.dumps({"error": "No fields to update"})
updates.append("updated_at = NOW()")
params.append(id)
await conn.execute(
f"UPDATE memories SET {', '.join(updates)} WHERE id = ${idx}",
*params,
)
return json.dumps({"updated": id})
# Auth middleware for /mcp/* routes
class MCPAuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Any) -> Response:

View file

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field
@ -38,3 +38,26 @@ class SecretResponse(BaseModel):
class SyncResponse(BaseModel):
memories: list[dict[str, Any]]
server_time: str
class ShareMemory(BaseModel):
shared_with: str = Field(..., min_length=1, max_length=100)
permission: Literal["read", "write"] = "read"
class ShareTag(BaseModel):
tag: str = Field(..., min_length=1, max_length=100)
shared_with: str = Field(..., min_length=1, max_length=100)
permission: Literal["read", "write"] = "read"
class UnshareTag(BaseModel):
tag: str = Field(..., min_length=1, max_length=100)
shared_with: str = Field(..., min_length=1, max_length=100)
class MemoryUpdate(BaseModel):
content: Optional[str] = Field(None, max_length=MAX_MEMORY_CHARS)
tags: Optional[str] = None
importance: Optional[float] = Field(None, ge=0.0, le=1.0)
expanded_keywords: Optional[str] = None

View file

@ -0,0 +1,60 @@
"""Permission checking for shared memories."""
import asyncpg
async def check_memory_permission(
conn: asyncpg.Connection, memory_id: int, user_id: str, required: str
) -> tuple[bool, str | None]:
"""Check if user_id has the required permission on memory_id.
Returns (allowed, owner_id).
- Owner always has full access.
- Shared users checked via memory_shares and tag_shares.
- required: "read" or "write". "read" is satisfied by either permission.
"""
row = await conn.fetchrow(
"SELECT user_id FROM memories WHERE id = $1 AND deleted_at IS NULL",
memory_id,
)
if not row:
return False, None
owner_id = row["user_id"]
# Owner always has access
if owner_id == user_id:
return True, owner_id
# Check individual memory share
share = await conn.fetchrow(
"SELECT permission FROM memory_shares WHERE memory_id = $1 AND shared_with = $2",
memory_id, user_id,
)
if share:
if required == "read" or share["permission"] == "write":
return True, owner_id
return False, owner_id
# Check tag-based shares
tag_share = await conn.fetchrow(
"""
SELECT ts.permission
FROM tag_shares ts
JOIN memories m ON m.user_id = ts.owner_id
WHERE m.id = $1 AND ts.shared_with = $2
AND EXISTS (
SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t
WHERE trim(t) = ts.tag
)
ORDER BY CASE WHEN ts.permission = 'write' THEN 0 ELSE 1 END
LIMIT 1
""",
memory_id, user_id,
)
if tag_share:
if required == "read" or tag_share["permission"] == "write":
return True, owner_id
return False, owner_id
return False, owner_id