examples: service.upsert_example + summary_for_country
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful

This commit is contained in:
Viktor Barzin 2026-05-28 22:28:53 +00:00
parent 0d442de918
commit a10d7fe2a6
2 changed files with 194 additions and 0 deletions

View file

@ -0,0 +1,111 @@
"""Persistence + read-side queries for fire_example.
`upsert_example(...)` does an INSERT ... ON CONFLICT DO NOTHING by
reddit_id. Returns True when a new row was inserted, False when it was
already present (idempotent re-runs are a feature, not a bug).
`summary_for_country(...)` computes count + median/p25/p75 of
portfolio_gbp + annual_exp_gbp + up to 5 sample post URLs. Runs as
plain SQL SQLAlchemy expression API so it works on both Postgres
and SQLite (which the tests use).
"""
from __future__ import annotations
import logging
import statistics
from decimal import Decimal
from typing import Any, cast
from sqlalchemy import CursorResult, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.db import FireExample
from fire_planner.examples.llm_extract import to_gbp
from fire_planner.examples.models import ExtractedExample, RawPost, Summary, SummaryStats
log = logging.getLogger(__name__)
EXCERPT_LEN = 500
def _dialect_insert(session: AsyncSession) -> Any:
bind = session.get_bind()
if bind.dialect.name == "sqlite":
return sqlite_insert
return pg_insert
async def upsert_example(
session: AsyncSession,
post: RawPost,
extracted: ExtractedExample,
fx_rates: dict[str, Decimal],
) -> bool:
"""INSERT ... ON CONFLICT DO NOTHING. Returns True on insert, False on conflict."""
portfolio_gbp = to_gbp(extracted.portfolio_native, extracted.raw_currency, fx_rates)
annual_exp_gbp = to_gbp(extracted.annual_exp_native, extracted.raw_currency, fx_rates)
values = {
"reddit_id": post.reddit_id,
"source_sub": post.source_sub,
"post_url": post.url,
"post_date": post.created_at,
"post_title": post.title,
"country": extracted.country,
"city": extracted.city,
"portfolio_gbp": portfolio_gbp,
"annual_exp_gbp": annual_exp_gbp,
"age": extracted.age,
"family_size": extracted.family_size,
"fi_status": str(extracted.fi_status) if extracted.fi_status else None,
"is_retired": extracted.is_retired,
"raw_currency": extracted.raw_currency,
"raw_excerpt": (post.title + "\n" + post.body)[:EXCERPT_LEN],
"llm_model": extracted.llm_model,
"llm_confidence": extracted.confidence,
}
insert_fn = _dialect_insert(session)
stmt = insert_fn(FireExample).values(**values)
stmt = stmt.on_conflict_do_nothing(index_elements=["reddit_id"])
# session.execute returns a CursorResult for DML; cast so mypy can see
# .rowcount (the base Result class hides it).
result = cast(CursorResult[Any], await session.execute(stmt))
await session.commit()
return (result.rowcount or 0) > 0
def _quartiles(values: list[Decimal]) -> SummaryStats:
if not values:
return SummaryStats(median=None, p25=None, p75=None)
floats = [float(v) for v in values]
median = statistics.median(floats)
if len(floats) >= 4:
# method="inclusive" matches NumPy/pandas linear interpolation —
# the test expects p25/p75 to land on actual data points when n is
# small (e.g. p25=200_000 for [100k, 200k, 300k, 400k, 500k]).
quants = statistics.quantiles(floats, n=4, method="inclusive")
p25, p75 = quants[0], quants[2]
else:
# Too few samples for quartiles; fall back to min/max bounds.
p25, p75 = min(floats), max(floats)
return SummaryStats(
median=Decimal(f"{median:.2f}"),
p25=Decimal(f"{p25:.2f}"),
p75=Decimal(f"{p75:.2f}"),
)
async def summary_for_country(session: AsyncSession, country: str) -> Summary:
stmt = select(FireExample).where(FireExample.country == country)
rows = (await session.execute(stmt)).scalars().all()
portfolios = [r.portfolio_gbp for r in rows if r.portfolio_gbp is not None]
expenses = [r.annual_exp_gbp for r in rows if r.annual_exp_gbp is not None]
sample_links = [r.post_url for r in rows[:5]]
return Summary(
country=country,
count=len(rows),
portfolio_gbp=_quartiles(portfolios),
annual_exp_gbp=_quartiles(expenses),
sample_links=sample_links,
)

View file

@ -0,0 +1,83 @@
"""Tests for service.upsert_example and service.summary_for_country."""
from __future__ import annotations
from datetime import date
from decimal import Decimal
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.db import FireExample
from fire_planner.examples.models import ExtractedExample, FiStatus, RawPost
from fire_planner.examples.service import summary_for_country, upsert_example
def _post(reddit_id: str = "abc1") -> RawPost:
return RawPost(
reddit_id=reddit_id,
source_sub="ExpatFIRE",
url=f"https://reddit.com/{reddit_id}",
title="t",
body="b",
created_at=date(2026, 1, 1),
)
def _ex(conf: Decimal = Decimal("0.8")) -> ExtractedExample:
return ExtractedExample(
country="Philippines",
city="Manila",
portfolio_native=Decimal("1200000"),
annual_exp_native=Decimal("18000"),
raw_currency="USD",
age=38,
family_size=3,
fi_status=FiStatus.FIRE,
is_retired=True,
confidence=conf,
llm_model="qwen3-8b",
)
@pytest.mark.asyncio
async def test_upsert_inserts_new_row(session: AsyncSession) -> None:
rates = {"GBP": Decimal("1"), "USD": Decimal("0.80")}
inserted = await upsert_example(session, _post(), _ex(), rates)
assert inserted is True
rows = (await session.execute(select(FireExample))).scalars().all()
assert len(rows) == 1
assert rows[0].portfolio_gbp == Decimal("960000.00")
assert rows[0].country == "Philippines"
@pytest.mark.asyncio
async def test_upsert_is_idempotent_by_reddit_id(session: AsyncSession) -> None:
rates = {"GBP": Decimal("1"), "USD": Decimal("0.80")}
await upsert_example(session, _post("abc1"), _ex(), rates)
inserted = await upsert_example(session, _post("abc1"), _ex(), rates)
assert inserted is False # second call is no-op
rows = (await session.execute(select(FireExample))).scalars().all()
assert len(rows) == 1
@pytest.mark.asyncio
async def test_summary_for_country_returns_quartiles(session: AsyncSession) -> None:
rates = {"GBP": Decimal("1"), "USD": Decimal("1")}
portfolios = [100_000, 200_000, 300_000, 400_000, 500_000]
for i, p in enumerate(portfolios):
ex = ExtractedExample(
country="Philippines",
portfolio_native=Decimal(p),
raw_currency="GBP",
confidence=Decimal("0.9"),
llm_model="qwen3-8b",
)
await upsert_example(session, _post(f"id{i}"), ex, rates)
summary = await summary_for_country(session, "Philippines")
assert summary.count == 5
assert summary.portfolio_gbp.median == Decimal("300000.00")
assert summary.portfolio_gbp.p25 == Decimal("200000.00")
assert summary.portfolio_gbp.p75 == Decimal("400000.00")
assert len(summary.sample_links) <= 5