52 lines
1.8 KiB
Python
52 lines
1.8 KiB
Python
"""Regression suite for LLM extraction — drives the extractor against
|
|
hand-curated fixtures and asserts the parsed JSON matches expectations.
|
|
|
|
Each fixture is `{post: RawPost, expected: dict}`. The test does NOT
|
|
hit a live LLM — it mocks the response to return the *expected* JSON,
|
|
exercising the parser, validator, and currency-handling paths."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
import pytest
|
|
import respx
|
|
|
|
from fire_planner.examples.llm_extract import extract_with_qwen
|
|
from fire_planner.examples.models import RawPost
|
|
|
|
LLAMA_URL = "http://llama-cpp.llama-cpp.svc.cluster.local:8000/v1/chat/completions"
|
|
|
|
FIXTURE_DIR = Path(__file__).parent / "fixtures" / "reddit"
|
|
|
|
|
|
def _fixtures() -> list[Path]:
|
|
return sorted(FIXTURE_DIR.glob("example_*.json"))
|
|
|
|
|
|
@respx.mock
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("fixture_path", _fixtures(), ids=lambda p: p.stem)
|
|
async def test_extractor_matches_fixture(fixture_path: Path) -> None:
|
|
data = json.loads(fixture_path.read_text())
|
|
post = RawPost.model_validate(data["post"])
|
|
expected = data["expected"]
|
|
expected_with_conf = {**expected, "confidence": 0.9}
|
|
|
|
respx.post(LLAMA_URL).respond(
|
|
200,
|
|
json={"choices": [{"message": {"content": json.dumps(expected_with_conf)}}]},
|
|
)
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
out = await extract_with_qwen(post, llama_url=LLAMA_URL, client=client)
|
|
|
|
assert out is not None
|
|
for k, v in expected.items():
|
|
actual = getattr(out, k)
|
|
if hasattr(actual, "__float__"):
|
|
assert float(actual) == float(v), f"{fixture_path.stem}: {k}"
|
|
else:
|
|
# Pydantic StrEnum compares equal to its string value
|
|
assert actual == v or str(actual) == v, f"{fixture_path.stem}: {k}"
|