fire-planner/tests/test_examples_service.py

84 lines
2.8 KiB
Python
Raw Normal View History

"""Tests for service.upsert_example and service.summary_for_country."""
from __future__ import annotations
from datetime import date
from decimal import Decimal
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.db import FireExample
from fire_planner.examples.models import ExtractedExample, FiStatus, RawPost
from fire_planner.examples.service import summary_for_country, upsert_example
def _post(reddit_id: str = "abc1") -> RawPost:
return RawPost(
reddit_id=reddit_id,
source_sub="ExpatFIRE",
url=f"https://reddit.com/{reddit_id}",
title="t",
body="b",
created_at=date(2026, 1, 1),
)
def _ex(conf: Decimal = Decimal("0.8")) -> ExtractedExample:
return ExtractedExample(
country="Philippines",
city="Manila",
portfolio_native=Decimal("1200000"),
annual_exp_native=Decimal("18000"),
raw_currency="USD",
age=38,
family_size=3,
fi_status=FiStatus.FIRE,
is_retired=True,
confidence=conf,
llm_model="qwen3-8b",
)
@pytest.mark.asyncio
async def test_upsert_inserts_new_row(session: AsyncSession) -> None:
rates = {"GBP": Decimal("1"), "USD": Decimal("0.80")}
inserted = await upsert_example(session, _post(), _ex(), rates)
assert inserted is True
rows = (await session.execute(select(FireExample))).scalars().all()
assert len(rows) == 1
assert rows[0].portfolio_gbp == Decimal("960000.00")
assert rows[0].country == "Philippines"
@pytest.mark.asyncio
async def test_upsert_is_idempotent_by_reddit_id(session: AsyncSession) -> None:
rates = {"GBP": Decimal("1"), "USD": Decimal("0.80")}
await upsert_example(session, _post("abc1"), _ex(), rates)
inserted = await upsert_example(session, _post("abc1"), _ex(), rates)
assert inserted is False # second call is no-op
rows = (await session.execute(select(FireExample))).scalars().all()
assert len(rows) == 1
@pytest.mark.asyncio
async def test_summary_for_country_returns_quartiles(session: AsyncSession) -> None:
rates = {"GBP": Decimal("1"), "USD": Decimal("1")}
portfolios = [100_000, 200_000, 300_000, 400_000, 500_000]
for i, p in enumerate(portfolios):
ex = ExtractedExample(
country="Philippines",
portfolio_native=Decimal(p),
raw_currency="GBP",
confidence=Decimal("0.9"),
llm_model="qwen3-8b",
)
await upsert_example(session, _post(f"id{i}"), ex, rates)
summary = await summary_for_country(session, "Philippines")
assert summary.count == 5
assert summary.portfolio_gbp.median == Decimal("300000.00")
assert summary.portfolio_gbp.p25 == Decimal("200000.00")
assert summary.portfolio_gbp.p75 == Decimal("400000.00")
assert len(summary.sample_links) <= 5