"""Tests for /scenarios CRUD + projection.""" from __future__ import annotations from collections.abc import AsyncIterator from datetime import UTC, datetime from decimal import Decimal import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from fire_planner.api.dependencies import get_session from fire_planner.app import app from fire_planner.db import McRun, ProjectionYearly, Scenario @pytest_asyncio.fixture async def client(engine: AsyncEngine, session: AsyncSession) -> AsyncIterator[AsyncClient]: factory = async_sessionmaker(engine, expire_on_commit=False) async def _override() -> AsyncIterator[AsyncSession]: async with factory() as s: yield s app.dependency_overrides[get_session] = _override transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() async def _seed(session: AsyncSession) -> Scenario: scen = Scenario( external_id="cyprus-vpw-leave-y3-glide-rising", kind="cartesian", jurisdiction="cyprus", strategy="vpw", leave_uk_year=3, glide_path="rising", spending_gbp=Decimal("60000"), nw_seed_gbp=Decimal("1500000"), savings_per_year_gbp=Decimal("0"), config_json={"horizon_years": 60}, ) session.add(scen) await session.commit() await session.refresh(scen) return scen async def test_list_scenarios_empty(client: AsyncClient) -> None: resp = await client.get("/scenarios") assert resp.status_code == 200 assert resp.json() == [] async def test_list_and_filter_by_kind(client: AsyncClient, session: AsyncSession) -> None: base = await _seed(session) user = Scenario( external_id="user-abc", kind="user", name="My plan", parent_scenario_id=base.id, jurisdiction="cyprus", strategy="vpw", leave_uk_year=3, glide_path="rising", spending_gbp=Decimal("80000"), nw_seed_gbp=Decimal("1500000"), savings_per_year_gbp=Decimal("0"), config_json={}, ) session.add(user) await session.commit() all_resp = await client.get("/scenarios") assert len(all_resp.json()) == 2 user_resp = await client.get("/scenarios?kind=user") assert len(user_resp.json()) == 1 assert user_resp.json()[0]["name"] == "My plan" async def test_get_scenario(client: AsyncClient, session: AsyncSession) -> None: scen = await _seed(session) resp = await client.get(f"/scenarios/{scen.id}") assert resp.status_code == 200 assert resp.json()["jurisdiction"] == "cyprus" async def test_get_scenario_404(client: AsyncClient) -> None: resp = await client.get("/scenarios/9999") assert resp.status_code == 404 async def test_create_user_scenario(client: AsyncClient) -> None: resp = await client.post( "/scenarios", json={ "name": "Aggressive FIRE", "description": "Cyprus, lower spend", "jurisdiction": "cyprus", "strategy": "vpw", "leave_uk_year": 2, "glide_path": "rising", "spending_gbp": "50000", "horizon_years": 60, "nw_seed_gbp": "1500000", "savings_per_year_gbp": "0", }, ) assert resp.status_code == 201, resp.text body = resp.json() assert body["kind"] == "user" assert body["name"] == "Aggressive FIRE" assert body["external_id"].startswith("user-") async def test_create_with_invalid_parent_id(client: AsyncClient) -> None: resp = await client.post( "/scenarios", json={ "name": "X", "parent_scenario_id": 9999, "jurisdiction": "uk", "strategy": "trinity", "leave_uk_year": 0, "glide_path": "static", "spending_gbp": "60000", "nw_seed_gbp": "1000000", }, ) assert resp.status_code == 400 async def test_patch_user_scenario(client: AsyncClient) -> None: create = await client.post("/scenarios", json={ "name": "Plan A", "jurisdiction": "uk", "strategy": "trinity", "leave_uk_year": 0, "glide_path": "static", "spending_gbp": "60000", "nw_seed_gbp": "1000000", }) sid = create.json()["id"] resp = await client.patch(f"/scenarios/{sid}", json={"name": "Plan A v2", "leave_uk_year": 2}) assert resp.status_code == 200 body = resp.json() assert body["name"] == "Plan A v2" assert body["leave_uk_year"] == 2 async def test_patch_cartesian_core_fields_blocked( client: AsyncClient, session: AsyncSession, ) -> None: """Cartesian scenarios reject edits to fields that get rebuilt by recompute (jurisdiction/strategy/etc), but allow free-form metadata (config_json/name/description) so users can pin notes + flex_rules.""" cart = await _seed(session) # Core field — still blocked. bad = await client.patch(f"/scenarios/{cart.id}", json={"jurisdiction": "uae"}) assert bad.status_code == 400 assert "cartesian" in bad.json()["detail"] # config_json and name — allowed (preserves user edits). ok = await client.patch( f"/scenarios/{cart.id}", json={"config_json": {"flex_rules": [{"from_ath_pct": 0.2, "cut_discretionary_pct": 0.5}]}, "name": "Renamed"}, ) assert ok.status_code == 200, ok.text assert ok.json()["name"] == "Renamed" assert ok.json()["config_json"]["flex_rules"][0]["from_ath_pct"] == 0.2 async def test_delete_user_scenario(client: AsyncClient) -> None: create = await client.post("/scenarios", json={ "name": "Throwaway", "jurisdiction": "uk", "strategy": "trinity", "leave_uk_year": 0, "glide_path": "static", "spending_gbp": "60000", "nw_seed_gbp": "1000000", }) sid = create.json()["id"] resp = await client.delete(f"/scenarios/{sid}") assert resp.status_code == 204 assert (await client.get(f"/scenarios/{sid}")).status_code == 404 async def test_delete_cartesian_blocked(client: AsyncClient, session: AsyncSession) -> None: cart = await _seed(session) resp = await client.delete(f"/scenarios/{cart.id}") assert resp.status_code == 400 async def test_projection_404_when_no_run(client: AsyncClient, session: AsyncSession) -> None: scen = await _seed(session) resp = await client.get(f"/scenarios/{scen.id}/projection") assert resp.status_code == 404 async def test_projection_returns_yearly_series(client: AsyncClient, session: AsyncSession) -> None: scen = await _seed(session) run = McRun( scenario_id=scen.id, run_at=datetime(2026, 5, 1, tzinfo=UTC), n_paths=1000, seed=42, success_rate=Decimal("0.9050"), p10_ending_gbp=Decimal("100000"), p50_ending_gbp=Decimal("3000000"), p90_ending_gbp=Decimal("9000000"), median_lifetime_tax_gbp=Decimal("750000"), elapsed_seconds=Decimal("12.500"), ) session.add(run) await session.commit() await session.refresh(run) for y in range(5): session.add( ProjectionYearly( mc_run_id=run.id, year_idx=y, p10_portfolio_gbp=Decimal("900000"), p25_portfolio_gbp=Decimal("950000"), p50_portfolio_gbp=Decimal("1000000"), p75_portfolio_gbp=Decimal("1100000"), p90_portfolio_gbp=Decimal("1200000"), p50_withdrawal_gbp=Decimal("60000"), p50_tax_gbp=Decimal("8000"), survival_rate=Decimal("1.0"), )) await session.commit() resp = await client.get(f"/scenarios/{scen.id}/projection") assert resp.status_code == 200 body = resp.json() assert body["scenario_id"] == scen.id assert body["n_paths"] == 1000 assert len(body["yearly"]) == 5 assert Decimal(body["yearly"][0]["p50_portfolio_gbp"]) == Decimal("1000000")