fire-planner/tests/test_reporters_pg.py
2026-05-07 17:06:19 +00:00

93 lines
3.1 KiB
Python

"""Postgres reporter — write_run round-trips into the schema."""
from decimal import Decimal
import numpy as np
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.db import McRun, ProjectionYearly, ScenarioSummary
from fire_planner.glide_path import static
from fire_planner.reporters.pg import write_run
from fire_planner.scenarios import ScenarioSpec
from fire_planner.simulator import simulate
from fire_planner.strategies.trinity import TrinityStrategy
from fire_planner.tax.malaysia import MalaysiaTaxRegime
def fixed_paths(n_paths: int, n_years: int) -> np.ndarray:
out = np.zeros((n_paths, n_years, 3))
out[..., 0] = 0.05
out[..., 1] = 0.03
out[..., 2] = 0.02
return out
async def test_write_run_persists_summary_run_and_projection(session: AsyncSession) -> None:
spec = ScenarioSpec(
jurisdiction="cyprus",
strategy="trinity",
leave_uk_year=3,
glide_path="rising",
spending_gbp=Decimal("100000"),
nw_seed_gbp=Decimal("1000000"),
horizon_years=20,
)
paths = fixed_paths(50, 20)
result = simulate(
paths=paths,
initial_portfolio=1_000_000.0,
spending_target=40_000.0,
glide=static(0.7),
strategy=TrinityStrategy(),
regime=MalaysiaTaxRegime(),
horizon_years=20,
)
summary = await write_run(session, spec, result, seed=42, elapsed_seconds=1.5)
await session.commit()
runs = (await session.execute(select(McRun))).scalars().all()
assert len(runs) == 1
assert runs[0].id == summary.mc_run_id
assert runs[0].n_paths == 50
projections = (await session.execute(select(ProjectionYearly))).scalars().all()
assert len(projections) == 20 # one row per year
summaries = (await session.execute(select(ScenarioSummary))).scalars().all()
assert len(summaries) == 1
assert summaries[0].jurisdiction == "cyprus"
async def test_write_run_idempotent_summary(session: AsyncSession) -> None:
"""Running twice for the same scenario should keep summary at one row,
pointing at the latest run."""
spec = ScenarioSpec(
jurisdiction="bulgaria",
strategy="vpw",
leave_uk_year=2,
glide_path="static_60_40",
spending_gbp=Decimal("100000"),
nw_seed_gbp=Decimal("1000000"),
horizon_years=20,
)
paths = fixed_paths(20, 20)
result = simulate(
paths=paths,
initial_portfolio=1_000_000.0,
spending_target=40_000.0,
glide=static(0.6),
strategy=TrinityStrategy(),
regime=MalaysiaTaxRegime(),
horizon_years=20,
)
s1 = await write_run(session, spec, result, seed=42, elapsed_seconds=1.0)
await session.commit()
s2 = await write_run(session, spec, result, seed=43, elapsed_seconds=1.5)
await session.commit()
assert s1.scenario_id == s2.scenario_id
assert s2.mc_run_id != s1.mc_run_id
runs = (await session.execute(select(McRun))).scalars().all()
assert len(runs) == 2
summaries = (await session.execute(select(ScenarioSummary))).scalars().all()
assert len(summaries) == 1
assert summaries[0].mc_run_id == s2.mc_run_id