fix mypy across all source files, remove || true from CI
- Add type annotations to all FastAPI endpoints in api/app.py - Fix bare list/dict generics in sync.py and app.py - Fix no-any-return in vault_client.py and sync.py - Remove mypy || true from GitHub Actions CI — mypy is now clean
This commit is contained in:
parent
678d50654b
commit
d370855abf
4 changed files with 23 additions and 19 deletions
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
|
|
@ -19,7 +19,7 @@ jobs:
|
||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
- run: pip install -e ".[api,dev]"
|
- run: pip install -e ".[api,dev]"
|
||||||
- run: ruff check src/ tests/
|
- run: ruff check src/ tests/
|
||||||
- run: mypy src/claude_memory/ || true
|
- run: mypy src/claude_memory/
|
||||||
- run: pytest tests/ -v --tb=short
|
- run: pytest tests/ -v --tb=short
|
||||||
|
|
||||||
build:
|
build:
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
from typing import Any, AsyncGenerator, Optional
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException
|
from fastapi import Depends, FastAPI, HTTPException
|
||||||
|
|
||||||
|
|
@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
await init_pool()
|
await init_pool()
|
||||||
yield
|
yield
|
||||||
await close_pool()
|
await close_pool()
|
||||||
|
|
@ -55,7 +55,7 @@ def _redact_content(content: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health() -> dict[str, str]:
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -63,7 +63,7 @@ async def health():
|
||||||
async def sync_memories(
|
async def sync_memories(
|
||||||
since: Optional[str] = None,
|
since: Optional[str] = None,
|
||||||
user: AuthUser = Depends(get_current_user),
|
user: AuthUser = Depends(get_current_user),
|
||||||
):
|
) -> SyncResponse:
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
server_time = datetime.now(timezone.utc).isoformat()
|
server_time = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
@ -113,7 +113,7 @@ async def sync_memories(
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/memories", response_model=MemoryResponse)
|
@app.post("/api/memories", response_model=MemoryResponse)
|
||||||
async def store_memory(body: MemoryStore, user: AuthUser = Depends(get_current_user)):
|
async def store_memory(body: MemoryStore, user: AuthUser = Depends(get_current_user)) -> MemoryResponse:
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
is_sensitive = body.force_sensitive or _detect_sensitive(body.content)
|
is_sensitive = body.force_sensitive or _detect_sensitive(body.content)
|
||||||
|
|
||||||
|
|
@ -146,7 +146,7 @@ async def store_memory(body: MemoryStore, user: AuthUser = Depends(get_current_u
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/memories/recall")
|
@app.post("/api/memories/recall")
|
||||||
async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_current_user)):
|
async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
|
|
||||||
query_text = f"{body.context} {body.expanded_query}".strip()
|
query_text = f"{body.context} {body.expanded_query}".strip()
|
||||||
|
|
@ -161,7 +161,7 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
||||||
order_clause = "created_at DESC"
|
order_clause = "created_at DESC"
|
||||||
|
|
||||||
category_filter = ""
|
category_filter = ""
|
||||||
params: list = [user.user_id, query_text, body.limit]
|
params: list[Any] = [user.user_id, query_text, body.limit]
|
||||||
if body.category:
|
if body.category:
|
||||||
category_filter = "AND category = $4"
|
category_filter = "AND category = $4"
|
||||||
params.append(body.category)
|
params.append(body.category)
|
||||||
|
|
@ -190,7 +190,7 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
||||||
words = query_text.split()
|
words = query_text.split()
|
||||||
if len(words) > 1:
|
if len(words) > 1:
|
||||||
or_tsquery = " | ".join(w for w in words if w)
|
or_tsquery = " | ".join(w for w in words if w)
|
||||||
or_params: list = [user.user_id, or_tsquery, body.limit]
|
or_params: list[Any] = [user.user_id, or_tsquery, body.limit]
|
||||||
or_cat_filter = ""
|
or_cat_filter = ""
|
||||||
if body.category:
|
if body.category:
|
||||||
or_cat_filter = "AND category = $4"
|
or_cat_filter = "AND category = $4"
|
||||||
|
|
@ -241,7 +241,7 @@ async def list_memories(
|
||||||
category: Optional[str] = None,
|
category: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
user: AuthUser = Depends(get_current_user),
|
user: AuthUser = Depends(get_current_user),
|
||||||
):
|
) -> dict[str, Any]:
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
|
|
||||||
if category:
|
if category:
|
||||||
|
|
@ -250,7 +250,7 @@ async def list_memories(
|
||||||
FROM memories WHERE user_id = $1 AND deleted_at IS NULL AND category = $2
|
FROM memories WHERE user_id = $1 AND deleted_at IS NULL AND category = $2
|
||||||
ORDER BY importance DESC LIMIT $3
|
ORDER BY importance DESC LIMIT $3
|
||||||
"""
|
"""
|
||||||
params: list = [user.user_id, category, limit]
|
params: list[Any] = [user.user_id, category, limit]
|
||||||
else:
|
else:
|
||||||
query = """
|
query = """
|
||||||
SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
||||||
|
|
@ -284,7 +284,7 @@ async def list_memories(
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/api/memories/{memory_id}")
|
@app.delete("/api/memories/{memory_id}")
|
||||||
async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_user)):
|
async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
|
|
@ -311,7 +311,7 @@ async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_use
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/memories/{memory_id}/secret", response_model=SecretResponse)
|
@app.post("/api/memories/{memory_id}/secret", response_model=SecretResponse)
|
||||||
async def get_memory_secret(memory_id: int, user: AuthUser = Depends(get_current_user)):
|
async def get_memory_secret(memory_id: int, user: AuthUser = Depends(get_current_user)) -> SecretResponse:
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
|
|
@ -346,7 +346,7 @@ async def get_memory_secret(memory_id: int, user: AuthUser = Depends(get_current
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/memories/migrate-secrets")
|
@app.post("/api/memories/migrate-secrets")
|
||||||
async def migrate_secrets(user: AuthUser = Depends(get_current_user)):
|
async def migrate_secrets(user: AuthUser = Depends(get_current_user)) -> dict[str, int]:
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
migrated = 0
|
migrated = 0
|
||||||
|
|
||||||
|
|
@ -388,7 +388,7 @@ async def migrate_secrets(user: AuthUser = Depends(get_current_user)):
|
||||||
@app.post("/api/memories/import")
|
@app.post("/api/memories/import")
|
||||||
async def import_memories(
|
async def import_memories(
|
||||||
memories: list[MemoryStore], user: AuthUser = Depends(get_current_user)
|
memories: list[MemoryStore], user: AuthUser = Depends(get_current_user)
|
||||||
):
|
) -> list[MemoryResponse]:
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
imported = []
|
imported = []
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import threading
|
||||||
import urllib.error
|
import urllib.error
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
from typing import Any
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
@ -117,7 +118,7 @@ class SyncEngine:
|
||||||
self._push_pending_ops()
|
self._push_pending_ops()
|
||||||
self._pull_changes()
|
self._pull_changes()
|
||||||
|
|
||||||
def _api_request(self, method: str, path: str, body: dict | None = None) -> dict:
|
def _api_request(self, method: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
"""Make an HTTP request to the memory API."""
|
"""Make an HTTP request to the memory API."""
|
||||||
url = f"{self.api_base_url}{path}"
|
url = f"{self.api_base_url}{path}"
|
||||||
data = json.dumps(body).encode() if body else None
|
data = json.dumps(body).encode() if body else None
|
||||||
|
|
@ -131,7 +132,8 @@ class SyncEngine:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||||
return json.loads(resp.read().decode())
|
result: dict[str, Any] = json.loads(resp.read().decode())
|
||||||
|
return result
|
||||||
|
|
||||||
def _push_pending_ops(self) -> None:
|
def _push_pending_ops(self) -> None:
|
||||||
"""Push queued operations to the API server."""
|
"""Push queued operations to the API server."""
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,8 @@ class VaultClient:
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||||
return json.loads(resp.read().decode())
|
result: dict[str, Any] = json.loads(resp.read().decode())
|
||||||
|
return result
|
||||||
except urllib.error.HTTPError as e:
|
except urllib.error.HTTPError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
return {}
|
return {}
|
||||||
|
|
@ -79,6 +80,7 @@ class VaultClient:
|
||||||
"""List secrets at a path."""
|
"""List secrets at a path."""
|
||||||
try:
|
try:
|
||||||
resp = self._request("LIST", f"/v1/{self.mount}/metadata/{path}")
|
resp = self._request("LIST", f"/v1/{self.mount}/metadata/{path}")
|
||||||
return resp.get("data", {}).get("keys", [])
|
keys: list[str] = resp.get("data", {}).get("keys", [])
|
||||||
|
return keys
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue