From 445692229437f2ba2199564853a469a99cf90827 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Sun, 15 Mar 2026 11:15:14 +0000 Subject: [PATCH] =?UTF-8?q?add=20fallback=20chain=20for=20judge:=20claude?= =?UTF-8?q?=20CLI=20=E2=86=92=20ollama=20=E2=86=92=20heuristic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - claude CLI: run from /tmp to avoid internet-mode-used marker prompts - ollama: try small local models (qwen2.5:3b, llama3.2:3b, etc.) - heuristic: pattern matching for corrections, preferences, decisions - better JSON extraction: handles markdown fences and surrounding text --- hooks/auto-learn.py | 168 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 151 insertions(+), 17 deletions(-) diff --git a/hooks/auto-learn.py b/hooks/auto-learn.py index bc9d688..69cc530 100644 --- a/hooks/auto-learn.py +++ b/hooks/auto-learn.py @@ -173,10 +173,12 @@ def _parse_transcript(transcript_path: str, max_exchanges: int = 1) -> list[dict entry = json.loads(line) except json.JSONDecodeError: continue - role = entry.get("role", "") + # Transcript format: role can be at top level or nested in message + msg = entry.get("message", entry) + role = msg.get("role", "") or entry.get("type", "") if role not in ("user", "assistant"): continue - content = entry.get("content", "") + content = msg.get("content", "") if isinstance(content, list): content = " ".join( b.get("text", "") for b in content @@ -277,26 +279,158 @@ def _append_to_auto_memory(content: str, event_type: str) -> None: f.write(f"- [{now}] **{event_type}**: {content}\n") -def _call_judge(prompt: str) -> list[dict]: - """Call haiku as judge and return extracted events.""" +def _parse_llm_response(response_text: str) -> list[dict]: + """Parse LLM response text into events list.""" + response_text = response_text.strip() + # Strip markdown code fences if present + if response_text.startswith("```"): + lines = response_text.split("\n") + lines = [l for l in lines if not l.strip().startswith("```")] + response_text = "\n".join(lines).strip() + # Try to extract JSON from the response + # Sometimes the LLM adds text before/after the JSON + start = response_text.find("{") + end = response_text.rfind("}") + 1 + if start >= 0 and end > start: + response_text = response_text[start:end] + judge_result = json.loads(response_text) + return judge_result.get("events", []) + + +def _call_judge_claude(prompt: str) -> list[dict] | None: + """Try claude CLI as judge. Returns None if unavailable.""" + if not shutil.which("claude"): + return None try: result = subprocess.run( ["claude", "-p", prompt, "--model", "haiku"], - capture_output=True, text=True, timeout=45, + capture_output=True, text=True, timeout=60, + # Run from /tmp to avoid internet-mode-used marker prompts + # Clear CLAUDECODE to prevent recursion + cwd="/tmp", env={**os.environ, "CLAUDECODE": ""}, ) if result.returncode != 0: - return [] - response_text = result.stdout.strip() - # Strip markdown code fences if present - if response_text.startswith("```"): - lines = response_text.split("\n") - lines = [l for l in lines if not l.strip().startswith("```")] - response_text = "\n".join(lines).strip() - judge_result = json.loads(response_text) - return judge_result.get("events", []) + return None + return _parse_llm_response(result.stdout) except (subprocess.TimeoutExpired, json.JSONDecodeError, OSError): - return [] + return None + + +def _call_judge_ollama(prompt: str) -> list[dict] | None: + """Try local ollama as judge. Returns None if unavailable.""" + ollama_url = os.environ.get("OLLAMA_HOST", "http://localhost:11434") + # Prefer small models for speed + models_to_try = ["qwen2.5:3b", "llama3.2:3b", "gemma2:2b", "phi3:mini"] + for model in models_to_try: + try: + data = json.dumps({ + "model": model, + "prompt": prompt, + "stream": False, + "options": {"temperature": 0, "num_predict": 512}, + }).encode() + req = urllib.request.Request( + f"{ollama_url}/api/generate", + data=data, method="POST", + headers={"Content-Type": "application/json"}, + ) + with urllib.request.urlopen(req, timeout=30) as resp: + result = json.loads(resp.read().decode()) + return _parse_llm_response(result.get("response", "")) + except Exception: + continue + return None + + +def _call_judge_heuristic(entries: list[dict]) -> list[dict]: + """ + Heuristic fallback: extract learnings via pattern matching. + Less accurate than LLM but works without any external dependencies. + """ + events = [] + correction_patterns = [ + "actually", "that's wrong", "no,", "not correct", "instead of", + "don't use", "never use", "always use", "the correct way", + "the issue was", "the problem was", "root cause", + ] + preference_patterns = [ + "i prefer", "i like", "i want", "please always", "please never", + "remember to", "from now on", "going forward", + ] + decision_patterns = [ + "let's go with", "we decided", "the approach is", + "we'll use", "switched to", "migrated to", + ] + + for entry in entries: + if entry["role"] != "user": + continue + text_lower = entry["content"].lower() + + for pattern in correction_patterns: + if pattern in text_lower: + # Extract the sentence containing the pattern + for sentence in entry["content"].replace("\n", ". ").split(". "): + if pattern in sentence.lower() and len(sentence) > 20: + events.append({ + "type": "correction", + "content": sentence.strip()[:200], + "importance": 0.8, + "tags": "auto-learned,heuristic,correction", + "expanded_keywords": " ".join(sentence.lower().split()[:10]), + }) + break + break + + for pattern in preference_patterns: + if pattern in text_lower: + for sentence in entry["content"].replace("\n", ". ").split(". "): + if pattern in sentence.lower() and len(sentence) > 15: + events.append({ + "type": "preference", + "content": sentence.strip()[:200], + "importance": 0.7, + "tags": "auto-learned,heuristic,preference", + "expanded_keywords": " ".join(sentence.lower().split()[:10]), + }) + break + break + + for pattern in decision_patterns: + if pattern in text_lower: + for sentence in entry["content"].replace("\n", ". ").split(". "): + if pattern in sentence.lower() and len(sentence) > 20: + events.append({ + "type": "decision", + "content": sentence.strip()[:200], + "importance": 0.7, + "tags": "auto-learned,heuristic,decision", + "expanded_keywords": " ".join(sentence.lower().split()[:10]), + }) + break + break + + return events[:5] # Max 5 events + + +def _call_judge(prompt: str, entries: list[dict] | None = None) -> list[dict]: + """Call judge with fallback chain: claude CLI → ollama → heuristic.""" + # Try claude CLI first + result = _call_judge_claude(prompt) + if result is not None: + return result + + # Try ollama + result = _call_judge_ollama(prompt) + if result is not None: + return result + + # Fall back to heuristic (only for deep extraction with entries) + if entries: + return _call_judge_heuristic(entries) + + return [] def _format_conversation(entries: list[dict]) -> str: @@ -409,7 +543,7 @@ def main() -> None: n_exchanges=n_exchanges, conversation=conversation[:8000], # Cap total context ) - events = _call_judge(prompt) + events = _call_judge(prompt, entries) state["last_deep_turn"] = turn_count else: # Single-turn extraction: just the last exchange @@ -434,7 +568,7 @@ def main() -> None: user_message=user_msg, assistant_response=assistant_msg[:2000], ) - events = _call_judge(prompt) + events = _call_judge(prompt, entries) # Store events if events: