claude-memory-mcp/hooks/auto-learn.py
Viktor Barzin e47efee6b6
resilient memory sync: decouple push/pull, startup full resync, auth failure handling
- Decouple push and pull in _sync_once() so pull always runs even if push fails
- Add startup full resync to catch drift from other agents and schema changes
- Add periodic full resync every ~10 minutes for continuous drift correction
- Add auth failure detection (401/403) with graceful SQLite-only degradation
- Add /api/auth-check endpoint for lightweight key validation
- Add retry cap (5 attempts) on pending ops to prevent infinite queue buildup
- Add orphan reconciliation: push local-only records with content dedup
- Add memory_count MCP tool for sync diagnostics
- Add version-based SQLite schema migration (PRAGMA user_version)
- Fix API key in ~/.claude.json to match server
- Update README with sync resilience docs, test structure, project layout
- Add 30 new tests covering all new behaviors (155 total, all passing)
2026-03-16 18:37:59 +00:00

557 lines
21 KiB
Python

#!/usr/bin/env python3
"""
Stop hook (async): automatic learning extraction via haiku-as-judge.
After each Claude response, reads the recent conversation window and uses
haiku to detect learnings worth persisting:
- User corrections, preferences, decisions, facts (original scope)
- Debugging insights: error → root cause → fix mappings
- Architectural patterns and workarounds discovered during work
- Service/tool-specific operational knowledge
Features:
- Multi-turn context window (last 5 exchanges by default)
- State tracking to avoid duplicate extraction
- Writes to memory API/SQLite only
- Throttled deep extraction: full window every ~5 turns, single-turn otherwise
Runs with async: true — does NOT block the user.
"""
import hashlib
import io
import json
import logging
import os
import shutil
import subprocess
import sys
import urllib.error
import urllib.request
from datetime import datetime, timezone
from pathlib import Path
logger = logging.getLogger(__name__)
API_BASE_URL = os.environ.get("MEMORY_API_URL") or os.environ.get("CLAUDE_MEMORY_API_URL", "")
API_KEY = os.environ.get("MEMORY_API_KEY") or os.environ.get("CLAUDE_MEMORY_API_KEY", "")
# How many turns between deep (multi-turn) extractions
DEEP_EXTRACTION_INTERVAL = 5
# Max exchanges to include in deep extraction
DEEP_WINDOW_SIZE = 5
# Max chars per message in the context window
MAX_MSG_CHARS = 3000
# State directory
STATE_DIR = Path.home() / ".claude" / "auto-learn-state"
SINGLE_TURN_PROMPT = """You are a memory extraction judge. Analyze this single exchange between a user and an AI assistant.
USER MESSAGE:
{user_message}
ASSISTANT RESPONSE:
{assistant_response}
Your job: determine if any of these learning events occurred:
1. USER CORRECTION — user corrected the assistant's mistake or misunderstanding
2. PREFERENCE — user stated a preference, habit, or "I like/prefer/want" statement
3. DECISION — a decision was reached about how to do something
4. FACT — user shared a durable fact about themselves, their team, tools, or environment
If ANY learning event occurred, return JSON:
{{"events": [{{"type": "correction|preference|decision|fact", "content": "concise fact to remember (1-2 sentences, max 300 chars)", "importance": 0.7, "tags": "comma,separated,tags", "expanded_keywords": "space-separated semantically related search terms for recall (minimum 5 words)", "supersedes": null}}]}}
If NO learning event occurred, return:
{{"events": []}}
Rules:
- Only extract DURABLE facts, not transient task details ("fix this file", "run tests")
- Corrections are highest value (0.8-0.9)
- Be conservative — false negatives are better than false positives
- ONE topic per event. If multiple topics, create separate events.
- Keep each event's content under 300 characters (1-2 sentences). Include the "why" not just the "what".
- "supersedes" should be a search query to find the old outdated memory, or null
- Return ONLY valid JSON, no other text"""
DEEP_EXTRACTION_PROMPT = """You are a knowledge extraction system. Analyze this multi-turn conversation between a user and an AI assistant working on software engineering tasks.
CONVERSATION (last {n_exchanges} exchanges):
{conversation}
Extract any DURABLE knowledge worth remembering across sessions. Look for:
1. **CORRECTIONS** — user corrected a mistake or misunderstanding (importance: 0.8-0.9)
2. **PREFERENCES** — user stated how they like things done (importance: 0.7-0.8)
3. **DECISIONS** — architectural or design decisions reached (importance: 0.7-0.8)
4. **FACTS** — durable facts about user, team, tools, environment (importance: 0.6-0.8)
5. **DEBUGGING INSIGHTS** — error → root cause → fix patterns that would help next time (importance: 0.7-0.9)
6. **WORKAROUNDS** — things that didn't work and what did instead (importance: 0.7-0.8)
7. **OPERATIONAL KNOWLEDGE** — service-specific learnings, config gotchas, resource requirements (importance: 0.7-0.8)
Return JSON:
{{"events": [{{"type": "correction|preference|decision|fact|debugging|workaround|operational", "content": "concise knowledge (1-3 sentences, max 500 chars, ONE topic per event)", "importance": 0.7, "tags": "comma,separated,relevant,tags", "expanded_keywords": "space-separated semantically related search terms for recall (minimum 5 words)", "supersedes": null}}]}}
If NO durable knowledge was found, return:
{{"events": []}}
Rules:
- Only extract DURABLE knowledge, not transient task context ("reading file X", "running command Y")
- Don't extract things that are obvious from the codebase (file paths, function names)
- DO extract: "X doesn't work because Y — use Z instead", "service A needs B config", "always do X before Y"
- ONE topic per event — never combine unrelated learnings into a single event
- Keep each event's content between 100-500 characters. Include the WHY, not just the WHAT.
- If a debugging session revealed the root cause, capture the error→cause→fix chain as ONE event
- "supersedes" should be a search query to find an old outdated memory this replaces, or null
- Maximum 5 events per extraction — prioritize by importance
- Return ONLY valid JSON, no other text"""
def _get_state_path(session_id: str) -> Path:
"""Get state file path for this session."""
STATE_DIR.mkdir(parents=True, exist_ok=True)
return STATE_DIR / f"{session_id}.json"
def _load_state(session_id: str) -> dict:
"""Load extraction state for this session."""
path = _get_state_path(session_id)
if path.exists():
try:
return json.loads(path.read_text())
except (json.JSONDecodeError, OSError):
pass
return {"turn_count": 0, "extracted_hashes": [], "last_deep_turn": 0}
def _save_state(session_id: str, state: dict) -> None:
"""Save extraction state for this session."""
path = _get_state_path(session_id)
try:
path.write_text(json.dumps(state))
except OSError:
pass
def _cleanup_old_state() -> None:
"""Remove state files older than 24 hours."""
if not STATE_DIR.exists():
return
now = datetime.now().timestamp()
try:
for f in STATE_DIR.iterdir():
if f.suffix == ".json" and (now - f.stat().st_mtime) > 86400:
f.unlink(missing_ok=True)
except OSError:
pass
def _content_hash(content: str) -> str:
"""Hash content for deduplication."""
return hashlib.sha256(content.encode()).hexdigest()[:16]
def _parse_transcript(transcript_path: str, max_exchanges: int = 1) -> list[dict]:
"""
Parse the transcript and return the last N exchanges as
[{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
"""
try:
MAX_TAIL_BYTES = max_exchanges * 100_000 # ~100KB per exchange should be plenty
with open(transcript_path, "rb") as f:
f.seek(0, io.SEEK_END)
size = f.tell()
f.seek(max(0, size - MAX_TAIL_BYTES))
tail = f.read().decode("utf-8", errors="replace")
lines = tail.split("\n")
except Exception:
return []
entries = []
for line in lines:
line = line.strip()
if not line:
continue
try:
entry = json.loads(line)
except json.JSONDecodeError:
continue
# Transcript format: role can be at top level or nested in message
msg = entry.get("message", entry)
role = msg.get("role", "") or entry.get("type", "")
if role not in ("user", "assistant"):
continue
content = msg.get("content", "")
if isinstance(content, list):
content = " ".join(
b.get("text", "") for b in content
if isinstance(b, dict) and b.get("type") == "text"
)
content = str(content)[:MAX_MSG_CHARS]
if content.strip():
entries.append({"role": role, "content": content})
# Extract the last N exchanges (user+assistant pairs)
# Walk backwards to find pairs
exchanges = []
i = len(entries) - 1
while i >= 0 and len(exchanges) < max_exchanges * 2:
exchanges.insert(0, entries[i])
i -= 1
# Trim to last N complete exchanges
result = []
pair_count = 0
for entry in reversed(exchanges):
result.insert(0, entry)
if entry["role"] == "user":
pair_count += 1
if pair_count >= max_exchanges:
break
return result
def _api_request(method: str, path: str, body: dict | None = None) -> dict:
url = f"{API_BASE_URL}{path}"
data = json.dumps(body).encode() if body else None
req = urllib.request.Request(
url, data=data, method=method,
headers={"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"},
)
with urllib.request.urlopen(req, timeout=15) as resp:
return json.loads(resp.read().decode())
def _store_via_api(content, category, tags, importance, expanded_keywords):
_api_request("POST", "/api/memories", {
"content": content, "category": category, "tags": tags,
"expanded_keywords": expanded_keywords, "importance": importance,
})
def _store_via_sqlite(content, category, tags, importance, expanded_keywords):
import sqlite3
memory_home = os.environ.get("MEMORY_HOME", os.path.expanduser("~/.claude/claude-memory"))
db_path = os.path.join(memory_home, "memory", "memory.db")
if not os.path.exists(db_path):
legacy_db = os.path.join(os.path.expanduser("~/.claude/metaclaw"), "memory", "memory.db")
if os.path.exists(legacy_db):
db_path = legacy_db
conn = sqlite3.connect(db_path, timeout=10.0)
conn.execute("PRAGMA journal_mode=WAL")
now = datetime.now(timezone.utc).isoformat()
conn.execute(
"INSERT INTO memories (content, category, tags, importance, expanded_keywords, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
(content, category, tags, importance, expanded_keywords, now, now),
)
conn.commit()
conn.close()
def _parse_llm_response(response_text: str) -> list[dict]:
"""Parse LLM response text into events list."""
response_text = response_text.strip()
# Strip markdown code fences if present
if response_text.startswith("```"):
lines = response_text.split("\n")
lines = [l for l in lines if not l.strip().startswith("```")]
response_text = "\n".join(lines).strip()
# Try to extract JSON from the response
# Sometimes the LLM adds text before/after the JSON
start = response_text.find("{")
end = response_text.rfind("}") + 1
if start >= 0 and end > start:
response_text = response_text[start:end]
judge_result = json.loads(response_text)
return judge_result.get("events", [])
def _call_judge_claude(prompt: str) -> list[dict] | None:
"""Try claude CLI as judge. Returns None if unavailable."""
if not shutil.which("claude"):
return None
try:
result = subprocess.run(
["claude", "-p", prompt, "--model", "haiku"],
capture_output=True, text=True, timeout=60,
# Run from /tmp to avoid internet-mode-used marker prompts
# Clear CLAUDECODE to prevent recursion
cwd="/tmp",
env={**os.environ, "CLAUDECODE": ""},
)
if result.returncode != 0:
return None
return _parse_llm_response(result.stdout)
except (subprocess.TimeoutExpired, json.JSONDecodeError, OSError):
return None
def _call_judge_ollama(prompt: str) -> list[dict] | None:
"""Try local ollama as judge. Returns None if unavailable."""
ollama_url = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
# Prefer small models for speed
models_to_try = ["qwen2.5:3b", "llama3.2:3b", "gemma2:2b", "phi3:mini"]
for model in models_to_try:
try:
data = json.dumps({
"model": model,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0, "num_predict": 512},
}).encode()
req = urllib.request.Request(
f"{ollama_url}/api/generate",
data=data, method="POST",
headers={"Content-Type": "application/json"},
)
with urllib.request.urlopen(req, timeout=30) as resp:
result = json.loads(resp.read().decode())
return _parse_llm_response(result.get("response", ""))
except Exception:
continue
return None
def _call_judge_heuristic(entries: list[dict]) -> list[dict]:
"""
Heuristic fallback: extract learnings via pattern matching.
Less accurate than LLM but works without any external dependencies.
"""
events = []
correction_patterns = [
"actually", "that's wrong", "no,", "not correct", "instead of",
"don't use", "never use", "always use", "the correct way",
"the issue was", "the problem was", "root cause",
]
preference_patterns = [
"i prefer", "i like", "i want", "please always", "please never",
"remember to", "from now on", "going forward",
]
decision_patterns = [
"let's go with", "we decided", "the approach is",
"we'll use", "switched to", "migrated to",
]
for entry in entries:
if entry["role"] != "user":
continue
text_lower = entry["content"].lower()
for pattern in correction_patterns:
if pattern in text_lower:
# Extract the sentence containing the pattern
for sentence in entry["content"].replace("\n", ". ").split(". "):
if pattern in sentence.lower() and len(sentence) > 20:
events.append({
"type": "correction",
"content": sentence.strip()[:200],
"importance": 0.8,
"tags": "auto-learned,heuristic,correction",
"expanded_keywords": " ".join(sentence.lower().split()[:10]),
})
break
break
for pattern in preference_patterns:
if pattern in text_lower:
for sentence in entry["content"].replace("\n", ". ").split(". "):
if pattern in sentence.lower() and len(sentence) > 15:
events.append({
"type": "preference",
"content": sentence.strip()[:200],
"importance": 0.7,
"tags": "auto-learned,heuristic,preference",
"expanded_keywords": " ".join(sentence.lower().split()[:10]),
})
break
break
for pattern in decision_patterns:
if pattern in text_lower:
for sentence in entry["content"].replace("\n", ". ").split(". "):
if pattern in sentence.lower() and len(sentence) > 20:
events.append({
"type": "decision",
"content": sentence.strip()[:200],
"importance": 0.7,
"tags": "auto-learned,heuristic,decision",
"expanded_keywords": " ".join(sentence.lower().split()[:10]),
})
break
break
return events[:5] # Max 5 events
def _call_judge(prompt: str, entries: list[dict] | None = None) -> list[dict]:
"""Call judge with fallback chain: claude CLI → ollama → heuristic."""
# Try claude CLI first
result = _call_judge_claude(prompt)
if result is not None:
return result
# Try ollama
result = _call_judge_ollama(prompt)
if result is not None:
return result
# Fall back to heuristic (only for deep extraction with entries)
if entries:
return _call_judge_heuristic(entries)
return []
def _format_conversation(entries: list[dict]) -> str:
"""Format conversation entries for the judge prompt."""
parts = []
for entry in entries:
role = "USER" if entry["role"] == "user" else "ASSISTANT"
parts.append(f"[{role}]: {entry['content']}")
return "\n\n".join(parts)
def _store_events(events: list[dict], extracted_hashes: list[str]) -> list[str]:
"""Store extracted events, return new hashes."""
category_map = {
"correction": "preferences",
"preference": "preferences",
"decision": "decisions",
"fact": "facts",
"debugging": "decisions",
"workaround": "decisions",
"operational": "facts",
}
new_hashes = []
for event in events:
content = event.get("content", "")
if not content:
continue
# Deduplication: skip if we've already extracted this
h = _content_hash(content)
if h in extracted_hashes:
continue
event_type = event.get("type", "fact")
importance = max(0.0, min(1.0, float(event.get("importance", 0.7))))
category = category_map.get(event_type, "facts")
tags = event.get("tags", f"auto-learned,{event_type}")
if "auto-learned" not in tags:
tags = f"auto-learned,{tags}"
expanded_keywords = event.get("expanded_keywords", "")
# Store to memory API or SQLite
try:
if API_KEY and API_BASE_URL:
_store_via_api(content, category, tags, importance, expanded_keywords)
else:
_store_via_sqlite(content, category, tags, importance, expanded_keywords)
except Exception:
pass
new_hashes.append(h)
return new_hashes
def main() -> None:
if not shutil.which("claude"):
return
try:
hook_input = json.load(sys.stdin)
except (json.JSONDecodeError, EOFError):
return
if isinstance(hook_input, dict) and hook_input.get("stop_hook_active", False):
return
transcript_path = ""
session_id = ""
if isinstance(hook_input, dict):
transcript_path = hook_input.get("transcript_path", "")
session_id = hook_input.get("session_id", "")
if not transcript_path or not os.path.exists(transcript_path):
return
# Derive session ID from transcript path if not provided
if not session_id:
session_id = hashlib.sha256(transcript_path.encode()).hexdigest()[:16]
# Load state
state = _load_state(session_id)
state["turn_count"] = state.get("turn_count", 0) + 1
turn_count = state["turn_count"]
last_deep_turn = state.get("last_deep_turn", 0)
extracted_hashes = state.get("extracted_hashes", [])
# Decide: single-turn (cheap) or deep (multi-turn) extraction
turns_since_deep = turn_count - last_deep_turn
do_deep = turns_since_deep >= DEEP_EXTRACTION_INTERVAL
if do_deep:
# Deep extraction: read last N exchanges
entries = _parse_transcript(transcript_path, max_exchanges=DEEP_WINDOW_SIZE)
if len(entries) < 2:
_save_state(session_id, state)
return
# Count actual exchanges
n_exchanges = sum(1 for e in entries if e["role"] == "user")
conversation = _format_conversation(entries)
prompt = DEEP_EXTRACTION_PROMPT.format(
n_exchanges=n_exchanges,
conversation=conversation[:8000], # Cap total context
)
events = _call_judge(prompt, entries)
state["last_deep_turn"] = turn_count
else:
# Single-turn extraction: just the last exchange
entries = _parse_transcript(transcript_path, max_exchanges=1)
if len(entries) < 2:
_save_state(session_id, state)
return
user_msg = ""
assistant_msg = ""
for entry in entries:
if entry["role"] == "user":
user_msg = entry["content"]
elif entry["role"] == "assistant":
assistant_msg = entry["content"]
if not user_msg or len(user_msg.strip()) < 10:
_save_state(session_id, state)
return
prompt = SINGLE_TURN_PROMPT.format(
user_message=user_msg,
assistant_response=assistant_msg[:2000],
)
events = _call_judge(prompt, entries)
# Store events
if events:
new_hashes = _store_events(events, extracted_hashes)
extracted_hashes.extend(new_hashes)
# Keep hash list bounded
if len(extracted_hashes) > 200:
extracted_hashes = extracted_hashes[-200:]
state["extracted_hashes"] = extracted_hashes
_save_state(session_id, state)
# Periodic cleanup of old state files
if turn_count % 20 == 0:
_cleanup_old_state()
if __name__ == "__main__":
main()