examples: primary qwen3-8b extractor
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
8fc0fd7646
commit
d1a5da1008
2 changed files with 211 additions and 0 deletions
130
fire_planner/examples/llm_extract.py
Normal file
130
fire_planner/examples/llm_extract.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
"""LLM extraction — primary qwen3-8b via llama-cpp, Tier 2 fallback to
|
||||
claude-agent-service when qwen confidence is low or JSON unparseable.
|
||||
|
||||
Both backends speak the OpenAI-compatible chat-completions API. We
|
||||
issue a strict JSON-schema prompt and parse the first `choices[0]`
|
||||
message into `ExtractedExample`. Tier 2 escalation lives in
|
||||
`extract_with_fallback` (added in Task 8) — primary failure is silent
|
||||
(returns None) so the orchestrator can choose to escalate or skip.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import ValidationError
|
||||
|
||||
from fire_planner.examples.models import ExtractedExample, RawPost
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
QWEN_MODEL = "qwen3-8b"
|
||||
CLAUDE_AGENT_MODEL = "claude-haiku-4-5"
|
||||
HTTP_TIMEOUT = httpx.Timeout(60.0)
|
||||
|
||||
PROMPT_SYSTEM = (
|
||||
"You are extracting structured FIRE-example data from a Reddit post. "
|
||||
"Output ONLY a single JSON object with these keys (use null when the "
|
||||
"post does not say): country, city, portfolio_native (number), "
|
||||
"annual_exp_native (number), raw_currency (3-letter ISO), age (int), "
|
||||
"family_size (int, default 1 if single), fi_status (one of: "
|
||||
"accumulating, coastFIRE, baristaFIRE, leanFIRE, FIRE, fatFIRE, "
|
||||
"unknown), is_retired (bool), confidence (0.0-1.0). "
|
||||
"DO NOT include any prose or markdown — JSON only."
|
||||
)
|
||||
|
||||
|
||||
def _user_prompt(post: RawPost) -> str:
|
||||
return (
|
||||
f"Subreddit: {post.source_sub}\n"
|
||||
f"Title: {post.title}\n"
|
||||
f"Body:\n{post.body[:4000]}"
|
||||
)
|
||||
|
||||
|
||||
async def extract_with_qwen(
|
||||
post: RawPost,
|
||||
llama_url: str,
|
||||
client: httpx.AsyncClient,
|
||||
) -> ExtractedExample | None:
|
||||
"""Call qwen3-8b via llama-cpp. Returns None on any failure."""
|
||||
return await _call_openai_chat(
|
||||
url=llama_url,
|
||||
model_name=QWEN_MODEL,
|
||||
post=post,
|
||||
client=client,
|
||||
record_model=QWEN_MODEL,
|
||||
)
|
||||
|
||||
|
||||
async def _call_openai_chat(
|
||||
*,
|
||||
url: str,
|
||||
model_name: str,
|
||||
post: RawPost,
|
||||
client: httpx.AsyncClient,
|
||||
record_model: str,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> ExtractedExample | None:
|
||||
body = {
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": PROMPT_SYSTEM},
|
||||
{"role": "user", "content": _user_prompt(post)},
|
||||
],
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 512,
|
||||
}
|
||||
try:
|
||||
resp = await client.post(
|
||||
url,
|
||||
json=body,
|
||||
timeout=HTTP_TIMEOUT,
|
||||
headers=extra_headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPError:
|
||||
log.warning("LLM call failed for %s via %s", post.reddit_id, url, exc_info=True)
|
||||
return None
|
||||
|
||||
try:
|
||||
content: str = resp.json()["choices"][0]["message"]["content"]
|
||||
except (KeyError, IndexError, ValueError):
|
||||
log.warning("Unexpected LLM response shape for %s", post.reddit_id)
|
||||
return None
|
||||
|
||||
return _parse_extracted_json(content, record_model)
|
||||
|
||||
|
||||
def _parse_extracted_json(content: str, record_model: str) -> ExtractedExample | None:
|
||||
"""Tolerant JSON parser — strip fences, parse, validate."""
|
||||
cleaned = (
|
||||
content.strip()
|
||||
.removeprefix("```json")
|
||||
.removeprefix("```")
|
||||
.removesuffix("```")
|
||||
.strip()
|
||||
)
|
||||
try:
|
||||
data: dict[str, Any] = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
log.warning("LLM returned unparseable JSON: %s", cleaned[:200])
|
||||
return None
|
||||
|
||||
# Convert numeric fields to Decimal where present.
|
||||
for k in ("portfolio_native", "annual_exp_native", "confidence"):
|
||||
if data.get(k) is not None:
|
||||
try:
|
||||
data[k] = Decimal(str(data[k]))
|
||||
except InvalidOperation:
|
||||
data[k] = None
|
||||
|
||||
data["llm_model"] = record_model
|
||||
try:
|
||||
return ExtractedExample.model_validate(data)
|
||||
except ValidationError:
|
||||
log.warning("LLM JSON failed schema validation: %s", cleaned[:200])
|
||||
return None
|
||||
81
tests/test_examples_llm_extract.py
Normal file
81
tests/test_examples_llm_extract.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
"""Tests for LLM extraction — respx mocks the llama-cpp /completion endpoint."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from fire_planner.examples.llm_extract import extract_with_qwen
|
||||
from fire_planner.examples.models import RawPost
|
||||
|
||||
LLAMA_URL = "http://llama-cpp.llama-cpp.svc.cluster.local:8000/v1/chat/completions"
|
||||
|
||||
|
||||
def _post() -> RawPost:
|
||||
return RawPost(
|
||||
reddit_id="a1",
|
||||
source_sub="ExpatFIRE",
|
||||
url="u",
|
||||
title="FIRE'd at 38 — Manila",
|
||||
body="Net worth $1.2M, living in Manila with family of 3, retired last year.",
|
||||
created_at=date(2026, 1, 1),
|
||||
)
|
||||
|
||||
|
||||
@respx.mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_with_qwen_parses_json_response() -> None:
|
||||
payload = {
|
||||
"country": "Philippines",
|
||||
"city": "Manila",
|
||||
"portfolio_native": 1200000,
|
||||
"annual_exp_native": 18000,
|
||||
"raw_currency": "USD",
|
||||
"age": 38,
|
||||
"family_size": 3,
|
||||
"fi_status": "FIRE",
|
||||
"is_retired": True,
|
||||
"confidence": 0.85,
|
||||
}
|
||||
respx.post(LLAMA_URL).respond(
|
||||
200,
|
||||
json={"choices": [{"message": {"content": json.dumps(payload)}}]},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
out = await extract_with_qwen(_post(), llama_url=LLAMA_URL, client=client)
|
||||
|
||||
assert out is not None
|
||||
assert out.country == "Philippines"
|
||||
assert out.portfolio_native == Decimal("1200000")
|
||||
assert out.confidence == Decimal("0.85")
|
||||
assert out.llm_model == "qwen3-8b"
|
||||
|
||||
|
||||
@respx.mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_with_qwen_returns_none_on_unparseable_json() -> None:
|
||||
respx.post(LLAMA_URL).respond(
|
||||
200,
|
||||
json={"choices": [{"message": {"content": "definitely not json"}}]},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
out = await extract_with_qwen(_post(), llama_url=LLAMA_URL, client=client)
|
||||
|
||||
assert out is None
|
||||
|
||||
|
||||
@respx.mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_with_qwen_returns_none_on_http_error() -> None:
|
||||
respx.post(LLAMA_URL).respond(500)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
out = await extract_with_qwen(_post(), llama_url=LLAMA_URL, client=client)
|
||||
|
||||
assert out is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue