"""Postgres reporter — write_run round-trips into the schema.""" from decimal import Decimal import numpy as np from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from fire_planner.db import McRun, ProjectionYearly, ScenarioSummary from fire_planner.glide_path import static from fire_planner.reporters.pg import write_run from fire_planner.scenarios import ScenarioSpec from fire_planner.simulator import simulate from fire_planner.strategies.trinity import TrinityStrategy from fire_planner.tax.malaysia import MalaysiaTaxRegime def fixed_paths(n_paths: int, n_years: int) -> np.ndarray: out = np.zeros((n_paths, n_years, 3)) out[..., 0] = 0.05 out[..., 1] = 0.03 out[..., 2] = 0.02 return out async def test_write_run_persists_summary_run_and_projection(session: AsyncSession) -> None: spec = ScenarioSpec( jurisdiction="cyprus", strategy="trinity", leave_uk_year=3, glide_path="rising", spending_gbp=Decimal("100000"), nw_seed_gbp=Decimal("1000000"), horizon_years=20, ) paths = fixed_paths(50, 20) result = simulate( paths=paths, initial_portfolio=1_000_000.0, spending_target=40_000.0, glide=static(0.7), strategy=TrinityStrategy(), regime=MalaysiaTaxRegime(), horizon_years=20, ) summary = await write_run(session, spec, result, seed=42, elapsed_seconds=1.5) await session.commit() runs = (await session.execute(select(McRun))).scalars().all() assert len(runs) == 1 assert runs[0].id == summary.mc_run_id assert runs[0].n_paths == 50 projections = (await session.execute(select(ProjectionYearly))).scalars().all() assert len(projections) == 20 # one row per year summaries = (await session.execute(select(ScenarioSummary))).scalars().all() assert len(summaries) == 1 assert summaries[0].jurisdiction == "cyprus" async def test_write_run_idempotent_summary(session: AsyncSession) -> None: """Running twice for the same scenario should keep summary at one row, pointing at the latest run.""" spec = ScenarioSpec( jurisdiction="bulgaria", strategy="vpw", leave_uk_year=2, glide_path="static_60_40", spending_gbp=Decimal("100000"), nw_seed_gbp=Decimal("1000000"), horizon_years=20, ) paths = fixed_paths(20, 20) result = simulate( paths=paths, initial_portfolio=1_000_000.0, spending_target=40_000.0, glide=static(0.6), strategy=TrinityStrategy(), regime=MalaysiaTaxRegime(), horizon_years=20, ) s1 = await write_run(session, spec, result, seed=42, elapsed_seconds=1.0) await session.commit() s2 = await write_run(session, spec, result, seed=43, elapsed_seconds=1.5) await session.commit() assert s1.scenario_id == s2.scenario_id assert s2.mc_run_id != s1.mc_run_id runs = (await session.execute(select(McRun))).scalars().all() assert len(runs) == 2 summaries = (await session.execute(select(ScenarioSummary))).scalars().all() assert len(summaries) == 1 assert summaries[0].mc_run_id == s2.mc_run_id