"""Tests for income-stream CRUD + simulator integration.""" from __future__ import annotations from collections.abc import AsyncIterator from decimal import Decimal import numpy as np import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from fire_planner.api.dependencies import get_session from fire_planner.api.schemas import IncomeStreamInput, SimulateRequest from fire_planner.api.simulate import _project from fire_planner.app import app from fire_planner.db import Scenario from fire_planner.income_streams import ( IncomeStreamInput as EngineIncomeStream, ) from fire_planner.income_streams import streams_to_arrays @pytest_asyncio.fixture async def client(engine: AsyncEngine, session: AsyncSession) -> AsyncIterator[AsyncClient]: factory = async_sessionmaker(engine, expire_on_commit=False) async def _override() -> AsyncIterator[AsyncSession]: async with factory() as s: yield s app.dependency_overrides[get_session] = _override transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() async def _seed_scenario(session: AsyncSession) -> int: scen = Scenario( external_id="user-host", kind="user", name="Host plan", jurisdiction="uk", strategy="trinity", leave_uk_year=0, glide_path="static", spending_gbp=Decimal("60000"), nw_seed_gbp=Decimal("1000000"), savings_per_year_gbp=Decimal("0"), config_json={}, ) session.add(scen) await session.commit() await session.refresh(scen) return scen.id # ── streams_to_arrays unit tests ───────────────────────────────────── def test_streams_to_arrays_with_growth() -> None: streams = [ EngineIncomeStream( kind="salary", start_year=0, end_year=2, amount_gbp_per_year=10_000, growth_pct=0.05, tax_treatment="income", enabled=True, ), ] inflows, taxable = streams_to_arrays(streams, horizon_years=5) # year 0: 10_000; year 1: 10_500; year 2: 11_025; years 3+: 0 assert inflows[0] == 10_000.0 assert inflows[1] == 10_500.0 assert inflows[2] == 11_025.0 assert inflows[3] == 0.0 # Income-treated streams add to taxable. assert taxable[0] == inflows[0] def test_streams_to_arrays_tax_free_excluded_from_taxable() -> None: streams = [ EngineIncomeStream( kind="dividend", start_year=0, end_year=0, amount_gbp_per_year=5_000, tax_treatment="tax_free", enabled=True, ), ] inflows, taxable = streams_to_arrays(streams, horizon_years=2) assert inflows[0] == 5_000.0 assert taxable[0] == 0.0 def test_streams_to_arrays_disabled_skipped() -> None: streams = [ EngineIncomeStream( kind="salary", amount_gbp_per_year=10_000, enabled=False, ), ] inflows, taxable = streams_to_arrays(streams, horizon_years=2) assert inflows.sum() == 0.0 assert taxable.sum() == 0.0 # ── CRUD endpoint tests ────────────────────────────────────────────── async def test_create_and_list_income_streams( client: AsyncClient, session: AsyncSession, ) -> None: sid = await _seed_scenario(session) create = await client.post( f"/scenarios/{sid}/income-streams", json={ "kind": "salary", "name": "Day job", "start_year": 0, "end_year": 10, "amount_gbp_per_year": "60000", "growth_pct": "0.02", "tax_treatment": "income", }, ) assert create.status_code == 201 payload = create.json() assert payload["name"] == "Day job" assert payload["scenario_id"] == sid listed = await client.get(f"/scenarios/{sid}/income-streams") assert listed.status_code == 200 rows = listed.json() assert len(rows) == 1 assert rows[0]["kind"] == "salary" async def test_patch_and_delete_income_stream( client: AsyncClient, session: AsyncSession, ) -> None: sid = await _seed_scenario(session) create = await client.post( f"/scenarios/{sid}/income-streams", json={ "kind": "rental", "name": "Flat 2", "amount_gbp_per_year": "12000", }, ) stream_id = create.json()["id"] patch = await client.patch( f"/income-streams/{stream_id}", json={"amount_gbp_per_year": "15000"}, ) assert patch.status_code == 200 assert patch.json()["amount_gbp_per_year"] == "15000.00" del_resp = await client.delete(f"/income-streams/{stream_id}") assert del_resp.status_code == 204 listed = await client.get(f"/scenarios/{sid}/income-streams") assert listed.json() == [] async def test_invalid_year_range_rejected( client: AsyncClient, session: AsyncSession, ) -> None: sid = await _seed_scenario(session) bad = await client.post( f"/scenarios/{sid}/income-streams", json={ "kind": "other", "name": "Backwards", "start_year": 5, "end_year": 2, "amount_gbp_per_year": "1000", }, ) assert bad.status_code == 400 # ── simulate integration: a £50k stream year 5-15 lifts median NW ──── def test_simulate_with_income_stream_lifts_median() -> None: paths = np.zeros((100, 30, 3), dtype=np.float64) paths[:, :, 0] = 0.07 # nominal stocks paths[:, :, 1] = 0.03 # nominal bonds paths[:, :, 2] = 0.02 # cpi base_req = SimulateRequest( jurisdiction="uae", strategy="trinity", leave_uk_year=0, spending_gbp=Decimal("20000"), nw_seed_gbp=Decimal("2000000"), horizon_years=30, n_paths=100, seed=1, rates_mode=None, ) req_with = base_req.model_copy(update={ "income_streams": [ IncomeStreamInput( kind="dividend", start_year=5, end_year=15, amount_gbp_per_year=Decimal("50000"), growth_pct=Decimal("0.02"), tax_treatment="tax_free", ), ], }) base_result, _ = _project(base_req, paths) with_result, _ = _project(req_with, paths) base_median_end = float(np.median(base_result.portfolio_real[:, -1])) with_median_end = float(np.median(with_result.portfolio_real[:, -1])) assert with_median_end > base_median_end