"""Tests for the spending-profile endpoint.""" from __future__ import annotations from collections.abc import AsyncIterator from datetime import UTC, datetime from decimal import Decimal import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from fire_planner.api.dependencies import get_session from fire_planner.app import app from fire_planner.db import LifeEvent, McRun, ProjectionYearly, Scenario @pytest_asyncio.fixture async def client(engine: AsyncEngine, session: AsyncSession) -> AsyncIterator[AsyncClient]: factory = async_sessionmaker(engine, expire_on_commit=False) async def _override() -> AsyncIterator[AsyncSession]: async with factory() as s: yield s app.dependency_overrides[get_session] = _override transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() async def _seed(session: AsyncSession, flex_rules: list[dict] | None = None) -> int: config: dict = {} if flex_rules: config["flex_rules"] = flex_rules scen = Scenario( external_id="user-sp", kind="user", name="SP test", jurisdiction="uk", strategy="trinity", leave_uk_year=0, glide_path="static", spending_gbp=Decimal("60000"), horizon_years=5, nw_seed_gbp=Decimal("1000000"), savings_per_year_gbp=Decimal("0"), config_json=config, ) session.add(scen) await session.commit() await session.refresh(scen) # One persistent essential life event (kid at home), one # discretionary (travel), one income inflow. session.add_all([ LifeEvent( scenario_id=scen.id, kind="kid_at_home", name="Kid 1", year_start=0, year_end=4, delta_gbp_per_year=Decimal("-15000"), category="essential", enabled=True, ), LifeEvent( scenario_id=scen.id, kind="travel", name="Travel", year_start=0, year_end=4, delta_gbp_per_year=Decimal("-10000"), category="discretionary", enabled=True, ), LifeEvent( scenario_id=scen.id, kind="rental", name="Rental", year_start=0, year_end=4, delta_gbp_per_year=Decimal("8000"), category="essential", enabled=True, ), ]) await session.commit() return scen.id async def test_spending_profile_with_no_run( client: AsyncClient, session: AsyncSession, ) -> None: sid = await _seed(session) resp = await client.get(f"/scenarios/{sid}/spending-profile") assert resp.status_code == 200, resp.text body = resp.json() assert body["horizon_years"] == 5 assert len(body["points"]) == 5 p0 = body["points"][0] # base = 60000 - 8000 inflow = 52000 assert Decimal(p0["base_gbp"]) == Decimal("52000") assert Decimal(p0["essential_gbp"]) == Decimal("15000") assert Decimal(p0["discretionary_gbp"]) == Decimal("10000") # No projection yet → no flex cut. assert Decimal(p0["flex_cut_gbp"]) == Decimal("0") # total = 52000 + 15000 + 10000 = 77000 assert Decimal(p0["total_gbp"]) == Decimal("77000") async def test_spending_profile_with_flex_rules( client: AsyncClient, session: AsyncSession, ) -> None: flex = [{"from_ath_pct": 0.20, "cut_discretionary_pct": 0.50}] sid = await _seed(session, flex_rules=flex) # Persist a fan that drops to 70% of seed (i.e. 30% drawdown vs ATH). run = McRun( scenario_id=sid, run_at=datetime.now(UTC), n_paths=10, seed=1, success_rate=Decimal("1"), p10_ending_gbp=Decimal("0"), p50_ending_gbp=Decimal("0"), p90_ending_gbp=Decimal("0"), median_lifetime_tax_gbp=Decimal("0"), median_years_to_ruin=None, elapsed_seconds=Decimal("0"), ) session.add(run) await session.commit() await session.refresh(run) rows = [ ProjectionYearly( mc_run_id=run.id, year_idx=y, p10_portfolio_gbp=Decimal("0"), p25_portfolio_gbp=Decimal("0"), # year 0 = 1M (ATH); year 1 = 700k (down 30% — flex fires); # years 2-4 = 800k (still down 20% from ATH 1M). p50_portfolio_gbp=Decimal( str([1_000_000, 700_000, 800_000, 800_000, 800_000][y])), p75_portfolio_gbp=Decimal("0"), p90_portfolio_gbp=Decimal("0"), p50_withdrawal_gbp=Decimal("0"), p50_tax_gbp=Decimal("0"), survival_rate=Decimal("1"), ) for y in range(5) ] session.add_all(rows) await session.commit() resp = await client.get(f"/scenarios/{sid}/spending-profile") assert resp.status_code == 200 pts = resp.json()["points"] # Year 0: portfolio == ATH → no cut. assert Decimal(pts[0]["flex_cut_gbp"]) == Decimal("0") # Year 1: drawdown 30% → 50% cut on £10k discretionary = £5k. assert Decimal(pts[1]["flex_cut_gbp"]) == Decimal("5000.00") # Year 1 total = 52000 + 15000 + 10000 - 5000 = 72000 assert Decimal(pts[1]["total_gbp"]) == Decimal("72000.00")