93 lines
3.1 KiB
Python
93 lines
3.1 KiB
Python
"""Postgres reporter — write_run round-trips into the schema."""
|
|
from decimal import Decimal
|
|
|
|
import numpy as np
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from fire_planner.db import McRun, ProjectionYearly, ScenarioSummary
|
|
from fire_planner.glide_path import static
|
|
from fire_planner.reporters.pg import write_run
|
|
from fire_planner.scenarios import ScenarioSpec
|
|
from fire_planner.simulator import simulate
|
|
from fire_planner.strategies.trinity import TrinityStrategy
|
|
from fire_planner.tax.malaysia import MalaysiaTaxRegime
|
|
|
|
|
|
def fixed_paths(n_paths: int, n_years: int) -> np.ndarray:
|
|
out = np.zeros((n_paths, n_years, 3))
|
|
out[..., 0] = 0.05
|
|
out[..., 1] = 0.03
|
|
out[..., 2] = 0.02
|
|
return out
|
|
|
|
|
|
async def test_write_run_persists_summary_run_and_projection(session: AsyncSession) -> None:
|
|
spec = ScenarioSpec(
|
|
jurisdiction="cyprus",
|
|
strategy="trinity",
|
|
leave_uk_year=3,
|
|
glide_path="rising",
|
|
spending_gbp=Decimal("100000"),
|
|
nw_seed_gbp=Decimal("1000000"),
|
|
horizon_years=20,
|
|
)
|
|
paths = fixed_paths(50, 20)
|
|
result = simulate(
|
|
paths=paths,
|
|
initial_portfolio=1_000_000.0,
|
|
spending_target=40_000.0,
|
|
glide=static(0.7),
|
|
strategy=TrinityStrategy(),
|
|
regime=MalaysiaTaxRegime(),
|
|
horizon_years=20,
|
|
)
|
|
summary = await write_run(session, spec, result, seed=42, elapsed_seconds=1.5)
|
|
await session.commit()
|
|
|
|
runs = (await session.execute(select(McRun))).scalars().all()
|
|
assert len(runs) == 1
|
|
assert runs[0].id == summary.mc_run_id
|
|
assert runs[0].n_paths == 50
|
|
|
|
projections = (await session.execute(select(ProjectionYearly))).scalars().all()
|
|
assert len(projections) == 20 # one row per year
|
|
summaries = (await session.execute(select(ScenarioSummary))).scalars().all()
|
|
assert len(summaries) == 1
|
|
assert summaries[0].jurisdiction == "cyprus"
|
|
|
|
|
|
async def test_write_run_idempotent_summary(session: AsyncSession) -> None:
|
|
"""Running twice for the same scenario should keep summary at one row,
|
|
pointing at the latest run."""
|
|
spec = ScenarioSpec(
|
|
jurisdiction="bulgaria",
|
|
strategy="vpw",
|
|
leave_uk_year=2,
|
|
glide_path="static_60_40",
|
|
spending_gbp=Decimal("100000"),
|
|
nw_seed_gbp=Decimal("1000000"),
|
|
horizon_years=20,
|
|
)
|
|
paths = fixed_paths(20, 20)
|
|
result = simulate(
|
|
paths=paths,
|
|
initial_portfolio=1_000_000.0,
|
|
spending_target=40_000.0,
|
|
glide=static(0.6),
|
|
strategy=TrinityStrategy(),
|
|
regime=MalaysiaTaxRegime(),
|
|
horizon_years=20,
|
|
)
|
|
s1 = await write_run(session, spec, result, seed=42, elapsed_seconds=1.0)
|
|
await session.commit()
|
|
s2 = await write_run(session, spec, result, seed=43, elapsed_seconds=1.5)
|
|
await session.commit()
|
|
assert s1.scenario_id == s2.scenario_id
|
|
assert s2.mc_run_id != s1.mc_run_id
|
|
|
|
runs = (await session.execute(select(McRun))).scalars().all()
|
|
assert len(runs) == 2
|
|
summaries = (await session.execute(select(ScenarioSummary))).scalars().all()
|
|
assert len(summaries) == 1
|
|
assert summaries[0].mc_run_id == s2.mc_run_id
|