fire-planner/fire_planner/examples/service.py

112 lines
4.2 KiB
Python
Raw Normal View History

"""Persistence + read-side queries for fire_example.
`upsert_example(...)` does an INSERT ... ON CONFLICT DO NOTHING by
reddit_id. Returns True when a new row was inserted, False when it was
already present (idempotent re-runs are a feature, not a bug).
`summary_for_country(...)` computes count + median/p25/p75 of
portfolio_gbp + annual_exp_gbp + up to 5 sample post URLs. Runs as
plain SQL SQLAlchemy expression API so it works on both Postgres
and SQLite (which the tests use).
"""
from __future__ import annotations
import logging
import statistics
from decimal import Decimal
from typing import Any, cast
from sqlalchemy import CursorResult, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.db import FireExample
from fire_planner.examples.llm_extract import to_gbp
from fire_planner.examples.models import ExtractedExample, RawPost, Summary, SummaryStats
log = logging.getLogger(__name__)
EXCERPT_LEN = 500
def _dialect_insert(session: AsyncSession) -> Any:
bind = session.get_bind()
if bind.dialect.name == "sqlite":
return sqlite_insert
return pg_insert
async def upsert_example(
session: AsyncSession,
post: RawPost,
extracted: ExtractedExample,
fx_rates: dict[str, Decimal],
) -> bool:
"""INSERT ... ON CONFLICT DO NOTHING. Returns True on insert, False on conflict."""
portfolio_gbp = to_gbp(extracted.portfolio_native, extracted.raw_currency, fx_rates)
annual_exp_gbp = to_gbp(extracted.annual_exp_native, extracted.raw_currency, fx_rates)
values = {
"reddit_id": post.reddit_id,
"source_sub": post.source_sub,
"post_url": post.url,
"post_date": post.created_at,
"post_title": post.title,
"country": extracted.country,
"city": extracted.city,
"portfolio_gbp": portfolio_gbp,
"annual_exp_gbp": annual_exp_gbp,
"age": extracted.age,
"family_size": extracted.family_size,
"fi_status": str(extracted.fi_status) if extracted.fi_status else None,
"is_retired": extracted.is_retired,
"raw_currency": extracted.raw_currency,
"raw_excerpt": (post.title + "\n" + post.body)[:EXCERPT_LEN],
"llm_model": extracted.llm_model,
"llm_confidence": extracted.confidence,
}
insert_fn = _dialect_insert(session)
stmt = insert_fn(FireExample).values(**values)
stmt = stmt.on_conflict_do_nothing(index_elements=["reddit_id"])
# session.execute returns a CursorResult for DML; cast so mypy can see
# .rowcount (the base Result class hides it).
result = cast(CursorResult[Any], await session.execute(stmt))
await session.commit()
return (result.rowcount or 0) > 0
def _quartiles(values: list[Decimal]) -> SummaryStats:
if not values:
return SummaryStats(median=None, p25=None, p75=None)
floats = [float(v) for v in values]
median = statistics.median(floats)
if len(floats) >= 4:
# method="inclusive" matches NumPy/pandas linear interpolation —
# the test expects p25/p75 to land on actual data points when n is
# small (e.g. p25=200_000 for [100k, 200k, 300k, 400k, 500k]).
quants = statistics.quantiles(floats, n=4, method="inclusive")
p25, p75 = quants[0], quants[2]
else:
# Too few samples for quartiles; fall back to min/max bounds.
p25, p75 = min(floats), max(floats)
return SummaryStats(
median=Decimal(f"{median:.2f}"),
p25=Decimal(f"{p25:.2f}"),
p75=Decimal(f"{p75:.2f}"),
)
async def summary_for_country(session: AsyncSession, country: str) -> Summary:
stmt = select(FireExample).where(FireExample.country == country)
rows = (await session.execute(stmt)).scalars().all()
portfolios = [r.portfolio_gbp for r in rows if r.portfolio_gbp is not None]
expenses = [r.annual_exp_gbp for r in rows if r.annual_exp_gbp is not None]
sample_links = [r.post_url for r in rows[:5]]
return Summary(
country=country,
count=len(rows),
portfolio_gbp=_quartiles(portfolios),
annual_exp_gbp=_quartiles(expenses),
sample_links=sample_links,
)