191 lines
5.5 KiB
Python
191 lines
5.5 KiB
Python
"""Tests for LLM extraction — respx mocks the llama-cpp /completion endpoint."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import date
|
|
from decimal import Decimal
|
|
|
|
import httpx
|
|
import pytest
|
|
import respx
|
|
|
|
from fire_planner.examples.llm_extract import (
|
|
extract_with_fallback,
|
|
extract_with_qwen,
|
|
to_gbp,
|
|
)
|
|
from fire_planner.examples.models import RawPost
|
|
|
|
LLAMA_URL = "http://llama-cpp.llama-cpp.svc.cluster.local:8000/v1/chat/completions"
|
|
CLAUDE_URL = "http://claude-agent-service.claude-agent.svc.cluster.local:8080/v1/chat/completions"
|
|
|
|
|
|
def _post() -> RawPost:
|
|
return RawPost(
|
|
reddit_id="a1",
|
|
source_sub="ExpatFIRE",
|
|
url="u",
|
|
title="FIRE'd at 38 — Manila",
|
|
body="Net worth $1.2M, living in Manila with family of 3, retired last year.",
|
|
created_at=date(2026, 1, 1),
|
|
)
|
|
|
|
|
|
@respx.mock
|
|
@pytest.mark.asyncio
|
|
async def test_extract_with_qwen_parses_json_response() -> None:
|
|
payload = {
|
|
"country": "Philippines",
|
|
"city": "Manila",
|
|
"portfolio_native": 1200000,
|
|
"annual_exp_native": 18000,
|
|
"raw_currency": "USD",
|
|
"age": 38,
|
|
"family_size": 3,
|
|
"fi_status": "FIRE",
|
|
"is_retired": True,
|
|
"confidence": 0.85,
|
|
}
|
|
respx.post(LLAMA_URL).respond(
|
|
200,
|
|
json={"choices": [{"message": {"content": json.dumps(payload)}}]},
|
|
)
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
out = await extract_with_qwen(_post(), llama_url=LLAMA_URL, client=client)
|
|
|
|
assert out is not None
|
|
assert out.country == "Philippines"
|
|
assert out.portfolio_native == Decimal("1200000")
|
|
assert out.confidence == Decimal("0.85")
|
|
assert out.llm_model == "qwen3-8b"
|
|
|
|
|
|
@respx.mock
|
|
@pytest.mark.asyncio
|
|
async def test_extract_with_qwen_returns_none_on_unparseable_json() -> None:
|
|
respx.post(LLAMA_URL).respond(
|
|
200,
|
|
json={"choices": [{"message": {"content": "definitely not json"}}]},
|
|
)
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
out = await extract_with_qwen(_post(), llama_url=LLAMA_URL, client=client)
|
|
|
|
assert out is None
|
|
|
|
|
|
@respx.mock
|
|
@pytest.mark.asyncio
|
|
async def test_extract_with_qwen_returns_none_on_http_error() -> None:
|
|
respx.post(LLAMA_URL).respond(500)
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
out = await extract_with_qwen(_post(), llama_url=LLAMA_URL, client=client)
|
|
|
|
assert out is None
|
|
|
|
|
|
@respx.mock
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_escalates_when_qwen_returns_none() -> None:
|
|
respx.post(LLAMA_URL).respond(500) # qwen down
|
|
claude_payload = {
|
|
"country": "Philippines",
|
|
"city": "Manila",
|
|
"confidence": 0.95,
|
|
}
|
|
respx.post(CLAUDE_URL).respond(
|
|
200,
|
|
json={"choices": [{"message": {"content": json.dumps(claude_payload)}}]},
|
|
)
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
out = await extract_with_fallback(
|
|
_post(),
|
|
llama_url=LLAMA_URL,
|
|
claude_url=CLAUDE_URL,
|
|
claude_bearer="t",
|
|
client=client,
|
|
)
|
|
|
|
assert out is not None
|
|
assert out.llm_model == "claude-haiku-4-5"
|
|
assert out.country == "Philippines"
|
|
|
|
|
|
@respx.mock
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_escalates_on_low_confidence() -> None:
|
|
qwen_payload = {"country": None, "confidence": 0.2}
|
|
respx.post(LLAMA_URL).respond(
|
|
200,
|
|
json={"choices": [{"message": {"content": json.dumps(qwen_payload)}}]},
|
|
)
|
|
claude_payload = {"country": "Thailand", "city": "Bangkok", "confidence": 0.9}
|
|
respx.post(CLAUDE_URL).respond(
|
|
200,
|
|
json={"choices": [{"message": {"content": json.dumps(claude_payload)}}]},
|
|
)
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
out = await extract_with_fallback(
|
|
_post(),
|
|
llama_url=LLAMA_URL,
|
|
claude_url=CLAUDE_URL,
|
|
claude_bearer="t",
|
|
client=client,
|
|
confidence_threshold=Decimal("0.5"),
|
|
)
|
|
|
|
assert out is not None
|
|
assert out.country == "Thailand"
|
|
assert out.llm_model == "claude-haiku-4-5"
|
|
|
|
|
|
@respx.mock
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_keeps_high_confidence_qwen_result() -> None:
|
|
payload = {
|
|
"country": "Philippines",
|
|
"confidence": 0.9,
|
|
}
|
|
respx.post(LLAMA_URL).respond(
|
|
200,
|
|
json={"choices": [{"message": {"content": json.dumps(payload)}}]},
|
|
)
|
|
claude_route = respx.post(CLAUDE_URL).respond(
|
|
200,
|
|
json={"choices": [{"message": {"content": "{}"}}]},
|
|
)
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
out = await extract_with_fallback(
|
|
_post(),
|
|
llama_url=LLAMA_URL,
|
|
claude_url=CLAUDE_URL,
|
|
claude_bearer="t",
|
|
client=client,
|
|
confidence_threshold=Decimal("0.5"),
|
|
)
|
|
|
|
assert out is not None
|
|
assert out.llm_model == "qwen3-8b"
|
|
assert claude_route.called is False # high-confidence qwen → claude not hit
|
|
|
|
|
|
def test_to_gbp_converts_usd() -> None:
|
|
rates = {"GBP": Decimal("1"), "USD": Decimal("0.80")}
|
|
assert to_gbp(Decimal("100"), "USD", rates) == Decimal("80.00")
|
|
|
|
|
|
def test_to_gbp_passes_through_gbp() -> None:
|
|
assert to_gbp(Decimal("100"), "GBP", {"GBP": Decimal("1")}) == Decimal("100.00")
|
|
|
|
|
|
def test_to_gbp_returns_none_for_unknown_currency() -> None:
|
|
assert to_gbp(Decimal("100"), "XYZ", {"GBP": Decimal("1"), "USD": Decimal("0.8")}) is None
|
|
|
|
|
|
def test_to_gbp_returns_none_for_none_amount() -> None:
|
|
assert to_gbp(None, "USD", {"USD": Decimal("0.8")}) is None
|