fire-planner/fire_planner/examples/llm_extract.py

205 lines
6.4 KiB
Python

"""LLM extraction — primary qwen3-8b via llama-cpp, Tier 2 fallback to
claude-agent-service when qwen confidence is low or JSON unparseable.
Both backends speak the OpenAI-compatible chat-completions API. We
issue a strict JSON-schema prompt and parse the first `choices[0]`
message into `ExtractedExample`. Tier 2 escalation lives in
`extract_with_fallback` (added in Task 8) — primary failure is silent
(returns None) so the orchestrator can choose to escalate or skip.
"""
from __future__ import annotations
import json
import logging
import os
from decimal import Decimal, InvalidOperation
from typing import Any
import httpx
from pydantic import ValidationError
from fire_planner.examples.models import ExtractedExample, RawPost
log = logging.getLogger(__name__)
# `LLM_MODEL` lets the deployment swap to a smaller model when the GPU is
# contested. Default stays on qwen3-8b for local dev / tests. The "qwen" name
# in the constant is historical — the value can be any llama-swap model id
# (e.g. `qwen3vl-4b` when k8s-node1's VRAM is mostly held by immich-ml).
QWEN_MODEL = os.environ.get("LLM_MODEL", "qwen3-8b")
CLAUDE_AGENT_MODEL = "claude-haiku-4-5"
HTTP_TIMEOUT = httpx.Timeout(60.0)
PROMPT_SYSTEM = (
"You are extracting structured FIRE-example data from a Reddit post. "
"Output ONLY a single JSON object with these keys (use null when the "
"post does not say): country, city, portfolio_native (number), "
"annual_exp_native (number), raw_currency (3-letter ISO), age (int), "
"family_size (int, default 1 if single), fi_status (one of: "
"accumulating, coastFIRE, baristaFIRE, leanFIRE, FIRE, fatFIRE, "
"unknown), is_retired (bool), confidence (0.0-1.0). "
"DO NOT include any prose or markdown — JSON only."
)
def _user_prompt(post: RawPost) -> str:
return (
f"Subreddit: {post.source_sub}\n"
f"Title: {post.title}\n"
f"Body:\n{post.body[:4000]}"
)
async def extract_with_qwen(
post: RawPost,
llama_url: str,
client: httpx.AsyncClient,
) -> ExtractedExample | None:
"""Call qwen3-8b via llama-cpp. Returns None on any failure."""
return await _call_openai_chat(
url=llama_url,
model_name=QWEN_MODEL,
post=post,
client=client,
record_model=QWEN_MODEL,
)
async def _call_openai_chat(
*,
url: str,
model_name: str,
post: RawPost,
client: httpx.AsyncClient,
record_model: str,
extra_headers: dict[str, str] | None = None,
) -> ExtractedExample | None:
body = {
"model": model_name,
"messages": [
{"role": "system", "content": PROMPT_SYSTEM},
{"role": "user", "content": _user_prompt(post)},
],
"temperature": 0.0,
"max_tokens": 512,
}
try:
resp = await client.post(
url,
json=body,
timeout=HTTP_TIMEOUT,
headers=extra_headers,
)
resp.raise_for_status()
except httpx.HTTPError:
log.warning("LLM call failed for %s via %s", post.reddit_id, url, exc_info=True)
return None
try:
content: str = resp.json()["choices"][0]["message"]["content"]
except (KeyError, IndexError, ValueError):
log.warning("Unexpected LLM response shape for %s", post.reddit_id)
return None
return _parse_extracted_json(content, record_model)
def _parse_extracted_json(content: str, record_model: str) -> ExtractedExample | None:
"""Tolerant JSON parser — strip fences, parse, validate."""
cleaned = (
content.strip()
.removeprefix("```json")
.removeprefix("```")
.removesuffix("```")
.strip()
)
try:
data: dict[str, Any] = json.loads(cleaned)
except json.JSONDecodeError:
log.warning("LLM returned unparseable JSON: %s", cleaned[:200])
return None
# Convert numeric fields to Decimal where present.
for k in ("portfolio_native", "annual_exp_native", "confidence"):
if data.get(k) is not None:
try:
data[k] = Decimal(str(data[k]))
except InvalidOperation:
data[k] = None
data["llm_model"] = record_model
try:
return ExtractedExample.model_validate(data)
except ValidationError:
log.warning("LLM JSON failed schema validation: %s", cleaned[:200])
return None
DEFAULT_CONFIDENCE_THRESHOLD = Decimal("0.5")
async def extract_with_claude(
post: RawPost,
claude_url: str,
bearer: str,
client: httpx.AsyncClient,
) -> ExtractedExample | None:
"""Call claude-agent-service. Returns None on any failure."""
return await _call_openai_chat(
url=claude_url,
model_name=CLAUDE_AGENT_MODEL,
post=post,
client=client,
record_model=CLAUDE_AGENT_MODEL,
extra_headers={"Authorization": f"Bearer {bearer}"},
)
async def extract_with_fallback(
post: RawPost,
*,
llama_url: str,
claude_url: str,
claude_bearer: str,
client: httpx.AsyncClient,
confidence_threshold: Decimal = DEFAULT_CONFIDENCE_THRESHOLD,
) -> ExtractedExample | None:
"""Try qwen first; escalate to claude on failure or low confidence.
Returns None only when both backends fail (the orchestrator drops
the post and increments `fire_examples_extract_failed_total`).
"""
primary = await extract_with_qwen(post, llama_url=llama_url, client=client)
if primary is not None and primary.confidence >= confidence_threshold:
return primary
log.info(
"Escalating %s to Tier 2 (primary=%s)",
post.reddit_id,
"none" if primary is None else f"conf={primary.confidence}",
)
secondary = await extract_with_claude(
post,
claude_url=claude_url,
bearer=claude_bearer,
client=client,
)
return secondary or primary
def to_gbp(
amount: Decimal | None,
currency: str | None,
rates: dict[str, Decimal],
) -> Decimal | None:
"""Convert `amount` in `currency` to GBP using `fx.fetch_rates` output.
`rates[X]` = "how much GBP one unit of X is worth" — the convention
used by `fire_planner/fx.py`. Returns None when amount/currency is
missing or the currency isn't in `rates`.
"""
if amount is None or currency is None:
return None
rate = rates.get(currency.upper())
if rate is None:
return None
return (amount * rate).quantize(Decimal("0.01"))