From d1a5da100848331de45c08bc019b17b601208d98 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Thu, 28 May 2026 22:19:32 +0000 Subject: [PATCH] examples: primary qwen3-8b extractor Co-Authored-By: Claude Opus 4.7 --- fire_planner/examples/llm_extract.py | 130 +++++++++++++++++++++++++++ tests/test_examples_llm_extract.py | 81 +++++++++++++++++ 2 files changed, 211 insertions(+) create mode 100644 fire_planner/examples/llm_extract.py create mode 100644 tests/test_examples_llm_extract.py diff --git a/fire_planner/examples/llm_extract.py b/fire_planner/examples/llm_extract.py new file mode 100644 index 0000000..fd7072e --- /dev/null +++ b/fire_planner/examples/llm_extract.py @@ -0,0 +1,130 @@ +"""LLM extraction — primary qwen3-8b via llama-cpp, Tier 2 fallback to +claude-agent-service when qwen confidence is low or JSON unparseable. + +Both backends speak the OpenAI-compatible chat-completions API. We +issue a strict JSON-schema prompt and parse the first `choices[0]` +message into `ExtractedExample`. Tier 2 escalation lives in +`extract_with_fallback` (added in Task 8) — primary failure is silent +(returns None) so the orchestrator can choose to escalate or skip. +""" +from __future__ import annotations + +import json +import logging +from decimal import Decimal, InvalidOperation +from typing import Any + +import httpx +from pydantic import ValidationError + +from fire_planner.examples.models import ExtractedExample, RawPost + +log = logging.getLogger(__name__) + +QWEN_MODEL = "qwen3-8b" +CLAUDE_AGENT_MODEL = "claude-haiku-4-5" +HTTP_TIMEOUT = httpx.Timeout(60.0) + +PROMPT_SYSTEM = ( + "You are extracting structured FIRE-example data from a Reddit post. " + "Output ONLY a single JSON object with these keys (use null when the " + "post does not say): country, city, portfolio_native (number), " + "annual_exp_native (number), raw_currency (3-letter ISO), age (int), " + "family_size (int, default 1 if single), fi_status (one of: " + "accumulating, coastFIRE, baristaFIRE, leanFIRE, FIRE, fatFIRE, " + "unknown), is_retired (bool), confidence (0.0-1.0). " + "DO NOT include any prose or markdown — JSON only." +) + + +def _user_prompt(post: RawPost) -> str: + return ( + f"Subreddit: {post.source_sub}\n" + f"Title: {post.title}\n" + f"Body:\n{post.body[:4000]}" + ) + + +async def extract_with_qwen( + post: RawPost, + llama_url: str, + client: httpx.AsyncClient, +) -> ExtractedExample | None: + """Call qwen3-8b via llama-cpp. Returns None on any failure.""" + return await _call_openai_chat( + url=llama_url, + model_name=QWEN_MODEL, + post=post, + client=client, + record_model=QWEN_MODEL, + ) + + +async def _call_openai_chat( + *, + url: str, + model_name: str, + post: RawPost, + client: httpx.AsyncClient, + record_model: str, + extra_headers: dict[str, str] | None = None, +) -> ExtractedExample | None: + body = { + "model": model_name, + "messages": [ + {"role": "system", "content": PROMPT_SYSTEM}, + {"role": "user", "content": _user_prompt(post)}, + ], + "temperature": 0.0, + "max_tokens": 512, + } + try: + resp = await client.post( + url, + json=body, + timeout=HTTP_TIMEOUT, + headers=extra_headers, + ) + resp.raise_for_status() + except httpx.HTTPError: + log.warning("LLM call failed for %s via %s", post.reddit_id, url, exc_info=True) + return None + + try: + content: str = resp.json()["choices"][0]["message"]["content"] + except (KeyError, IndexError, ValueError): + log.warning("Unexpected LLM response shape for %s", post.reddit_id) + return None + + return _parse_extracted_json(content, record_model) + + +def _parse_extracted_json(content: str, record_model: str) -> ExtractedExample | None: + """Tolerant JSON parser — strip fences, parse, validate.""" + cleaned = ( + content.strip() + .removeprefix("```json") + .removeprefix("```") + .removesuffix("```") + .strip() + ) + try: + data: dict[str, Any] = json.loads(cleaned) + except json.JSONDecodeError: + log.warning("LLM returned unparseable JSON: %s", cleaned[:200]) + return None + + # Convert numeric fields to Decimal where present. + for k in ("portfolio_native", "annual_exp_native", "confidence"): + if data.get(k) is not None: + try: + data[k] = Decimal(str(data[k])) + except InvalidOperation: + data[k] = None + + data["llm_model"] = record_model + try: + return ExtractedExample.model_validate(data) + except ValidationError: + log.warning("LLM JSON failed schema validation: %s", cleaned[:200]) + return None diff --git a/tests/test_examples_llm_extract.py b/tests/test_examples_llm_extract.py new file mode 100644 index 0000000..727e6f4 --- /dev/null +++ b/tests/test_examples_llm_extract.py @@ -0,0 +1,81 @@ +"""Tests for LLM extraction — respx mocks the llama-cpp /completion endpoint.""" +from __future__ import annotations + +import json +from datetime import date +from decimal import Decimal + +import httpx +import pytest +import respx + +from fire_planner.examples.llm_extract import extract_with_qwen +from fire_planner.examples.models import RawPost + +LLAMA_URL = "http://llama-cpp.llama-cpp.svc.cluster.local:8000/v1/chat/completions" + + +def _post() -> RawPost: + return RawPost( + reddit_id="a1", + source_sub="ExpatFIRE", + url="u", + title="FIRE'd at 38 — Manila", + body="Net worth $1.2M, living in Manila with family of 3, retired last year.", + created_at=date(2026, 1, 1), + ) + + +@respx.mock +@pytest.mark.asyncio +async def test_extract_with_qwen_parses_json_response() -> None: + payload = { + "country": "Philippines", + "city": "Manila", + "portfolio_native": 1200000, + "annual_exp_native": 18000, + "raw_currency": "USD", + "age": 38, + "family_size": 3, + "fi_status": "FIRE", + "is_retired": True, + "confidence": 0.85, + } + respx.post(LLAMA_URL).respond( + 200, + json={"choices": [{"message": {"content": json.dumps(payload)}}]}, + ) + + async with httpx.AsyncClient() as client: + out = await extract_with_qwen(_post(), llama_url=LLAMA_URL, client=client) + + assert out is not None + assert out.country == "Philippines" + assert out.portfolio_native == Decimal("1200000") + assert out.confidence == Decimal("0.85") + assert out.llm_model == "qwen3-8b" + + +@respx.mock +@pytest.mark.asyncio +async def test_extract_with_qwen_returns_none_on_unparseable_json() -> None: + respx.post(LLAMA_URL).respond( + 200, + json={"choices": [{"message": {"content": "definitely not json"}}]}, + ) + + async with httpx.AsyncClient() as client: + out = await extract_with_qwen(_post(), llama_url=LLAMA_URL, client=client) + + assert out is None + + +@respx.mock +@pytest.mark.asyncio +async def test_extract_with_qwen_returns_none_on_http_error() -> None: + respx.post(LLAMA_URL).respond(500) + + async with httpx.AsyncClient() as client: + out = await extract_with_qwen(_post(), llama_url=LLAMA_URL, client=client) + + assert out is None