"""Persistence + read-side queries for fire_example. `upsert_example(...)` does an INSERT ... ON CONFLICT DO NOTHING by reddit_id. Returns True when a new row was inserted, False when it was already present (idempotent re-runs are a feature, not a bug). `summary_for_country(...)` computes count + median/p25/p75 of portfolio_gbp + annual_exp_gbp + up to 5 sample post URLs. Runs as plain SQL — SQLAlchemy expression API — so it works on both Postgres and SQLite (which the tests use). """ from __future__ import annotations import logging import statistics from decimal import Decimal from typing import Any, cast from sqlalchemy import CursorResult, select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.ext.asyncio import AsyncSession from fire_planner.db import FireExample from fire_planner.examples.llm_extract import to_gbp from fire_planner.examples.models import ExtractedExample, RawPost, Summary, SummaryStats log = logging.getLogger(__name__) EXCERPT_LEN = 500 def _dialect_insert(session: AsyncSession) -> Any: bind = session.get_bind() if bind.dialect.name == "sqlite": return sqlite_insert return pg_insert async def upsert_example( session: AsyncSession, post: RawPost, extracted: ExtractedExample, fx_rates: dict[str, Decimal], ) -> bool: """INSERT ... ON CONFLICT DO NOTHING. Returns True on insert, False on conflict.""" portfolio_gbp = to_gbp(extracted.portfolio_native, extracted.raw_currency, fx_rates) annual_exp_gbp = to_gbp(extracted.annual_exp_native, extracted.raw_currency, fx_rates) values = { "reddit_id": post.reddit_id, "source_sub": post.source_sub, "post_url": post.url, "post_date": post.created_at, "post_title": post.title, "country": extracted.country, "city": extracted.city, "portfolio_gbp": portfolio_gbp, "annual_exp_gbp": annual_exp_gbp, "age": extracted.age, "family_size": extracted.family_size, "fi_status": str(extracted.fi_status) if extracted.fi_status else None, "is_retired": extracted.is_retired, "raw_currency": extracted.raw_currency, "raw_excerpt": (post.title + "\n" + post.body)[:EXCERPT_LEN], "llm_model": extracted.llm_model, "llm_confidence": extracted.confidence, } insert_fn = _dialect_insert(session) stmt = insert_fn(FireExample).values(**values) stmt = stmt.on_conflict_do_nothing(index_elements=["reddit_id"]) # session.execute returns a CursorResult for DML; cast so mypy can see # .rowcount (the base Result class hides it). result = cast(CursorResult[Any], await session.execute(stmt)) await session.commit() return (result.rowcount or 0) > 0 def _quartiles(values: list[Decimal]) -> SummaryStats: if not values: return SummaryStats(median=None, p25=None, p75=None) floats = [float(v) for v in values] median = statistics.median(floats) if len(floats) >= 4: # method="inclusive" matches NumPy/pandas linear interpolation — # the test expects p25/p75 to land on actual data points when n is # small (e.g. p25=200_000 for [100k, 200k, 300k, 400k, 500k]). quants = statistics.quantiles(floats, n=4, method="inclusive") p25, p75 = quants[0], quants[2] else: # Too few samples for quartiles; fall back to min/max bounds. p25, p75 = min(floats), max(floats) return SummaryStats( median=Decimal(f"{median:.2f}"), p25=Decimal(f"{p25:.2f}"), p75=Decimal(f"{p75:.2f}"), ) async def summary_for_country(session: AsyncSession, country: str) -> Summary: stmt = select(FireExample).where(FireExample.country == country) rows = (await session.execute(stmt)).scalars().all() portfolios = [r.portfolio_gbp for r in rows if r.portfolio_gbp is not None] expenses = [r.annual_exp_gbp for r in rows if r.annual_exp_gbp is not None] sample_links = [r.post_url for r in rows[:5]] return Summary( country=country, count=len(rows), portfolio_gbp=_quartiles(portfolios), annual_exp_gbp=_quartiles(expenses), sample_links=sample_links, )