From a10d7fe2a67a719d3ebebc4a994d5c43725e6aa5 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Thu, 28 May 2026 22:28:53 +0000 Subject: [PATCH] examples: service.upsert_example + summary_for_country --- fire_planner/examples/service.py | 111 +++++++++++++++++++++++++++++++ tests/test_examples_service.py | 83 +++++++++++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 fire_planner/examples/service.py create mode 100644 tests/test_examples_service.py diff --git a/fire_planner/examples/service.py b/fire_planner/examples/service.py new file mode 100644 index 0000000..0e04243 --- /dev/null +++ b/fire_planner/examples/service.py @@ -0,0 +1,111 @@ +"""Persistence + read-side queries for fire_example. + +`upsert_example(...)` does an INSERT ... ON CONFLICT DO NOTHING by +reddit_id. Returns True when a new row was inserted, False when it was +already present (idempotent re-runs are a feature, not a bug). + +`summary_for_country(...)` computes count + median/p25/p75 of +portfolio_gbp + annual_exp_gbp + up to 5 sample post URLs. Runs as +plain SQL — SQLAlchemy expression API — so it works on both Postgres +and SQLite (which the tests use). +""" +from __future__ import annotations + +import logging +import statistics +from decimal import Decimal +from typing import Any, cast + +from sqlalchemy import CursorResult, select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.dialects.sqlite import insert as sqlite_insert +from sqlalchemy.ext.asyncio import AsyncSession + +from fire_planner.db import FireExample +from fire_planner.examples.llm_extract import to_gbp +from fire_planner.examples.models import ExtractedExample, RawPost, Summary, SummaryStats + +log = logging.getLogger(__name__) + +EXCERPT_LEN = 500 + + +def _dialect_insert(session: AsyncSession) -> Any: + bind = session.get_bind() + if bind.dialect.name == "sqlite": + return sqlite_insert + return pg_insert + + +async def upsert_example( + session: AsyncSession, + post: RawPost, + extracted: ExtractedExample, + fx_rates: dict[str, Decimal], +) -> bool: + """INSERT ... ON CONFLICT DO NOTHING. Returns True on insert, False on conflict.""" + portfolio_gbp = to_gbp(extracted.portfolio_native, extracted.raw_currency, fx_rates) + annual_exp_gbp = to_gbp(extracted.annual_exp_native, extracted.raw_currency, fx_rates) + values = { + "reddit_id": post.reddit_id, + "source_sub": post.source_sub, + "post_url": post.url, + "post_date": post.created_at, + "post_title": post.title, + "country": extracted.country, + "city": extracted.city, + "portfolio_gbp": portfolio_gbp, + "annual_exp_gbp": annual_exp_gbp, + "age": extracted.age, + "family_size": extracted.family_size, + "fi_status": str(extracted.fi_status) if extracted.fi_status else None, + "is_retired": extracted.is_retired, + "raw_currency": extracted.raw_currency, + "raw_excerpt": (post.title + "\n" + post.body)[:EXCERPT_LEN], + "llm_model": extracted.llm_model, + "llm_confidence": extracted.confidence, + } + insert_fn = _dialect_insert(session) + stmt = insert_fn(FireExample).values(**values) + stmt = stmt.on_conflict_do_nothing(index_elements=["reddit_id"]) + # session.execute returns a CursorResult for DML; cast so mypy can see + # .rowcount (the base Result class hides it). + result = cast(CursorResult[Any], await session.execute(stmt)) + await session.commit() + return (result.rowcount or 0) > 0 + + +def _quartiles(values: list[Decimal]) -> SummaryStats: + if not values: + return SummaryStats(median=None, p25=None, p75=None) + floats = [float(v) for v in values] + median = statistics.median(floats) + if len(floats) >= 4: + # method="inclusive" matches NumPy/pandas linear interpolation — + # the test expects p25/p75 to land on actual data points when n is + # small (e.g. p25=200_000 for [100k, 200k, 300k, 400k, 500k]). + quants = statistics.quantiles(floats, n=4, method="inclusive") + p25, p75 = quants[0], quants[2] + else: + # Too few samples for quartiles; fall back to min/max bounds. + p25, p75 = min(floats), max(floats) + return SummaryStats( + median=Decimal(f"{median:.2f}"), + p25=Decimal(f"{p25:.2f}"), + p75=Decimal(f"{p75:.2f}"), + ) + + +async def summary_for_country(session: AsyncSession, country: str) -> Summary: + stmt = select(FireExample).where(FireExample.country == country) + rows = (await session.execute(stmt)).scalars().all() + portfolios = [r.portfolio_gbp for r in rows if r.portfolio_gbp is not None] + expenses = [r.annual_exp_gbp for r in rows if r.annual_exp_gbp is not None] + sample_links = [r.post_url for r in rows[:5]] + return Summary( + country=country, + count=len(rows), + portfolio_gbp=_quartiles(portfolios), + annual_exp_gbp=_quartiles(expenses), + sample_links=sample_links, + ) diff --git a/tests/test_examples_service.py b/tests/test_examples_service.py new file mode 100644 index 0000000..d90ff51 --- /dev/null +++ b/tests/test_examples_service.py @@ -0,0 +1,83 @@ +"""Tests for service.upsert_example and service.summary_for_country.""" +from __future__ import annotations + +from datetime import date +from decimal import Decimal + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from fire_planner.db import FireExample +from fire_planner.examples.models import ExtractedExample, FiStatus, RawPost +from fire_planner.examples.service import summary_for_country, upsert_example + + +def _post(reddit_id: str = "abc1") -> RawPost: + return RawPost( + reddit_id=reddit_id, + source_sub="ExpatFIRE", + url=f"https://reddit.com/{reddit_id}", + title="t", + body="b", + created_at=date(2026, 1, 1), + ) + + +def _ex(conf: Decimal = Decimal("0.8")) -> ExtractedExample: + return ExtractedExample( + country="Philippines", + city="Manila", + portfolio_native=Decimal("1200000"), + annual_exp_native=Decimal("18000"), + raw_currency="USD", + age=38, + family_size=3, + fi_status=FiStatus.FIRE, + is_retired=True, + confidence=conf, + llm_model="qwen3-8b", + ) + + +@pytest.mark.asyncio +async def test_upsert_inserts_new_row(session: AsyncSession) -> None: + rates = {"GBP": Decimal("1"), "USD": Decimal("0.80")} + inserted = await upsert_example(session, _post(), _ex(), rates) + assert inserted is True + rows = (await session.execute(select(FireExample))).scalars().all() + assert len(rows) == 1 + assert rows[0].portfolio_gbp == Decimal("960000.00") + assert rows[0].country == "Philippines" + + +@pytest.mark.asyncio +async def test_upsert_is_idempotent_by_reddit_id(session: AsyncSession) -> None: + rates = {"GBP": Decimal("1"), "USD": Decimal("0.80")} + await upsert_example(session, _post("abc1"), _ex(), rates) + inserted = await upsert_example(session, _post("abc1"), _ex(), rates) + assert inserted is False # second call is no-op + rows = (await session.execute(select(FireExample))).scalars().all() + assert len(rows) == 1 + + +@pytest.mark.asyncio +async def test_summary_for_country_returns_quartiles(session: AsyncSession) -> None: + rates = {"GBP": Decimal("1"), "USD": Decimal("1")} + portfolios = [100_000, 200_000, 300_000, 400_000, 500_000] + for i, p in enumerate(portfolios): + ex = ExtractedExample( + country="Philippines", + portfolio_native=Decimal(p), + raw_currency="GBP", + confidence=Decimal("0.9"), + llm_model="qwen3-8b", + ) + await upsert_example(session, _post(f"id{i}"), ex, rates) + + summary = await summary_for_country(session, "Philippines") + assert summary.count == 5 + assert summary.portfolio_gbp.median == Decimal("300000.00") + assert summary.portfolio_gbp.p25 == Decimal("200000.00") + assert summary.portfolio_gbp.p75 == Decimal("400000.00") + assert len(summary.sample_links) <= 5