add multi-user memory sharing with r/w permissions
- New migration 004: memory_shares and tag_shares tables with indexes
- Share individual memories or entire tags with other users (read/write)
- Tag shares are live rules: future memories with shared tags auto-visible
- Recall query merges own + shared memories via UNION, returns shared_by field
- Owner-only delete enforcement (403 for non-owners, even with write access)
- PUT /api/memories/{id} update endpoint with permission checks
- 5 new MCP SSE tools: memory_share, memory_unshare, memory_share_tag,
memory_unshare_tag, memory_update
- Permission helper checks ownership, individual shares, and tag shares
This commit is contained in:
parent
1a275e976c
commit
f45e8ce2b3
4 changed files with 556 additions and 13 deletions
62
migrations/versions/004_add_sharing.py
Normal file
62
migrations/versions/004_add_sharing.py
Normal file
|
|
@ -0,0 +1,62 @@
|
||||||
|
"""Add memory sharing tables.
|
||||||
|
|
||||||
|
Revision ID: 004
|
||||||
|
Revises: 003
|
||||||
|
Create Date: 2026-03-22
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision: str = "004"
|
||||||
|
down_revision: Union[str, None] = "003"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _table_exists(conn, table_name: str) -> bool:
|
||||||
|
result = conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT EXISTS(SELECT 1 FROM information_schema.tables "
|
||||||
|
"WHERE table_name = :tbl)"
|
||||||
|
),
|
||||||
|
{"tbl": table_name},
|
||||||
|
)
|
||||||
|
return result.scalar()
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
if not _table_exists(conn, "memory_shares"):
|
||||||
|
op.create_table(
|
||||||
|
"memory_shares",
|
||||||
|
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
|
||||||
|
sa.Column("memory_id", sa.Integer, sa.ForeignKey("memories.id"), nullable=False),
|
||||||
|
sa.Column("owner_id", sa.String(100), nullable=False),
|
||||||
|
sa.Column("shared_with", sa.String(100), nullable=False),
|
||||||
|
sa.Column("permission", sa.String(10), nullable=False, server_default="read"),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
|
||||||
|
sa.UniqueConstraint("memory_id", "shared_with", name="uq_memory_shares_memory_user"),
|
||||||
|
)
|
||||||
|
op.create_index("idx_shares_shared_with", "memory_shares", ["shared_with"])
|
||||||
|
op.create_index("idx_shares_memory_id", "memory_shares", ["memory_id"])
|
||||||
|
|
||||||
|
if not _table_exists(conn, "tag_shares"):
|
||||||
|
op.create_table(
|
||||||
|
"tag_shares",
|
||||||
|
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
|
||||||
|
sa.Column("owner_id", sa.String(100), nullable=False),
|
||||||
|
sa.Column("tag", sa.String(100), nullable=False),
|
||||||
|
sa.Column("shared_with", sa.String(100), nullable=False),
|
||||||
|
sa.Column("permission", sa.String(10), nullable=False, server_default="read"),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
|
||||||
|
sa.UniqueConstraint("owner_id", "tag", "shared_with", name="uq_tag_shares_owner_tag_user"),
|
||||||
|
)
|
||||||
|
op.create_index("idx_tag_shares_shared_with", "tag_shares", ["shared_with"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("tag_shares")
|
||||||
|
op.drop_table("memory_shares")
|
||||||
|
|
@ -15,7 +15,11 @@ from starlette.routing import Mount, Route
|
||||||
|
|
||||||
from claude_memory.api.auth import AuthUser, get_current_user, _key_to_user
|
from claude_memory.api.auth import AuthUser, get_current_user, _key_to_user
|
||||||
from claude_memory.api.database import close_pool, get_pool, init_pool
|
from claude_memory.api.database import close_pool, get_pool, init_pool
|
||||||
from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse, SyncResponse
|
from claude_memory.api.models import (
|
||||||
|
MemoryRecall, MemoryResponse, MemoryStore, MemoryUpdate,
|
||||||
|
SecretResponse, ShareMemory, ShareTag, SyncResponse, UnshareTag,
|
||||||
|
)
|
||||||
|
from claude_memory.api.permissions import check_memory_permission
|
||||||
from claude_memory.api.vault_service import (
|
from claude_memory.api.vault_service import (
|
||||||
delete_secret,
|
delete_secret,
|
||||||
get_secret,
|
get_secret,
|
||||||
|
|
@ -179,13 +183,13 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
||||||
params.append(body.category)
|
params.append(body.category)
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
# Try AND-match first (plainto_tsquery ANDs by default), fall back to
|
# Own memories (AND-match)
|
||||||
# OR-match via individual word disjunction for broader results
|
|
||||||
rows = await conn.fetch(
|
rows = await conn.fetch(
|
||||||
f"""
|
f"""
|
||||||
SELECT id, content, category, tags, importance, is_sensitive,
|
SELECT id, content, category, tags, importance, is_sensitive,
|
||||||
ts_rank(search_vector, query) AS rank,
|
ts_rank(search_vector, query) AS rank,
|
||||||
created_at, updated_at
|
created_at, updated_at,
|
||||||
|
NULL::text AS shared_by, NULL::text AS share_permission
|
||||||
FROM memories, plainto_tsquery('english', $2) query
|
FROM memories, plainto_tsquery('english', $2) query
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
AND deleted_at IS NULL
|
AND deleted_at IS NULL
|
||||||
|
|
@ -197,8 +201,65 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
||||||
*params,
|
*params,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If AND-match returned too few results, broaden to OR-match
|
# Individually shared memories
|
||||||
if len(rows) < body.limit and query_text:
|
shared_rows = await conn.fetch(
|
||||||
|
f"""
|
||||||
|
SELECT m.id, m.content, m.category, m.tags, m.importance, m.is_sensitive,
|
||||||
|
ts_rank(m.search_vector, query) AS rank,
|
||||||
|
m.created_at, m.updated_at,
|
||||||
|
m.user_id AS shared_by, ms.permission AS share_permission
|
||||||
|
FROM memories m
|
||||||
|
JOIN memory_shares ms ON ms.memory_id = m.id,
|
||||||
|
plainto_tsquery('english', $2) query
|
||||||
|
WHERE ms.shared_with = $1
|
||||||
|
AND m.deleted_at IS NULL
|
||||||
|
AND (m.search_vector @@ query OR $2 = '')
|
||||||
|
{category_filter}
|
||||||
|
ORDER BY {order_clause}
|
||||||
|
LIMIT $3
|
||||||
|
""",
|
||||||
|
*params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tag-shared memories
|
||||||
|
tag_shared_rows = await conn.fetch(
|
||||||
|
f"""
|
||||||
|
SELECT DISTINCT ON (m.id)
|
||||||
|
m.id, m.content, m.category, m.tags, m.importance, m.is_sensitive,
|
||||||
|
ts_rank(m.search_vector, query) AS rank,
|
||||||
|
m.created_at, m.updated_at,
|
||||||
|
m.user_id AS shared_by, ts.permission AS share_permission
|
||||||
|
FROM memories m
|
||||||
|
JOIN tag_shares ts ON ts.owner_id = m.user_id,
|
||||||
|
plainto_tsquery('english', $2) query
|
||||||
|
WHERE ts.shared_with = $1
|
||||||
|
AND m.deleted_at IS NULL
|
||||||
|
AND (m.search_vector @@ query OR $2 = '')
|
||||||
|
AND EXISTS (
|
||||||
|
SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t
|
||||||
|
WHERE trim(t) = ts.tag
|
||||||
|
)
|
||||||
|
{category_filter}
|
||||||
|
ORDER BY m.id
|
||||||
|
LIMIT $3
|
||||||
|
""",
|
||||||
|
*params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge and deduplicate
|
||||||
|
seen_ids: set[int] = set()
|
||||||
|
all_rows = []
|
||||||
|
for row in list(rows) + list(shared_rows) + list(tag_shared_rows):
|
||||||
|
if row["id"] not in seen_ids:
|
||||||
|
seen_ids.add(row["id"])
|
||||||
|
all_rows.append(row)
|
||||||
|
|
||||||
|
# Sort merged results by importance desc and trim
|
||||||
|
all_rows.sort(key=lambda r: r["importance"], reverse=True)
|
||||||
|
all_rows = all_rows[:body.limit]
|
||||||
|
|
||||||
|
# If AND-match returned too few results, broaden to OR-match (own memories only)
|
||||||
|
if len(all_rows) < body.limit and query_text:
|
||||||
words = query_text.split()
|
words = query_text.split()
|
||||||
if len(words) > 1:
|
if len(words) > 1:
|
||||||
or_tsquery = " | ".join(w for w in words if w)
|
or_tsquery = " | ".join(w for w in words if w)
|
||||||
|
|
@ -207,12 +268,12 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
||||||
if body.category:
|
if body.category:
|
||||||
or_cat_filter = "AND category = $4"
|
or_cat_filter = "AND category = $4"
|
||||||
or_params.append(body.category)
|
or_params.append(body.category)
|
||||||
seen_ids = {r["id"] for r in rows}
|
|
||||||
or_rows = await conn.fetch(
|
or_rows = await conn.fetch(
|
||||||
f"""
|
f"""
|
||||||
SELECT id, content, category, tags, importance, is_sensitive,
|
SELECT id, content, category, tags, importance, is_sensitive,
|
||||||
ts_rank(search_vector, query) AS rank,
|
ts_rank(search_vector, query) AS rank,
|
||||||
created_at, updated_at
|
created_at, updated_at,
|
||||||
|
NULL::text AS shared_by, NULL::text AS share_permission
|
||||||
FROM memories, to_tsquery('english', $2) query
|
FROM memories, to_tsquery('english', $2) query
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
AND deleted_at IS NULL
|
AND deleted_at IS NULL
|
||||||
|
|
@ -223,11 +284,11 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
||||||
""",
|
""",
|
||||||
*or_params,
|
*or_params,
|
||||||
)
|
)
|
||||||
rows = list(rows) + [r for r in or_rows if r["id"] not in seen_ids]
|
all_rows = all_rows + [r for r in or_rows if r["id"] not in seen_ids]
|
||||||
rows = rows[:body.limit]
|
all_rows = all_rows[:body.limit]
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for row in rows:
|
for row in all_rows:
|
||||||
content = row["content"]
|
content = row["content"]
|
||||||
if row["is_sensitive"]:
|
if row["is_sensitive"]:
|
||||||
content = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
content = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||||
|
|
@ -242,6 +303,8 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
||||||
"rank": float(row["rank"]),
|
"rank": float(row["rank"]),
|
||||||
"created_at": row["created_at"].isoformat(),
|
"created_at": row["created_at"].isoformat(),
|
||||||
"updated_at": row["updated_at"].isoformat(),
|
"updated_at": row["updated_at"].isoformat(),
|
||||||
|
"shared_by": row["shared_by"],
|
||||||
|
"share_permission": row["share_permission"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -300,19 +363,27 @@ async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_use
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
|
# Only the owner can delete — even write-shared users cannot
|
||||||
row = await conn.fetchrow(
|
row = await conn.fetchrow(
|
||||||
"SELECT id, vault_path, substr(content, 1, 50) AS preview FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
"SELECT id, vault_path, substr(content, 1, 50) AS preview FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||||
memory_id,
|
memory_id,
|
||||||
user.user_id,
|
user.user_id,
|
||||||
)
|
)
|
||||||
if not row:
|
if not row:
|
||||||
|
# Check if memory exists but is owned by someone else
|
||||||
|
exists = await conn.fetchrow(
|
||||||
|
"SELECT id FROM memories WHERE id = $1 AND deleted_at IS NULL", memory_id
|
||||||
|
)
|
||||||
|
if exists:
|
||||||
|
raise HTTPException(status_code=403, detail="Only the owner can delete a memory")
|
||||||
# Idempotent: return success even if already deleted
|
# Idempotent: return success even if already deleted
|
||||||
# Prevents old clients without 404-handling from infinite retry loops
|
|
||||||
return {"deleted": memory_id, "preview": "[already deleted]"}
|
return {"deleted": memory_id, "preview": "[already deleted]"}
|
||||||
|
|
||||||
if row["vault_path"]:
|
if row["vault_path"]:
|
||||||
await delete_secret(user.user_id, row["vault_path"])
|
await delete_secret(user.user_id, row["vault_path"])
|
||||||
|
|
||||||
|
# Also clean up any shares for this memory
|
||||||
|
await conn.execute("DELETE FROM memory_shares WHERE memory_id = $1", memory_id)
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"UPDATE memories SET deleted_at = NOW(), updated_at = NOW() WHERE id = $1 AND user_id = $2",
|
"UPDATE memories SET deleted_at = NOW(), updated_at = NOW() WHERE id = $1 AND user_id = $2",
|
||||||
memory_id,
|
memory_id,
|
||||||
|
|
@ -439,6 +510,221 @@ async def import_memories(
|
||||||
return imported
|
return imported
|
||||||
|
|
||||||
|
|
||||||
|
# --- Sharing Endpoints ---
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/memories/{memory_id}/share")
|
||||||
|
async def share_memory(memory_id: int, body: ShareMemory, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||||
|
memory_id, user.user_id,
|
||||||
|
)
|
||||||
|
if not row:
|
||||||
|
raise HTTPException(status_code=404, detail="Memory not found or not owned by you")
|
||||||
|
if body.shared_with == user.user_id:
|
||||||
|
raise HTTPException(status_code=400, detail="Cannot share with yourself")
|
||||||
|
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO memory_shares (memory_id, owner_id, shared_with, permission)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
ON CONFLICT (memory_id, shared_with)
|
||||||
|
DO UPDATE SET permission = EXCLUDED.permission
|
||||||
|
""",
|
||||||
|
memory_id, user.user_id, body.shared_with, body.permission,
|
||||||
|
)
|
||||||
|
return {"shared": memory_id, "with": body.shared_with, "permission": body.permission}
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/api/memories/{memory_id}/share/{target_user}")
|
||||||
|
async def unshare_memory(memory_id: int, target_user: str, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||||
|
memory_id, user.user_id,
|
||||||
|
)
|
||||||
|
if not row:
|
||||||
|
raise HTTPException(status_code=404, detail="Memory not found or not owned by you")
|
||||||
|
|
||||||
|
await conn.execute(
|
||||||
|
"DELETE FROM memory_shares WHERE memory_id = $1 AND shared_with = $2",
|
||||||
|
memory_id, target_user,
|
||||||
|
)
|
||||||
|
return {"unshared": memory_id, "from": target_user}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/memories/share-tag")
|
||||||
|
async def share_tag(body: ShareTag, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||||
|
pool = await get_pool()
|
||||||
|
if body.shared_with == user.user_id:
|
||||||
|
raise HTTPException(status_code=400, detail="Cannot share with yourself")
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO tag_shares (owner_id, tag, shared_with, permission)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
ON CONFLICT (owner_id, tag, shared_with)
|
||||||
|
DO UPDATE SET permission = EXCLUDED.permission
|
||||||
|
""",
|
||||||
|
user.user_id, body.tag, body.shared_with, body.permission,
|
||||||
|
)
|
||||||
|
return {"shared_tag": body.tag, "with": body.shared_with, "permission": body.permission}
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/api/memories/share-tag")
|
||||||
|
async def unshare_tag(body: UnshareTag, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3",
|
||||||
|
user.user_id, body.tag, body.shared_with,
|
||||||
|
)
|
||||||
|
return {"unshared_tag": body.tag, "from": body.shared_with}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/memories/shared-with-me")
|
||||||
|
async def shared_with_me(user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
# Individual shares
|
||||||
|
individual = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT m.id, m.content, m.category, m.tags, m.importance, m.is_sensitive,
|
||||||
|
m.created_at, m.updated_at, m.user_id AS shared_by, ms.permission
|
||||||
|
FROM memories m
|
||||||
|
JOIN memory_shares ms ON ms.memory_id = m.id
|
||||||
|
WHERE ms.shared_with = $1 AND m.deleted_at IS NULL
|
||||||
|
ORDER BY m.importance DESC
|
||||||
|
""",
|
||||||
|
user.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tag shares
|
||||||
|
tag_shared = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance, m.is_sensitive,
|
||||||
|
m.created_at, m.updated_at, m.user_id AS shared_by, ts.permission
|
||||||
|
FROM memories m
|
||||||
|
JOIN tag_shares ts ON ts.owner_id = m.user_id
|
||||||
|
WHERE ts.shared_with = $1 AND m.deleted_at IS NULL
|
||||||
|
AND EXISTS (
|
||||||
|
SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t
|
||||||
|
WHERE trim(t) = ts.tag
|
||||||
|
)
|
||||||
|
ORDER BY m.id, m.importance DESC
|
||||||
|
""",
|
||||||
|
user.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
seen_ids = set()
|
||||||
|
results = []
|
||||||
|
for row in list(individual) + list(tag_shared):
|
||||||
|
if row["id"] in seen_ids:
|
||||||
|
continue
|
||||||
|
seen_ids.add(row["id"])
|
||||||
|
content = row["content"]
|
||||||
|
if row["is_sensitive"]:
|
||||||
|
content = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||||
|
results.append({
|
||||||
|
"id": row["id"], "content": content, "category": row["category"],
|
||||||
|
"tags": row["tags"], "importance": row["importance"],
|
||||||
|
"shared_by": row["shared_by"], "permission": row["permission"],
|
||||||
|
"created_at": row["created_at"].isoformat(),
|
||||||
|
"updated_at": row["updated_at"].isoformat(),
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"memories": results}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/memories/my-shares")
|
||||||
|
async def my_shares(user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
memory_shares = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT ms.memory_id, ms.shared_with, ms.permission, ms.created_at,
|
||||||
|
substr(m.content, 1, 80) AS preview
|
||||||
|
FROM memory_shares ms
|
||||||
|
JOIN memories m ON m.id = ms.memory_id
|
||||||
|
WHERE ms.owner_id = $1 AND m.deleted_at IS NULL
|
||||||
|
ORDER BY ms.created_at DESC
|
||||||
|
""",
|
||||||
|
user.user_id,
|
||||||
|
)
|
||||||
|
tag_shares = await conn.fetch(
|
||||||
|
"SELECT tag, shared_with, permission, created_at FROM tag_shares WHERE owner_id = $1 ORDER BY created_at DESC",
|
||||||
|
user.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"memory_shares": [
|
||||||
|
{
|
||||||
|
"memory_id": r["memory_id"], "shared_with": r["shared_with"],
|
||||||
|
"permission": r["permission"], "preview": r["preview"],
|
||||||
|
"created_at": r["created_at"].isoformat(),
|
||||||
|
}
|
||||||
|
for r in memory_shares
|
||||||
|
],
|
||||||
|
"tag_shares": [
|
||||||
|
{
|
||||||
|
"tag": r["tag"], "shared_with": r["shared_with"],
|
||||||
|
"permission": r["permission"],
|
||||||
|
"created_at": r["created_at"].isoformat(),
|
||||||
|
}
|
||||||
|
for r in tag_shares
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.put("/api/memories/{memory_id}")
|
||||||
|
async def update_memory(memory_id: int, body: MemoryUpdate, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
allowed, owner_id = await check_memory_permission(conn, memory_id, user.user_id, "write")
|
||||||
|
if not allowed:
|
||||||
|
if owner_id is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Memory not found")
|
||||||
|
raise HTTPException(status_code=403, detail="Write permission required")
|
||||||
|
|
||||||
|
updates = []
|
||||||
|
params: list[Any] = []
|
||||||
|
idx = 1
|
||||||
|
|
||||||
|
if body.content is not None:
|
||||||
|
updates.append(f"content = ${idx}")
|
||||||
|
params.append(body.content)
|
||||||
|
idx += 1
|
||||||
|
if body.tags is not None:
|
||||||
|
updates.append(f"tags = ${idx}")
|
||||||
|
params.append(body.tags)
|
||||||
|
idx += 1
|
||||||
|
if body.importance is not None:
|
||||||
|
updates.append(f"importance = ${idx}")
|
||||||
|
params.append(body.importance)
|
||||||
|
idx += 1
|
||||||
|
if body.expanded_keywords is not None:
|
||||||
|
updates.append(f"expanded_keywords = ${idx}")
|
||||||
|
params.append(body.expanded_keywords)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
if not updates:
|
||||||
|
raise HTTPException(status_code=400, detail="No fields to update")
|
||||||
|
|
||||||
|
updates.append("updated_at = NOW()")
|
||||||
|
params.append(memory_id)
|
||||||
|
|
||||||
|
await conn.execute(
|
||||||
|
f"UPDATE memories SET {', '.join(updates)} WHERE id = ${idx}",
|
||||||
|
*params,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"updated": memory_id}
|
||||||
|
|
||||||
|
|
||||||
# --- MCP SSE Transport ---
|
# --- MCP SSE Transport ---
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -608,6 +894,118 @@ async def secret_get(key: str) -> str:
|
||||||
return json.dumps({"error": "secret_get is not available via SSE transport"})
|
return json.dumps({"error": "secret_get is not available via SSE transport"})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp_server.tool()
|
||||||
|
async def memory_share(id: int, shared_with: str, permission: str = "read") -> str:
|
||||||
|
"""Share a memory with another user. Permission: 'read' or 'write'."""
|
||||||
|
if permission not in ("read", "write"):
|
||||||
|
return json.dumps({"error": "permission must be 'read' or 'write'"})
|
||||||
|
pool = await get_pool()
|
||||||
|
user_id = "default"
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||||
|
id, user_id,
|
||||||
|
)
|
||||||
|
if not row:
|
||||||
|
return json.dumps({"error": "Memory not found or not owned by you"})
|
||||||
|
await conn.execute(
|
||||||
|
"""INSERT INTO memory_shares (memory_id, owner_id, shared_with, permission)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
ON CONFLICT (memory_id, shared_with) DO UPDATE SET permission = EXCLUDED.permission""",
|
||||||
|
id, user_id, shared_with, permission,
|
||||||
|
)
|
||||||
|
return json.dumps({"shared": id, "with": shared_with, "permission": permission})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp_server.tool()
|
||||||
|
async def memory_unshare(id: int, shared_with: str) -> str:
|
||||||
|
"""Revoke sharing of a memory from a user."""
|
||||||
|
pool = await get_pool()
|
||||||
|
user_id = "default"
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"DELETE FROM memory_shares WHERE memory_id = $1 AND owner_id = $2 AND shared_with = $3",
|
||||||
|
id, user_id, shared_with,
|
||||||
|
)
|
||||||
|
return json.dumps({"unshared": id, "from": shared_with})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp_server.tool()
|
||||||
|
async def memory_share_tag(tag: str, shared_with: str, permission: str = "read") -> str:
|
||||||
|
"""Share all memories with a given tag with another user. Future memories with this tag are automatically shared."""
|
||||||
|
if permission not in ("read", "write"):
|
||||||
|
return json.dumps({"error": "permission must be 'read' or 'write'"})
|
||||||
|
pool = await get_pool()
|
||||||
|
user_id = "default"
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""INSERT INTO tag_shares (owner_id, tag, shared_with, permission)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
ON CONFLICT (owner_id, tag, shared_with) DO UPDATE SET permission = EXCLUDED.permission""",
|
||||||
|
user_id, tag, shared_with, permission,
|
||||||
|
)
|
||||||
|
return json.dumps({"shared_tag": tag, "with": shared_with, "permission": permission})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp_server.tool()
|
||||||
|
async def memory_unshare_tag(tag: str, shared_with: str) -> str:
|
||||||
|
"""Revoke tag-based sharing."""
|
||||||
|
pool = await get_pool()
|
||||||
|
user_id = "default"
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3",
|
||||||
|
user_id, tag, shared_with,
|
||||||
|
)
|
||||||
|
return json.dumps({"unshared_tag": tag, "from": shared_with})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp_server.tool()
|
||||||
|
async def memory_update(id: int, content: str | None = None, tags: str | None = None,
|
||||||
|
importance: float | None = None, expanded_keywords: str | None = None) -> str:
|
||||||
|
"""Update an existing memory's content, tags, importance, or keywords."""
|
||||||
|
pool = await get_pool()
|
||||||
|
user_id = "default"
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
allowed, owner_id = await check_memory_permission(conn, id, user_id, "write")
|
||||||
|
if not allowed:
|
||||||
|
if owner_id is None:
|
||||||
|
return json.dumps({"error": "Memory not found"})
|
||||||
|
return json.dumps({"error": "Write permission required"})
|
||||||
|
|
||||||
|
updates = []
|
||||||
|
params: list[Any] = []
|
||||||
|
idx = 1
|
||||||
|
if content is not None:
|
||||||
|
updates.append(f"content = ${idx}")
|
||||||
|
params.append(content)
|
||||||
|
idx += 1
|
||||||
|
if tags is not None:
|
||||||
|
updates.append(f"tags = ${idx}")
|
||||||
|
params.append(tags)
|
||||||
|
idx += 1
|
||||||
|
if importance is not None:
|
||||||
|
updates.append(f"importance = ${idx}")
|
||||||
|
params.append(importance)
|
||||||
|
idx += 1
|
||||||
|
if expanded_keywords is not None:
|
||||||
|
updates.append(f"expanded_keywords = ${idx}")
|
||||||
|
params.append(expanded_keywords)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
if not updates:
|
||||||
|
return json.dumps({"error": "No fields to update"})
|
||||||
|
|
||||||
|
updates.append("updated_at = NOW()")
|
||||||
|
params.append(id)
|
||||||
|
await conn.execute(
|
||||||
|
f"UPDATE memories SET {', '.join(updates)} WHERE id = ${idx}",
|
||||||
|
*params,
|
||||||
|
)
|
||||||
|
|
||||||
|
return json.dumps({"updated": id})
|
||||||
|
|
||||||
|
|
||||||
# Auth middleware for /mcp/* routes
|
# Auth middleware for /mcp/* routes
|
||||||
class MCPAuthMiddleware(BaseHTTPMiddleware):
|
class MCPAuthMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next: Any) -> Response:
|
async def dispatch(self, request: Request, call_next: Any) -> Response:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
@ -38,3 +38,26 @@ class SecretResponse(BaseModel):
|
||||||
class SyncResponse(BaseModel):
|
class SyncResponse(BaseModel):
|
||||||
memories: list[dict[str, Any]]
|
memories: list[dict[str, Any]]
|
||||||
server_time: str
|
server_time: str
|
||||||
|
|
||||||
|
|
||||||
|
class ShareMemory(BaseModel):
|
||||||
|
shared_with: str = Field(..., min_length=1, max_length=100)
|
||||||
|
permission: Literal["read", "write"] = "read"
|
||||||
|
|
||||||
|
|
||||||
|
class ShareTag(BaseModel):
|
||||||
|
tag: str = Field(..., min_length=1, max_length=100)
|
||||||
|
shared_with: str = Field(..., min_length=1, max_length=100)
|
||||||
|
permission: Literal["read", "write"] = "read"
|
||||||
|
|
||||||
|
|
||||||
|
class UnshareTag(BaseModel):
|
||||||
|
tag: str = Field(..., min_length=1, max_length=100)
|
||||||
|
shared_with: str = Field(..., min_length=1, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryUpdate(BaseModel):
|
||||||
|
content: Optional[str] = Field(None, max_length=MAX_MEMORY_CHARS)
|
||||||
|
tags: Optional[str] = None
|
||||||
|
importance: Optional[float] = Field(None, ge=0.0, le=1.0)
|
||||||
|
expanded_keywords: Optional[str] = None
|
||||||
|
|
|
||||||
60
src/claude_memory/api/permissions.py
Normal file
60
src/claude_memory/api/permissions.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
"""Permission checking for shared memories."""
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
|
||||||
|
async def check_memory_permission(
|
||||||
|
conn: asyncpg.Connection, memory_id: int, user_id: str, required: str
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
"""Check if user_id has the required permission on memory_id.
|
||||||
|
|
||||||
|
Returns (allowed, owner_id).
|
||||||
|
- Owner always has full access.
|
||||||
|
- Shared users checked via memory_shares and tag_shares.
|
||||||
|
- required: "read" or "write". "read" is satisfied by either permission.
|
||||||
|
"""
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"SELECT user_id FROM memories WHERE id = $1 AND deleted_at IS NULL",
|
||||||
|
memory_id,
|
||||||
|
)
|
||||||
|
if not row:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
owner_id = row["user_id"]
|
||||||
|
|
||||||
|
# Owner always has access
|
||||||
|
if owner_id == user_id:
|
||||||
|
return True, owner_id
|
||||||
|
|
||||||
|
# Check individual memory share
|
||||||
|
share = await conn.fetchrow(
|
||||||
|
"SELECT permission FROM memory_shares WHERE memory_id = $1 AND shared_with = $2",
|
||||||
|
memory_id, user_id,
|
||||||
|
)
|
||||||
|
if share:
|
||||||
|
if required == "read" or share["permission"] == "write":
|
||||||
|
return True, owner_id
|
||||||
|
return False, owner_id
|
||||||
|
|
||||||
|
# Check tag-based shares
|
||||||
|
tag_share = await conn.fetchrow(
|
||||||
|
"""
|
||||||
|
SELECT ts.permission
|
||||||
|
FROM tag_shares ts
|
||||||
|
JOIN memories m ON m.user_id = ts.owner_id
|
||||||
|
WHERE m.id = $1 AND ts.shared_with = $2
|
||||||
|
AND EXISTS (
|
||||||
|
SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t
|
||||||
|
WHERE trim(t) = ts.tag
|
||||||
|
)
|
||||||
|
ORDER BY CASE WHEN ts.permission = 'write' THEN 0 ELSE 1 END
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
memory_id, user_id,
|
||||||
|
)
|
||||||
|
if tag_share:
|
||||||
|
if required == "read" or tag_share["permission"] == "write":
|
||||||
|
return True, owner_id
|
||||||
|
return False, owner_id
|
||||||
|
|
||||||
|
return False, owner_id
|
||||||
Loading…
Add table
Add a link
Reference in a new issue