"""Tests for the per-year stats endpoint.""" from __future__ import annotations from collections.abc import AsyncIterator from datetime import UTC, datetime from decimal import Decimal import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from fire_planner.api.dependencies import get_session from fire_planner.app import app from fire_planner.db import McRun, ProjectionYearly, Scenario @pytest_asyncio.fixture async def client(engine: AsyncEngine, session: AsyncSession) -> AsyncIterator[AsyncClient]: factory = async_sessionmaker(engine, expire_on_commit=False) async def _override() -> AsyncIterator[AsyncSession]: async with factory() as s: yield s app.dependency_overrides[get_session] = _override transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() async def _seed_scenario_with_run(session: AsyncSession) -> int: scen = Scenario( external_id="user-yr-stats", kind="user", name="Yr Stats", jurisdiction="uk", strategy="trinity", leave_uk_year=0, glide_path="static", spending_gbp=Decimal("60000"), horizon_years=5, nw_seed_gbp=Decimal("1000000"), savings_per_year_gbp=Decimal("0"), config_json={}, ) session.add(scen) await session.commit() await session.refresh(scen) run = McRun( scenario_id=scen.id, run_at=datetime.now(UTC), n_paths=100, seed=1, success_rate=Decimal("1"), p10_ending_gbp=Decimal("900000"), p50_ending_gbp=Decimal("1100000"), p90_ending_gbp=Decimal("1300000"), median_lifetime_tax_gbp=Decimal("50000"), median_years_to_ruin=None, elapsed_seconds=Decimal("1.234"), ) session.add(run) await session.commit() await session.refresh(run) rows = [ ProjectionYearly( mc_run_id=run.id, year_idx=y, p10_portfolio_gbp=Decimal("900000"), p25_portfolio_gbp=Decimal("950000"), p50_portfolio_gbp=Decimal(str(1_000_000 + y * 50_000)), p75_portfolio_gbp=Decimal("1100000"), p90_portfolio_gbp=Decimal("1200000"), p50_withdrawal_gbp=Decimal("60000"), p50_tax_gbp=Decimal("8000"), survival_rate=Decimal("1.0"), ) for y in range(5) ] session.add_all(rows) await session.commit() return scen.id async def test_year_stats_returns_per_year_metrics( client: AsyncClient, session: AsyncSession, ) -> None: sid = await _seed_scenario_with_run(session) resp = await client.get(f"/scenarios/{sid}/year-stats?year=2") assert resp.status_code == 200, resp.text body = resp.json() assert body["year_idx"] == 2 # year 2 NW = 1_100_000; year 1 NW = 1_050_000 → change 50_000. assert body["net_worth_p50"] == "1100000.00" assert body["change_in_nw"] == "50000.00" assert body["spending"] == "60000.00" assert body["taxes"] == "8000.00" async def test_year_stats_404_when_no_run(client: AsyncClient, session: AsyncSession) -> None: scen = Scenario( external_id="user-no-run", kind="user", name="No run", jurisdiction="uk", strategy="trinity", leave_uk_year=0, glide_path="static", spending_gbp=Decimal("60000"), horizon_years=5, nw_seed_gbp=Decimal("1000000"), savings_per_year_gbp=Decimal("0"), config_json={}, ) session.add(scen) await session.commit() await session.refresh(scen) resp = await client.get(f"/scenarios/{scen.id}/year-stats?year=0") assert resp.status_code == 404