137 lines
4.2 KiB
Python
137 lines
4.2 KiB
Python
|
|
"""Tests for the Cash Flow / Sankey endpoint."""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from collections.abc import AsyncIterator
|
||
|
|
from datetime import UTC, datetime
|
||
|
|
from decimal import Decimal
|
||
|
|
|
||
|
|
import pytest_asyncio
|
||
|
|
from httpx import ASGITransport, AsyncClient
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||
|
|
|
||
|
|
from fire_planner.api.dependencies import get_session
|
||
|
|
from fire_planner.app import app
|
||
|
|
from fire_planner.db import IncomeStream, McRun, ProjectionYearly, Scenario
|
||
|
|
|
||
|
|
|
||
|
|
@pytest_asyncio.fixture
|
||
|
|
async def client(engine: AsyncEngine,
|
||
|
|
session: AsyncSession) -> AsyncIterator[AsyncClient]:
|
||
|
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||
|
|
|
||
|
|
async def _override() -> AsyncIterator[AsyncSession]:
|
||
|
|
async with factory() as s:
|
||
|
|
yield s
|
||
|
|
|
||
|
|
app.dependency_overrides[get_session] = _override
|
||
|
|
transport = ASGITransport(app=app)
|
||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||
|
|
yield ac
|
||
|
|
app.dependency_overrides.clear()
|
||
|
|
|
||
|
|
|
||
|
|
async def _seed(session: AsyncSession) -> int:
|
||
|
|
scen = Scenario(
|
||
|
|
external_id="user-cf",
|
||
|
|
kind="user",
|
||
|
|
name="Cashflow",
|
||
|
|
jurisdiction="uk",
|
||
|
|
strategy="trinity",
|
||
|
|
leave_uk_year=0,
|
||
|
|
glide_path="static",
|
||
|
|
spending_gbp=Decimal("60000"),
|
||
|
|
horizon_years=5,
|
||
|
|
nw_seed_gbp=Decimal("1000000"),
|
||
|
|
savings_per_year_gbp=Decimal("0"),
|
||
|
|
config_json={},
|
||
|
|
)
|
||
|
|
session.add(scen)
|
||
|
|
await session.commit()
|
||
|
|
await session.refresh(scen)
|
||
|
|
|
||
|
|
run = McRun(
|
||
|
|
scenario_id=scen.id,
|
||
|
|
run_at=datetime.now(UTC),
|
||
|
|
n_paths=10,
|
||
|
|
seed=1,
|
||
|
|
success_rate=Decimal("1"),
|
||
|
|
p10_ending_gbp=Decimal("0"),
|
||
|
|
p50_ending_gbp=Decimal("0"),
|
||
|
|
p90_ending_gbp=Decimal("0"),
|
||
|
|
median_lifetime_tax_gbp=Decimal("0"),
|
||
|
|
median_years_to_ruin=None,
|
||
|
|
elapsed_seconds=Decimal("0"),
|
||
|
|
)
|
||
|
|
session.add(run)
|
||
|
|
await session.commit()
|
||
|
|
await session.refresh(run)
|
||
|
|
|
||
|
|
yearly = [
|
||
|
|
ProjectionYearly(
|
||
|
|
mc_run_id=run.id,
|
||
|
|
year_idx=y,
|
||
|
|
p10_portfolio_gbp=Decimal("900000"),
|
||
|
|
p25_portfolio_gbp=Decimal("950000"),
|
||
|
|
p50_portfolio_gbp=Decimal(str(1_000_000 + y * 50_000)),
|
||
|
|
p75_portfolio_gbp=Decimal("1100000"),
|
||
|
|
p90_portfolio_gbp=Decimal("1200000"),
|
||
|
|
p50_withdrawal_gbp=Decimal("60000"),
|
||
|
|
p50_tax_gbp=Decimal("8000"),
|
||
|
|
survival_rate=Decimal("1"),
|
||
|
|
) for y in range(3)
|
||
|
|
]
|
||
|
|
session.add_all(yearly)
|
||
|
|
|
||
|
|
stream = IncomeStream(
|
||
|
|
scenario_id=scen.id,
|
||
|
|
kind="salary",
|
||
|
|
name="Day job",
|
||
|
|
start_year=0,
|
||
|
|
end_year=2,
|
||
|
|
amount_gbp_per_year=Decimal("80000"),
|
||
|
|
growth_pct=Decimal("0"),
|
||
|
|
tax_treatment="income",
|
||
|
|
enabled=True,
|
||
|
|
)
|
||
|
|
session.add(stream)
|
||
|
|
await session.commit()
|
||
|
|
return scen.id
|
||
|
|
|
||
|
|
|
||
|
|
async def test_cashflow_balances(client: AsyncClient, session: AsyncSession) -> None:
|
||
|
|
sid = await _seed(session)
|
||
|
|
resp = await client.get(f"/scenarios/{sid}/cashflow?year=1")
|
||
|
|
assert resp.status_code == 200, resp.text
|
||
|
|
body = resp.json()
|
||
|
|
sources_total = sum(Decimal(v) for v in body["sources"].values())
|
||
|
|
sinks_total = sum(Decimal(v) for v in body["sinks"].values())
|
||
|
|
assert sources_total == sinks_total
|
||
|
|
# Salary should appear as a source.
|
||
|
|
assert any(k.startswith("income:") for k in body["sources"])
|
||
|
|
# Spending and taxes are always sinks.
|
||
|
|
assert "spending" in body["sinks"]
|
||
|
|
assert "taxes" in body["sinks"]
|
||
|
|
|
||
|
|
|
||
|
|
async def test_cashflow_404_when_no_run(client: AsyncClient,
|
||
|
|
session: AsyncSession) -> None:
|
||
|
|
scen = Scenario(
|
||
|
|
external_id="user-no-run-cf",
|
||
|
|
kind="user",
|
||
|
|
name="No run cf",
|
||
|
|
jurisdiction="uk",
|
||
|
|
strategy="trinity",
|
||
|
|
leave_uk_year=0,
|
||
|
|
glide_path="static",
|
||
|
|
spending_gbp=Decimal("60000"),
|
||
|
|
horizon_years=5,
|
||
|
|
nw_seed_gbp=Decimal("1000000"),
|
||
|
|
savings_per_year_gbp=Decimal("0"),
|
||
|
|
config_json={},
|
||
|
|
)
|
||
|
|
session.add(scen)
|
||
|
|
await session.commit()
|
||
|
|
await session.refresh(scen)
|
||
|
|
resp = await client.get(f"/scenarios/{scen.id}/cashflow?year=0")
|
||
|
|
assert resp.status_code == 404
|