examples: service.upsert_example + summary_for_country
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
This commit is contained in:
parent
0d442de918
commit
a10d7fe2a6
2 changed files with 194 additions and 0 deletions
111
fire_planner/examples/service.py
Normal file
111
fire_planner/examples/service.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Persistence + read-side queries for fire_example.
|
||||
|
||||
`upsert_example(...)` does an INSERT ... ON CONFLICT DO NOTHING by
|
||||
reddit_id. Returns True when a new row was inserted, False when it was
|
||||
already present (idempotent re-runs are a feature, not a bug).
|
||||
|
||||
`summary_for_country(...)` computes count + median/p25/p75 of
|
||||
portfolio_gbp + annual_exp_gbp + up to 5 sample post URLs. Runs as
|
||||
plain SQL — SQLAlchemy expression API — so it works on both Postgres
|
||||
and SQLite (which the tests use).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import statistics
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import CursorResult, select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.db import FireExample
|
||||
from fire_planner.examples.llm_extract import to_gbp
|
||||
from fire_planner.examples.models import ExtractedExample, RawPost, Summary, SummaryStats
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
EXCERPT_LEN = 500
|
||||
|
||||
|
||||
def _dialect_insert(session: AsyncSession) -> Any:
|
||||
bind = session.get_bind()
|
||||
if bind.dialect.name == "sqlite":
|
||||
return sqlite_insert
|
||||
return pg_insert
|
||||
|
||||
|
||||
async def upsert_example(
|
||||
session: AsyncSession,
|
||||
post: RawPost,
|
||||
extracted: ExtractedExample,
|
||||
fx_rates: dict[str, Decimal],
|
||||
) -> bool:
|
||||
"""INSERT ... ON CONFLICT DO NOTHING. Returns True on insert, False on conflict."""
|
||||
portfolio_gbp = to_gbp(extracted.portfolio_native, extracted.raw_currency, fx_rates)
|
||||
annual_exp_gbp = to_gbp(extracted.annual_exp_native, extracted.raw_currency, fx_rates)
|
||||
values = {
|
||||
"reddit_id": post.reddit_id,
|
||||
"source_sub": post.source_sub,
|
||||
"post_url": post.url,
|
||||
"post_date": post.created_at,
|
||||
"post_title": post.title,
|
||||
"country": extracted.country,
|
||||
"city": extracted.city,
|
||||
"portfolio_gbp": portfolio_gbp,
|
||||
"annual_exp_gbp": annual_exp_gbp,
|
||||
"age": extracted.age,
|
||||
"family_size": extracted.family_size,
|
||||
"fi_status": str(extracted.fi_status) if extracted.fi_status else None,
|
||||
"is_retired": extracted.is_retired,
|
||||
"raw_currency": extracted.raw_currency,
|
||||
"raw_excerpt": (post.title + "\n" + post.body)[:EXCERPT_LEN],
|
||||
"llm_model": extracted.llm_model,
|
||||
"llm_confidence": extracted.confidence,
|
||||
}
|
||||
insert_fn = _dialect_insert(session)
|
||||
stmt = insert_fn(FireExample).values(**values)
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=["reddit_id"])
|
||||
# session.execute returns a CursorResult for DML; cast so mypy can see
|
||||
# .rowcount (the base Result class hides it).
|
||||
result = cast(CursorResult[Any], await session.execute(stmt))
|
||||
await session.commit()
|
||||
return (result.rowcount or 0) > 0
|
||||
|
||||
|
||||
def _quartiles(values: list[Decimal]) -> SummaryStats:
|
||||
if not values:
|
||||
return SummaryStats(median=None, p25=None, p75=None)
|
||||
floats = [float(v) for v in values]
|
||||
median = statistics.median(floats)
|
||||
if len(floats) >= 4:
|
||||
# method="inclusive" matches NumPy/pandas linear interpolation —
|
||||
# the test expects p25/p75 to land on actual data points when n is
|
||||
# small (e.g. p25=200_000 for [100k, 200k, 300k, 400k, 500k]).
|
||||
quants = statistics.quantiles(floats, n=4, method="inclusive")
|
||||
p25, p75 = quants[0], quants[2]
|
||||
else:
|
||||
# Too few samples for quartiles; fall back to min/max bounds.
|
||||
p25, p75 = min(floats), max(floats)
|
||||
return SummaryStats(
|
||||
median=Decimal(f"{median:.2f}"),
|
||||
p25=Decimal(f"{p25:.2f}"),
|
||||
p75=Decimal(f"{p75:.2f}"),
|
||||
)
|
||||
|
||||
|
||||
async def summary_for_country(session: AsyncSession, country: str) -> Summary:
|
||||
stmt = select(FireExample).where(FireExample.country == country)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
portfolios = [r.portfolio_gbp for r in rows if r.portfolio_gbp is not None]
|
||||
expenses = [r.annual_exp_gbp for r in rows if r.annual_exp_gbp is not None]
|
||||
sample_links = [r.post_url for r in rows[:5]]
|
||||
return Summary(
|
||||
country=country,
|
||||
count=len(rows),
|
||||
portfolio_gbp=_quartiles(portfolios),
|
||||
annual_exp_gbp=_quartiles(expenses),
|
||||
sample_links=sample_links,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue