diff --git a/fire_planner/examples/llm_extract.py b/fire_planner/examples/llm_extract.py index fd7072e..6e65b23 100644 --- a/fire_planner/examples/llm_extract.py +++ b/fire_planner/examples/llm_extract.py @@ -128,3 +128,54 @@ def _parse_extracted_json(content: str, record_model: str) -> ExtractedExample | except ValidationError: log.warning("LLM JSON failed schema validation: %s", cleaned[:200]) return None + + +DEFAULT_CONFIDENCE_THRESHOLD = Decimal("0.5") + + +async def extract_with_claude( + post: RawPost, + claude_url: str, + bearer: str, + client: httpx.AsyncClient, +) -> ExtractedExample | None: + """Call claude-agent-service. Returns None on any failure.""" + return await _call_openai_chat( + url=claude_url, + model_name=CLAUDE_AGENT_MODEL, + post=post, + client=client, + record_model=CLAUDE_AGENT_MODEL, + extra_headers={"Authorization": f"Bearer {bearer}"}, + ) + + +async def extract_with_fallback( + post: RawPost, + *, + llama_url: str, + claude_url: str, + claude_bearer: str, + client: httpx.AsyncClient, + confidence_threshold: Decimal = DEFAULT_CONFIDENCE_THRESHOLD, +) -> ExtractedExample | None: + """Try qwen first; escalate to claude on failure or low confidence. + + Returns None only when both backends fail (the orchestrator drops + the post and increments `fire_examples_extract_failed_total`). + """ + primary = await extract_with_qwen(post, llama_url=llama_url, client=client) + if primary is not None and primary.confidence >= confidence_threshold: + return primary + log.info( + "Escalating %s to Tier 2 (primary=%s)", + post.reddit_id, + "none" if primary is None else f"conf={primary.confidence}", + ) + secondary = await extract_with_claude( + post, + claude_url=claude_url, + bearer=claude_bearer, + client=client, + ) + return secondary or primary diff --git a/tests/test_examples_llm_extract.py b/tests/test_examples_llm_extract.py index 727e6f4..4ebc5ac 100644 --- a/tests/test_examples_llm_extract.py +++ b/tests/test_examples_llm_extract.py @@ -9,10 +9,14 @@ import httpx import pytest import respx -from fire_planner.examples.llm_extract import extract_with_qwen +from fire_planner.examples.llm_extract import ( + extract_with_fallback, + extract_with_qwen, +) from fire_planner.examples.models import RawPost LLAMA_URL = "http://llama-cpp.llama-cpp.svc.cluster.local:8000/v1/chat/completions" +CLAUDE_URL = "http://claude-agent-service.claude-agent.svc.cluster.local:8080/v1/chat/completions" def _post() -> RawPost: @@ -79,3 +83,91 @@ async def test_extract_with_qwen_returns_none_on_http_error() -> None: out = await extract_with_qwen(_post(), llama_url=LLAMA_URL, client=client) assert out is None + + +@respx.mock +@pytest.mark.asyncio +async def test_fallback_escalates_when_qwen_returns_none() -> None: + respx.post(LLAMA_URL).respond(500) # qwen down + claude_payload = { + "country": "Philippines", + "city": "Manila", + "confidence": 0.95, + } + respx.post(CLAUDE_URL).respond( + 200, + json={"choices": [{"message": {"content": json.dumps(claude_payload)}}]}, + ) + + async with httpx.AsyncClient() as client: + out = await extract_with_fallback( + _post(), + llama_url=LLAMA_URL, + claude_url=CLAUDE_URL, + claude_bearer="t", + client=client, + ) + + assert out is not None + assert out.llm_model == "claude-haiku-4-5" + assert out.country == "Philippines" + + +@respx.mock +@pytest.mark.asyncio +async def test_fallback_escalates_on_low_confidence() -> None: + qwen_payload = {"country": None, "confidence": 0.2} + respx.post(LLAMA_URL).respond( + 200, + json={"choices": [{"message": {"content": json.dumps(qwen_payload)}}]}, + ) + claude_payload = {"country": "Thailand", "city": "Bangkok", "confidence": 0.9} + respx.post(CLAUDE_URL).respond( + 200, + json={"choices": [{"message": {"content": json.dumps(claude_payload)}}]}, + ) + + async with httpx.AsyncClient() as client: + out = await extract_with_fallback( + _post(), + llama_url=LLAMA_URL, + claude_url=CLAUDE_URL, + claude_bearer="t", + client=client, + confidence_threshold=Decimal("0.5"), + ) + + assert out is not None + assert out.country == "Thailand" + assert out.llm_model == "claude-haiku-4-5" + + +@respx.mock +@pytest.mark.asyncio +async def test_fallback_keeps_high_confidence_qwen_result() -> None: + payload = { + "country": "Philippines", + "confidence": 0.9, + } + respx.post(LLAMA_URL).respond( + 200, + json={"choices": [{"message": {"content": json.dumps(payload)}}]}, + ) + claude_route = respx.post(CLAUDE_URL).respond( + 200, + json={"choices": [{"message": {"content": "{}"}}]}, + ) + + async with httpx.AsyncClient() as client: + out = await extract_with_fallback( + _post(), + llama_url=LLAMA_URL, + claude_url=CLAUDE_URL, + claude_bearer="t", + client=client, + confidence_threshold=Decimal("0.5"), + ) + + assert out is not None + assert out.llm_model == "qwen3-8b" + assert claude_route.called is False # high-confidence qwen → claude not hit