83 lines
2.8 KiB
Python
83 lines
2.8 KiB
Python
"""Tests for service.upsert_example and service.summary_for_country."""
|
|
from __future__ import annotations
|
|
|
|
from datetime import date
|
|
from decimal import Decimal
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from fire_planner.db import FireExample
|
|
from fire_planner.examples.models import ExtractedExample, FiStatus, RawPost
|
|
from fire_planner.examples.service import summary_for_country, upsert_example
|
|
|
|
|
|
def _post(reddit_id: str = "abc1") -> RawPost:
|
|
return RawPost(
|
|
reddit_id=reddit_id,
|
|
source_sub="ExpatFIRE",
|
|
url=f"https://reddit.com/{reddit_id}",
|
|
title="t",
|
|
body="b",
|
|
created_at=date(2026, 1, 1),
|
|
)
|
|
|
|
|
|
def _ex(conf: Decimal = Decimal("0.8")) -> ExtractedExample:
|
|
return ExtractedExample(
|
|
country="Philippines",
|
|
city="Manila",
|
|
portfolio_native=Decimal("1200000"),
|
|
annual_exp_native=Decimal("18000"),
|
|
raw_currency="USD",
|
|
age=38,
|
|
family_size=3,
|
|
fi_status=FiStatus.FIRE,
|
|
is_retired=True,
|
|
confidence=conf,
|
|
llm_model="qwen3-8b",
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upsert_inserts_new_row(session: AsyncSession) -> None:
|
|
rates = {"GBP": Decimal("1"), "USD": Decimal("0.80")}
|
|
inserted = await upsert_example(session, _post(), _ex(), rates)
|
|
assert inserted is True
|
|
rows = (await session.execute(select(FireExample))).scalars().all()
|
|
assert len(rows) == 1
|
|
assert rows[0].portfolio_gbp == Decimal("960000.00")
|
|
assert rows[0].country == "Philippines"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upsert_is_idempotent_by_reddit_id(session: AsyncSession) -> None:
|
|
rates = {"GBP": Decimal("1"), "USD": Decimal("0.80")}
|
|
await upsert_example(session, _post("abc1"), _ex(), rates)
|
|
inserted = await upsert_example(session, _post("abc1"), _ex(), rates)
|
|
assert inserted is False # second call is no-op
|
|
rows = (await session.execute(select(FireExample))).scalars().all()
|
|
assert len(rows) == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_summary_for_country_returns_quartiles(session: AsyncSession) -> None:
|
|
rates = {"GBP": Decimal("1"), "USD": Decimal("1")}
|
|
portfolios = [100_000, 200_000, 300_000, 400_000, 500_000]
|
|
for i, p in enumerate(portfolios):
|
|
ex = ExtractedExample(
|
|
country="Philippines",
|
|
portfolio_native=Decimal(p),
|
|
raw_currency="GBP",
|
|
confidence=Decimal("0.9"),
|
|
llm_model="qwen3-8b",
|
|
)
|
|
await upsert_example(session, _post(f"id{i}"), ex, rates)
|
|
|
|
summary = await summary_for_country(session, "Philippines")
|
|
assert summary.count == 5
|
|
assert summary.portfolio_gbp.median == Decimal("300000.00")
|
|
assert summary.portfolio_gbp.p25 == Decimal("200000.00")
|
|
assert summary.portfolio_gbp.p75 == Decimal("400000.00")
|
|
assert len(summary.sample_links) <= 5
|