"""upsert_fire_target writes one row per (case, country, with_home, bar) and updates in place on re-run (idempotent recompute).""" from __future__ import annotations from decimal import Decimal from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from fire_planner.db import FireTarget from fire_planner.fire_target import SolveResult, TargetInputs from fire_planner.reporters.pg import upsert_fire_target from fire_planner.spend_model import Case def _inp(**over) -> TargetInputs: base = dict( case=Case.SOLO, country_slug="sofia", country_display="Sofia", jurisdiction="bulgaria", annual_spend_gbp=35_000.0, horizon_years=60, ) base.update(over) return TargetInputs(**base) def _res(target: float, reached: bool = True) -> SolveResult: return SolveResult(target_nw_gbp=target, success_at_target=0.992, pension_at_unlock_gbp=120_000.0, reached_bar=reached) async def test_upsert_inserts_then_updates_in_place(session: AsyncSession) -> None: await upsert_fire_target(session, _inp(), _res(900_000.0), n_paths=2_000) await session.commit() rows = (await session.execute(select(FireTarget))).scalars().all() assert len(rows) == 1 assert rows[0].target_nw_gbp == Decimal("900000.00") assert rows[0].case == "solo" # Re-running the same key updates, doesn't duplicate. expire_all() forces a # DB read past the identity map (session is expire_on_commit=False). await upsert_fire_target(session, _inp(), _res(850_000.0), n_paths=5_000) await session.commit() session.expire_all() rows = (await session.execute(select(FireTarget))).scalars().all() assert len(rows) == 1 assert rows[0].target_nw_gbp == Decimal("850000.00") assert rows[0].n_paths == 5_000 async def test_with_home_is_a_distinct_row(session: AsyncSession) -> None: await upsert_fire_target(session, _inp(with_home=False), _res(900_000.0), 2_000) await upsert_fire_target(session, _inp(with_home=True), _res(1_100_000.0), 2_000) await session.commit() rows = (await session.execute(select(FireTarget))).scalars().all() assert len(rows) == 2 by_home = {r.with_home: r.target_nw_gbp for r in rows} assert by_home[True] > by_home[False] async def test_not_reached_bar_is_persisted(session: AsyncSession) -> None: await upsert_fire_target( session, _inp(case=Case.FAMILY), _res(5_000_000.0, reached=False), 2_000) await session.commit() row = (await session.execute(select(FireTarget))).scalars().one() assert row.reached_bar is False