"""Tests for the Progress overlay endpoint.""" from __future__ import annotations from collections.abc import AsyncIterator from datetime import UTC, date, datetime from decimal import Decimal import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from fire_planner.api.dependencies import get_session from fire_planner.app import app from fire_planner.db import ( AccountSnapshot, McRun, ProjectionYearly, Scenario, ) @pytest_asyncio.fixture async def client(engine: AsyncEngine, session: AsyncSession) -> AsyncIterator[AsyncClient]: factory = async_sessionmaker(engine, expire_on_commit=False) async def _override() -> AsyncIterator[AsyncSession]: async with factory() as s: yield s app.dependency_overrides[get_session] = _override transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() async def _seed_full(session: AsyncSession) -> int: scen = Scenario( external_id="user-prog", kind="user", name="Progress test", jurisdiction="uk", strategy="trinity", leave_uk_year=0, glide_path="static", spending_gbp=Decimal("60000"), horizon_years=5, nw_seed_gbp=Decimal("1000000"), savings_per_year_gbp=Decimal("0"), config_json={}, ) session.add(scen) await session.commit() await session.refresh(scen) run = McRun( scenario_id=scen.id, run_at=datetime.now(UTC), n_paths=100, seed=1, success_rate=Decimal("1"), p10_ending_gbp=Decimal("0"), p50_ending_gbp=Decimal("0"), p90_ending_gbp=Decimal("0"), median_lifetime_tax_gbp=Decimal("0"), median_years_to_ruin=None, elapsed_seconds=Decimal("0"), ) session.add(run) await session.commit() await session.refresh(run) yearly = [ ProjectionYearly( mc_run_id=run.id, year_idx=y, p10_portfolio_gbp=Decimal("900000"), p25_portfolio_gbp=Decimal("950000"), p50_portfolio_gbp=Decimal(str(1_000_000 + y * 50_000)), p75_portfolio_gbp=Decimal("1100000"), p90_portfolio_gbp=Decimal("1200000"), p50_withdrawal_gbp=Decimal("60000"), p50_tax_gbp=Decimal("8000"), survival_rate=Decimal("1"), ) for y in range(3) ] session.add_all(yearly) # Two snapshots a year apart snap_a = AccountSnapshot( external_id="wf:a:2024-01-01", snapshot_date=date(2024, 1, 1), account_id="a", account_name="Stocks", account_type="brokerage", currency="GBP", market_value=Decimal("1000000"), market_value_gbp=Decimal("1000000"), ) snap_b = AccountSnapshot( external_id="wf:a:2025-01-01", snapshot_date=date(2025, 1, 1), account_id="a", account_name="Stocks", account_type="brokerage", currency="GBP", market_value=Decimal("1080000"), market_value_gbp=Decimal("1080000"), ) session.add_all([snap_a, snap_b]) await session.commit() return scen.id async def test_progress_returns_actual_and_projected( client: AsyncClient, session: AsyncSession, ) -> None: sid = await _seed_full(session) resp = await client.get(f"/scenarios/{sid}/progress") assert resp.status_code == 200, resp.text body = resp.json() assert body["scenario_id"] == sid assert body["alignment_anchor"] == "2024-01-01" assert len(body["actual"]) == 2 assert len(body["projected"]) == 3 # year_idx 1 has actual £1.08M vs projected £1.05M → +30k variance. variance_y1 = next(v for v in body["variance"] if v["year_idx"] == 1) assert Decimal(variance_y1["delta_gbp"]) == Decimal("30000.00") async def test_progress_handles_empty_snapshots( client: AsyncClient, session: AsyncSession, ) -> None: scen = Scenario( external_id="user-empty", kind="user", name="No snapshots", jurisdiction="uk", strategy="trinity", leave_uk_year=0, glide_path="static", spending_gbp=Decimal("60000"), horizon_years=5, nw_seed_gbp=Decimal("1000000"), savings_per_year_gbp=Decimal("0"), config_json={}, ) session.add(scen) await session.commit() await session.refresh(scen) resp = await client.get(f"/scenarios/{scen.id}/progress") assert resp.status_code == 200 body = resp.json() assert body["actual"] == [] assert body["variance"] == []