Initial extraction from monorepo
This commit is contained in:
commit
f7ef7ca4ab
56 changed files with 6163 additions and 0 deletions
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
36
tests/conftest.py
Normal file
36
tests/conftest.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""Shared pytest fixtures.
|
||||
|
||||
Tests run against an in-memory SQLite DB created via the SQLAlchemy ORM
|
||||
metadata directly — fast, deterministic, and avoids running Alembic
|
||||
end-to-end on every test (the migration is exercised separately).
|
||||
"""
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from fire_planner.db import SCHEMA_NAME, Base
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def engine() -> AsyncIterator[AsyncEngine]:
|
||||
eng = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
async with eng.begin() as conn:
|
||||
# SQLite has no schema concept — attach an in-memory DB under the
|
||||
# `fire_planner` name so `__table_args__ = {"schema": ...}` resolves.
|
||||
await conn.exec_driver_sql(f"ATTACH DATABASE ':memory:' AS {SCHEMA_NAME}")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield eng
|
||||
await eng.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def session(engine: AsyncEngine) -> AsyncIterator[AsyncSession]:
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with factory() as sess:
|
||||
yield sess
|
||||
100
tests/test_cli.py
Normal file
100
tests/test_cli.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
"""CLI smoke tests via click's CliRunner."""
|
||||
from click.testing import CliRunner
|
||||
|
||||
from fire_planner.__main__ import cli
|
||||
|
||||
|
||||
def test_simulate_smoke() -> None:
|
||||
"""Run a tiny scenario through the CLI without writing to DB."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"simulate",
|
||||
"--scenario=cyprus-trinity-leave-y3-glide-rising",
|
||||
"--n-paths=200",
|
||||
"--horizon=20",
|
||||
"--spending=100000",
|
||||
"--nw-seed=1000000",
|
||||
"--no-write-db",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "Scenario: cyprus-trinity-leave-y3-glide-rising" in result.output
|
||||
assert "success_rate" in result.output
|
||||
|
||||
|
||||
def test_simulate_with_underscore_strategy() -> None:
|
||||
"""guyton_klinger contains an underscore — the parser must handle it."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"simulate",
|
||||
"--scenario=uk-guyton_klinger-leave-y1-glide-static_60_40",
|
||||
"--n-paths=100",
|
||||
"--horizon=15",
|
||||
"--spending=80000",
|
||||
"--nw-seed=1500000",
|
||||
"--no-write-db",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "uk-guyton_klinger-leave-y1-glide-static_60_40" in result.output
|
||||
|
||||
|
||||
def test_simulate_bad_scenario_id() -> None:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["simulate", "--scenario=nope"], catch_exceptions=False)
|
||||
assert result.exit_code != 0
|
||||
|
||||
|
||||
def test_simulate_vpw_floor_with_floor_flag() -> None:
|
||||
"""vpw_floor strategy + --floor=40000 should run without error."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"simulate",
|
||||
"--scenario=cyprus-vpw_floor-leave-y2-glide-rising",
|
||||
"--n-paths=200",
|
||||
"--horizon=20",
|
||||
"--spending=60000",
|
||||
"--nw-seed=1500000",
|
||||
"--floor=40000",
|
||||
"--no-write-db",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "cyprus-vpw_floor" in result.output
|
||||
|
||||
|
||||
def test_simulate_uae_smoke() -> None:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"simulate",
|
||||
"--scenario=uae-vpw_floor-leave-y2-glide-rising",
|
||||
"--n-paths=200",
|
||||
"--horizon=20",
|
||||
"--spending=60000",
|
||||
"--nw-seed=1500000",
|
||||
"--floor=40000",
|
||||
"--no-write-db",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "uae-vpw_floor" in result.output
|
||||
|
||||
|
||||
def test_help_lists_commands() -> None:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["--help"], catch_exceptions=False)
|
||||
assert result.exit_code == 0
|
||||
for cmd in ("ingest", "simulate", "recompute-all", "migrate", "serve"):
|
||||
assert cmd in result.output
|
||||
111
tests/test_db_schema.py
Normal file
111
tests/test_db_schema.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Smoke-test the ORM schema — every table must round-trip a row."""
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.db import (
|
||||
AccountSnapshot,
|
||||
McPath,
|
||||
McRun,
|
||||
ProjectionYearly,
|
||||
Scenario,
|
||||
ScenarioSummary,
|
||||
)
|
||||
|
||||
|
||||
async def test_account_snapshot_roundtrip(session: AsyncSession) -> None:
|
||||
snap = AccountSnapshot(
|
||||
external_id="wealthfolio:account-1:2026-04-25",
|
||||
snapshot_date=date(2026, 4, 25),
|
||||
account_id="account-1",
|
||||
account_name="ISA",
|
||||
account_type="ISA",
|
||||
currency="GBP",
|
||||
market_value=Decimal("123456.78"),
|
||||
market_value_gbp=Decimal("123456.78"),
|
||||
)
|
||||
session.add(snap)
|
||||
await session.commit()
|
||||
result = await session.execute(select(AccountSnapshot))
|
||||
rows = result.scalars().all()
|
||||
assert len(rows) == 1
|
||||
assert rows[0].external_id == "wealthfolio:account-1:2026-04-25"
|
||||
|
||||
|
||||
async def test_scenario_roundtrip(session: AsyncSession) -> None:
|
||||
scen = Scenario(
|
||||
external_id="cyprus-vpw-leave-y3-glide-rising",
|
||||
jurisdiction="cyprus",
|
||||
strategy="vpw",
|
||||
leave_uk_year=3,
|
||||
glide_path="rising",
|
||||
spending_gbp=Decimal("100000"),
|
||||
nw_seed_gbp=Decimal("1000000"),
|
||||
savings_per_year_gbp=Decimal("100000"),
|
||||
config_json={"horizon_years": 60},
|
||||
)
|
||||
session.add(scen)
|
||||
await session.commit()
|
||||
result = await session.execute(select(Scenario))
|
||||
rows = result.scalars().all()
|
||||
assert len(rows) == 1
|
||||
assert rows[0].jurisdiction == "cyprus"
|
||||
|
||||
|
||||
async def test_mc_run_roundtrip(session: AsyncSession) -> None:
|
||||
run = McRun(
|
||||
scenario_id=1,
|
||||
n_paths=10000,
|
||||
seed=42,
|
||||
success_rate=Decimal("0.9412"),
|
||||
p10_ending_gbp=Decimal("250000"),
|
||||
p50_ending_gbp=Decimal("3500000"),
|
||||
p90_ending_gbp=Decimal("12000000"),
|
||||
median_lifetime_tax_gbp=Decimal("750000"),
|
||||
elapsed_seconds=Decimal("42.351"),
|
||||
)
|
||||
session.add(run)
|
||||
await session.commit()
|
||||
result = await session.execute(select(McRun))
|
||||
rows = result.scalars().all()
|
||||
assert len(rows) == 1
|
||||
assert rows[0].n_paths == 10000
|
||||
|
||||
|
||||
async def test_remaining_tables_smoke(session: AsyncSession) -> None:
|
||||
session.add(
|
||||
McPath(mc_run_id=1,
|
||||
path_idx=0,
|
||||
bucket="median",
|
||||
year_idx=0,
|
||||
portfolio_gbp=Decimal("1000000"),
|
||||
withdrawal_gbp=Decimal("100000"),
|
||||
tax_paid_gbp=Decimal("0"),
|
||||
real_portfolio_gbp=Decimal("1000000")))
|
||||
session.add(
|
||||
ProjectionYearly(mc_run_id=1,
|
||||
year_idx=0,
|
||||
p10_portfolio_gbp=Decimal("800000"),
|
||||
p25_portfolio_gbp=Decimal("900000"),
|
||||
p50_portfolio_gbp=Decimal("1000000"),
|
||||
p75_portfolio_gbp=Decimal("1100000"),
|
||||
p90_portfolio_gbp=Decimal("1200000"),
|
||||
p50_withdrawal_gbp=Decimal("100000"),
|
||||
p50_tax_gbp=Decimal("0"),
|
||||
survival_rate=Decimal("1")))
|
||||
session.add(
|
||||
ScenarioSummary(scenario_id=1,
|
||||
mc_run_id=1,
|
||||
jurisdiction="uk",
|
||||
strategy="trinity",
|
||||
leave_uk_year=0,
|
||||
glide_path="static",
|
||||
spending_gbp=Decimal("100000"),
|
||||
success_rate=Decimal("0.95"),
|
||||
p10_ending_gbp=Decimal("200000"),
|
||||
p50_ending_gbp=Decimal("3000000"),
|
||||
p90_ending_gbp=Decimal("10000000"),
|
||||
median_lifetime_tax_gbp=Decimal("800000")))
|
||||
await session.commit()
|
||||
113
tests/test_e2e.py
Normal file
113
tests/test_e2e.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""End-to-end smoke: scenario builder → simulator → reporter → SQLite.
|
||||
|
||||
Exercises the same pipeline `recompute-all` runs in production, but on
|
||||
SQLite (no Postgres needed). Catches integration breakage early.
|
||||
"""
|
||||
from decimal import Decimal
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.db import McRun, Scenario, ScenarioSummary
|
||||
from fire_planner.glide_path import get as get_glide
|
||||
from fire_planner.reporters.pg import write_run
|
||||
from fire_planner.returns.bootstrap import block_bootstrap
|
||||
from fire_planner.returns.shiller import synthetic_returns
|
||||
from fire_planner.scenarios import build_regime_schedule, build_strategy, cartesian_scenarios
|
||||
from fire_planner.simulator import simulate
|
||||
|
||||
|
||||
async def test_full_pipeline_persists_summary_per_scenario(session: AsyncSession) -> None:
|
||||
"""Run a tiny Cartesian (2 jurisdictions × 1 strategy × 1 leave × 1 glide
|
||||
= 2 scenarios) end-to-end. Verifies scenario, mc_run, and
|
||||
scenario_summary all populate."""
|
||||
bundle = synthetic_returns(seed=1, n_years=120)
|
||||
paths = block_bootstrap(bundle,
|
||||
n_paths=200,
|
||||
n_years=20,
|
||||
block_size=5,
|
||||
rng=np.random.default_rng(0))
|
||||
specs = cartesian_scenarios(
|
||||
spending_gbp=Decimal("80000"),
|
||||
nw_seed_gbp=Decimal("1500000"),
|
||||
horizon_years=20,
|
||||
jurisdictions=("uk", "cyprus"),
|
||||
strategies=("trinity", ),
|
||||
leave_years=(2, ),
|
||||
glides=("rising", ),
|
||||
)
|
||||
assert len(specs) == 2
|
||||
for spec in specs:
|
||||
result = simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=float(spec.nw_seed_gbp),
|
||||
spending_target=float(spec.spending_gbp),
|
||||
glide=get_glide(spec.glide_path),
|
||||
strategy=build_strategy(spec.strategy),
|
||||
regime=build_regime_schedule(spec.jurisdiction, spec.leave_uk_year),
|
||||
horizon_years=spec.horizon_years,
|
||||
)
|
||||
await write_run(session, spec, result, seed=42, elapsed_seconds=0.5)
|
||||
await session.commit()
|
||||
|
||||
scenarios = (await session.execute(select(Scenario))).scalars().all()
|
||||
assert {s.external_id
|
||||
for s in scenarios} == {
|
||||
"uk-trinity-leave-y2-glide-rising",
|
||||
"cyprus-trinity-leave-y2-glide-rising",
|
||||
}
|
||||
|
||||
runs = (await session.execute(select(McRun))).scalars().all()
|
||||
assert len(runs) == 2
|
||||
|
||||
summaries = (await session.execute(select(ScenarioSummary))).scalars().all()
|
||||
assert len(summaries) == 2
|
||||
|
||||
# Cyprus median_lifetime_tax should be lower than UK's for the same
|
||||
# scenario shape — the canonical Phase 8 sanity test.
|
||||
by_jur = {s.jurisdiction: s for s in summaries}
|
||||
assert by_jur["cyprus"].median_lifetime_tax_gbp < by_jur["uk"].median_lifetime_tax_gbp
|
||||
|
||||
|
||||
async def test_pipeline_handles_recompute_idempotency(session: AsyncSession) -> None:
|
||||
"""Running the same scenario twice must result in 1 scenario row,
|
||||
2 mc_run rows, and 1 scenario_summary row pointing at the latest run."""
|
||||
bundle = synthetic_returns(seed=2, n_years=60)
|
||||
paths = block_bootstrap(bundle,
|
||||
n_paths=100,
|
||||
n_years=15,
|
||||
block_size=5,
|
||||
rng=np.random.default_rng(0))
|
||||
spec = next(
|
||||
iter(
|
||||
cartesian_scenarios(
|
||||
spending_gbp=Decimal("100000"),
|
||||
nw_seed_gbp=Decimal("1000000"),
|
||||
horizon_years=15,
|
||||
jurisdictions=("bulgaria", ),
|
||||
strategies=("vpw", ),
|
||||
leave_years=(1, ),
|
||||
glides=("static_60_40", ),
|
||||
)))
|
||||
for run in range(2):
|
||||
result = simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=float(spec.nw_seed_gbp),
|
||||
spending_target=float(spec.spending_gbp),
|
||||
glide=get_glide(spec.glide_path),
|
||||
strategy=build_strategy(spec.strategy),
|
||||
regime=build_regime_schedule(spec.jurisdiction, spec.leave_uk_year),
|
||||
horizon_years=spec.horizon_years,
|
||||
)
|
||||
await write_run(session, spec, result, seed=run, elapsed_seconds=0.2)
|
||||
await session.commit()
|
||||
|
||||
scenarios = (await session.execute(select(Scenario))).scalars().all()
|
||||
assert len(scenarios) == 1
|
||||
|
||||
runs = (await session.execute(select(McRun))).scalars().all()
|
||||
assert len(runs) == 2
|
||||
|
||||
summaries = (await session.execute(select(ScenarioSummary))).scalars().all()
|
||||
assert len(summaries) == 1
|
||||
97
tests/test_ingest_wealthfolio.py
Normal file
97
tests/test_ingest_wealthfolio.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""Wealthfolio ingest reads a real-shape sqlite and upserts cleanly."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.db import AccountSnapshot
|
||||
from fire_planner.ingest.wealthfolio import read_account_snapshots, upsert_snapshots
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wealthfolio_db(tmp_path: Path) -> Path:
|
||||
"""Create a minimal sqlite mimicking Wealthfolio's schema."""
|
||||
db_path = tmp_path / "wealthfolio.db"
|
||||
conn = sqlite3.connect(db_path)
|
||||
cur = conn.cursor()
|
||||
cur.executescript("""
|
||||
CREATE TABLE accounts (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT,
|
||||
type TEXT,
|
||||
currency TEXT
|
||||
);
|
||||
CREATE TABLE holdings_snapshot (
|
||||
account_id TEXT,
|
||||
snapshot_date TEXT,
|
||||
symbol TEXT,
|
||||
market_value REAL,
|
||||
market_value_gbp REAL
|
||||
);
|
||||
INSERT INTO accounts VALUES ('acc-isa', 'ISA', 'ISA', 'GBP');
|
||||
INSERT INTO accounts VALUES ('acc-schwab', 'Schwab', 'BROKERAGE', 'USD');
|
||||
INSERT INTO holdings_snapshot VALUES ('acc-isa', '2026-04-25', 'VWRL', 200000, 200000);
|
||||
INSERT INTO holdings_snapshot VALUES ('acc-isa', '2026-04-25', 'BND', 100000, 100000);
|
||||
INSERT INTO holdings_snapshot VALUES ('acc-schwab', '2026-04-25', 'META', 800000, 640000);
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db_path
|
||||
|
||||
|
||||
def test_read_groups_holdings_per_account(wealthfolio_db: Path) -> None:
|
||||
rows = read_account_snapshots(wealthfolio_db)
|
||||
assert len(rows) == 2
|
||||
by_id = {r["account_id"]: r for r in rows}
|
||||
assert by_id["acc-isa"]["market_value_gbp"] == Decimal("300000")
|
||||
assert by_id["acc-schwab"]["market_value_gbp"] == Decimal("640000")
|
||||
assert by_id["acc-isa"]["snapshot_date"] == date(2026, 4, 25)
|
||||
|
||||
|
||||
def test_read_returns_empty_on_unknown_schema(tmp_path: Path) -> None:
|
||||
"""If the sqlite has a totally different shape, return [] rather
|
||||
than blow up — let the operator surface the warning."""
|
||||
db = tmp_path / "weird.db"
|
||||
conn = sqlite3.connect(db)
|
||||
conn.execute("CREATE TABLE foo (x INTEGER)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
assert read_account_snapshots(db) == []
|
||||
|
||||
|
||||
def test_read_missing_file_raises(tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError):
|
||||
read_account_snapshots(tmp_path / "nope.db")
|
||||
|
||||
|
||||
async def test_upsert_inserts_new_rows(session: AsyncSession, wealthfolio_db: Path) -> None:
|
||||
rows = read_account_snapshots(wealthfolio_db)
|
||||
n = await upsert_snapshots(session, rows)
|
||||
await session.commit()
|
||||
assert n == 2
|
||||
persisted = (await session.execute(select(AccountSnapshot))).scalars().all()
|
||||
assert len(persisted) == 2
|
||||
by_id = {p.account_id: p for p in persisted}
|
||||
assert by_id["acc-isa"].market_value_gbp == Decimal("300000")
|
||||
|
||||
|
||||
async def test_upsert_is_idempotent(session: AsyncSession, wealthfolio_db: Path) -> None:
|
||||
rows = read_account_snapshots(wealthfolio_db)
|
||||
await upsert_snapshots(session, rows)
|
||||
await session.commit()
|
||||
# Run again — should still be 2 rows, not 4
|
||||
await upsert_snapshots(session, rows)
|
||||
await session.commit()
|
||||
persisted = (await session.execute(select(AccountSnapshot))).scalars().all()
|
||||
assert len(persisted) == 2
|
||||
|
||||
|
||||
async def test_upsert_zero_rows_is_noop(session: AsyncSession) -> None:
|
||||
n = await upsert_snapshots(session, [])
|
||||
assert n == 0
|
||||
93
tests/test_reporters_pg.py
Normal file
93
tests/test_reporters_pg.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""Postgres reporter — write_run round-trips into the schema."""
|
||||
from decimal import Decimal
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.db import McRun, ProjectionYearly, ScenarioSummary
|
||||
from fire_planner.glide_path import static
|
||||
from fire_planner.reporters.pg import write_run
|
||||
from fire_planner.scenarios import ScenarioSpec
|
||||
from fire_planner.simulator import simulate
|
||||
from fire_planner.strategies.trinity import TrinityStrategy
|
||||
from fire_planner.tax.malaysia import MalaysiaTaxRegime
|
||||
|
||||
|
||||
def fixed_paths(n_paths: int, n_years: int) -> np.ndarray:
|
||||
out = np.zeros((n_paths, n_years, 3))
|
||||
out[..., 0] = 0.05
|
||||
out[..., 1] = 0.03
|
||||
out[..., 2] = 0.02
|
||||
return out
|
||||
|
||||
|
||||
async def test_write_run_persists_summary_run_and_projection(session: AsyncSession) -> None:
|
||||
spec = ScenarioSpec(
|
||||
jurisdiction="cyprus",
|
||||
strategy="trinity",
|
||||
leave_uk_year=3,
|
||||
glide_path="rising",
|
||||
spending_gbp=Decimal("100000"),
|
||||
nw_seed_gbp=Decimal("1000000"),
|
||||
horizon_years=20,
|
||||
)
|
||||
paths = fixed_paths(50, 20)
|
||||
result = simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=1_000_000.0,
|
||||
spending_target=40_000.0,
|
||||
glide=static(0.7),
|
||||
strategy=TrinityStrategy(),
|
||||
regime=MalaysiaTaxRegime(),
|
||||
horizon_years=20,
|
||||
)
|
||||
summary = await write_run(session, spec, result, seed=42, elapsed_seconds=1.5)
|
||||
await session.commit()
|
||||
|
||||
runs = (await session.execute(select(McRun))).scalars().all()
|
||||
assert len(runs) == 1
|
||||
assert runs[0].id == summary.mc_run_id
|
||||
assert runs[0].n_paths == 50
|
||||
|
||||
projections = (await session.execute(select(ProjectionYearly))).scalars().all()
|
||||
assert len(projections) == 20 # one row per year
|
||||
summaries = (await session.execute(select(ScenarioSummary))).scalars().all()
|
||||
assert len(summaries) == 1
|
||||
assert summaries[0].jurisdiction == "cyprus"
|
||||
|
||||
|
||||
async def test_write_run_idempotent_summary(session: AsyncSession) -> None:
|
||||
"""Running twice for the same scenario should keep summary at one row,
|
||||
pointing at the latest run."""
|
||||
spec = ScenarioSpec(
|
||||
jurisdiction="bulgaria",
|
||||
strategy="vpw",
|
||||
leave_uk_year=2,
|
||||
glide_path="static_60_40",
|
||||
spending_gbp=Decimal("100000"),
|
||||
nw_seed_gbp=Decimal("1000000"),
|
||||
horizon_years=20,
|
||||
)
|
||||
paths = fixed_paths(20, 20)
|
||||
result = simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=1_000_000.0,
|
||||
spending_target=40_000.0,
|
||||
glide=static(0.6),
|
||||
strategy=TrinityStrategy(),
|
||||
regime=MalaysiaTaxRegime(),
|
||||
horizon_years=20,
|
||||
)
|
||||
s1 = await write_run(session, spec, result, seed=42, elapsed_seconds=1.0)
|
||||
await session.commit()
|
||||
s2 = await write_run(session, spec, result, seed=43, elapsed_seconds=1.5)
|
||||
await session.commit()
|
||||
assert s1.scenario_id == s2.scenario_id
|
||||
assert s2.mc_run_id != s1.mc_run_id
|
||||
|
||||
runs = (await session.execute(select(McRun))).scalars().all()
|
||||
assert len(runs) == 2
|
||||
summaries = (await session.execute(select(ScenarioSummary))).scalars().all()
|
||||
assert len(summaries) == 1
|
||||
assert summaries[0].mc_run_id == s2.mc_run_id
|
||||
126
tests/test_returns.py
Normal file
126
tests/test_returns.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
"""Returns loader + bootstrap behaviour."""
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from fire_planner.returns.bootstrap import block_bootstrap
|
||||
from fire_planner.returns.shiller import ReturnsBundle, load_from_csv, synthetic_returns
|
||||
|
||||
|
||||
def test_synthetic_returns_shape() -> None:
|
||||
b = synthetic_returns(seed=1, n_years=120)
|
||||
assert b.n_years == 120
|
||||
assert b.stock_nominal.shape == (120, )
|
||||
assert b.years[0] == 1871
|
||||
|
||||
|
||||
def test_synthetic_deterministic_for_seed() -> None:
|
||||
a = synthetic_returns(seed=42, n_years=10)
|
||||
b = synthetic_returns(seed=42, n_years=10)
|
||||
np.testing.assert_array_equal(a.stock_nominal, b.stock_nominal)
|
||||
|
||||
|
||||
def test_real_return_smoke() -> None:
|
||||
b = ReturnsBundle(
|
||||
years=np.array([2020], dtype=np.int32),
|
||||
stock_nominal=np.array([0.10]),
|
||||
bond_nominal=np.array([0.04]),
|
||||
cpi=np.array([0.03]),
|
||||
)
|
||||
# (1.10 / 1.03) - 1 ≈ 0.06796
|
||||
assert abs(b.stock_real()[0] - 0.06796116505) < 1e-9
|
||||
|
||||
|
||||
def test_load_from_csv(tmp_path: Path) -> None:
|
||||
csv_path = tmp_path / "returns.csv"
|
||||
csv_path.write_text("year,stock_nominal_return,bond_nominal_return,cpi_inflation\n"
|
||||
"1990,0.05,0.07,0.025\n"
|
||||
"1991,-0.10,0.04,0.03\n")
|
||||
b = load_from_csv(csv_path)
|
||||
assert b.n_years == 2
|
||||
assert b.stock_nominal[1] == pytest.approx(-0.10)
|
||||
assert b.cpi[0] == pytest.approx(0.025)
|
||||
|
||||
|
||||
def test_returns_bundle_rejects_mismatched_lengths() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ReturnsBundle(
|
||||
years=np.array([2020, 2021], dtype=np.int32),
|
||||
stock_nominal=np.array([0.1]),
|
||||
bond_nominal=np.array([0.04, 0.05]),
|
||||
cpi=np.array([0.03, 0.025]),
|
||||
)
|
||||
|
||||
|
||||
def test_returns_bundle_rejects_empty() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ReturnsBundle(
|
||||
years=np.array([], dtype=np.int32),
|
||||
stock_nominal=np.array([]),
|
||||
bond_nominal=np.array([]),
|
||||
cpi=np.array([]),
|
||||
)
|
||||
|
||||
|
||||
def test_bootstrap_shape() -> None:
|
||||
bundle = synthetic_returns(seed=1, n_years=150)
|
||||
rng = np.random.default_rng(0)
|
||||
paths = block_bootstrap(bundle, n_paths=100, n_years=60, block_size=5, rng=rng)
|
||||
assert paths.shape == (100, 60, 3)
|
||||
|
||||
|
||||
def test_bootstrap_deterministic_with_seed() -> None:
|
||||
bundle = synthetic_returns(seed=1, n_years=150)
|
||||
a = block_bootstrap(bundle, n_paths=50, n_years=30, block_size=5, rng=np.random.default_rng(0))
|
||||
b = block_bootstrap(bundle, n_paths=50, n_years=30, block_size=5, rng=np.random.default_rng(0))
|
||||
np.testing.assert_array_equal(a, b)
|
||||
|
||||
|
||||
def test_bootstrap_block_size_one_is_iid() -> None:
|
||||
"""Block size 1 reduces to simple IID resampling — covariance
|
||||
structure isn't preserved, but all draws come from the source."""
|
||||
bundle = synthetic_returns(seed=2, n_years=100)
|
||||
rng = np.random.default_rng(0)
|
||||
paths = block_bootstrap(bundle, n_paths=10, n_years=20, block_size=1, rng=rng)
|
||||
src_set = set(zip(bundle.stock_nominal, bundle.bond_nominal, bundle.cpi, strict=True))
|
||||
drawn_set = set((float(s), float(b), float(c)) for path in paths for s, b, c in path)
|
||||
assert drawn_set <= src_set
|
||||
|
||||
|
||||
def test_bootstrap_preserves_block_runs() -> None:
|
||||
"""For block_size=5, every consecutive 5-year run within a path
|
||||
must equal a 5-year run from the source (mod circular)."""
|
||||
bundle = synthetic_returns(seed=3, n_years=50)
|
||||
rng = np.random.default_rng(0)
|
||||
paths = block_bootstrap(bundle, n_paths=5, n_years=15, block_size=5, rng=rng)
|
||||
src = np.stack([bundle.stock_nominal, bundle.bond_nominal, bundle.cpi], axis=-1)
|
||||
src_n = src.shape[0]
|
||||
for path in paths:
|
||||
for block_start in range(0, 15, 5):
|
||||
block = path[block_start:block_start + 5]
|
||||
# Find this block in source by matching the first row, then
|
||||
# checking consecutiveness (mod circular).
|
||||
for src_idx in range(src_n):
|
||||
circ_block = np.stack([src[(src_idx + i) % src_n] for i in range(5)])
|
||||
if np.allclose(block, circ_block):
|
||||
break
|
||||
else:
|
||||
raise AssertionError(f"block {block_start} not a circular slice of source")
|
||||
|
||||
|
||||
def test_bootstrap_rejects_zero_block_size() -> None:
|
||||
bundle = synthetic_returns(seed=1, n_years=30)
|
||||
with pytest.raises(ValueError):
|
||||
block_bootstrap(bundle, n_paths=10, n_years=10, block_size=0)
|
||||
|
||||
|
||||
def test_bootstrap_n_years_not_multiple_of_block() -> None:
|
||||
"""13 years from 5-year blocks: 3 blocks then truncate to 13."""
|
||||
bundle = synthetic_returns(seed=1, n_years=50)
|
||||
paths = block_bootstrap(bundle,
|
||||
n_paths=4,
|
||||
n_years=13,
|
||||
block_size=5,
|
||||
rng=np.random.default_rng(0))
|
||||
assert paths.shape == (4, 13, 3)
|
||||
113
tests/test_scenarios.py
Normal file
113
tests/test_scenarios.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""Cartesian scenario builder + strategy/regime factory."""
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from fire_planner.scenarios import (
|
||||
DEFAULT_GLIDES,
|
||||
DEFAULT_JURISDICTIONS,
|
||||
DEFAULT_LEAVE_YEARS,
|
||||
DEFAULT_STRATEGIES,
|
||||
ScenarioSpec,
|
||||
build_regime_schedule,
|
||||
build_strategy,
|
||||
cartesian_scenarios,
|
||||
)
|
||||
from fire_planner.strategies.guyton_klinger import GuytonKlingerStrategy
|
||||
from fire_planner.strategies.trinity import TrinityStrategy
|
||||
from fire_planner.strategies.vpw import VpwStrategy, VpwWithFloorStrategy
|
||||
from fire_planner.tax.bulgaria import BulgariaTaxRegime
|
||||
from fire_planner.tax.cyprus import CyprusTaxRegime
|
||||
from fire_planner.tax.uae import UaeTaxRegime
|
||||
from fire_planner.tax.uk import UkTaxRegime
|
||||
|
||||
|
||||
def test_default_cartesian_count_is_120() -> None:
|
||||
specs = cartesian_scenarios(spending_gbp=Decimal("100000"), nw_seed_gbp=Decimal("1000000"))
|
||||
expected = (len(DEFAULT_JURISDICTIONS) * len(DEFAULT_STRATEGIES) * len(DEFAULT_LEAVE_YEARS) *
|
||||
len(DEFAULT_GLIDES))
|
||||
assert expected == 120
|
||||
assert len(specs) == 120
|
||||
|
||||
|
||||
def test_external_id_format() -> None:
|
||||
spec = ScenarioSpec(
|
||||
jurisdiction="cyprus",
|
||||
strategy="vpw",
|
||||
leave_uk_year=3,
|
||||
glide_path="rising",
|
||||
spending_gbp=Decimal("100000"),
|
||||
nw_seed_gbp=Decimal("1000000"),
|
||||
)
|
||||
assert spec.external_id == "cyprus-vpw-leave-y3-glide-rising"
|
||||
|
||||
|
||||
def test_cartesian_unique_external_ids() -> None:
|
||||
specs = cartesian_scenarios(spending_gbp=Decimal("100000"), nw_seed_gbp=Decimal("1000000"))
|
||||
ids = [s.external_id for s in specs]
|
||||
assert len(ids) == len(set(ids))
|
||||
|
||||
|
||||
def test_build_strategy_dispatch() -> None:
|
||||
assert isinstance(build_strategy("trinity"), TrinityStrategy)
|
||||
assert isinstance(build_strategy("guyton_klinger"), GuytonKlingerStrategy)
|
||||
assert isinstance(build_strategy("vpw"), VpwStrategy)
|
||||
|
||||
|
||||
def test_build_strategy_vpw_floor_requires_floor() -> None:
|
||||
s = build_strategy("vpw_floor", floor=40_000.0)
|
||||
assert isinstance(s, VpwWithFloorStrategy)
|
||||
assert s.floor == 40_000.0
|
||||
|
||||
|
||||
def test_build_strategy_vpw_floor_missing_floor_raises() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
build_strategy("vpw_floor")
|
||||
|
||||
|
||||
def test_build_strategy_unknown_raises() -> None:
|
||||
with pytest.raises(KeyError):
|
||||
build_strategy("walmart")
|
||||
|
||||
|
||||
def test_build_regime_schedule_uae() -> None:
|
||||
fn = build_regime_schedule("uae", leave_uk_year=2)
|
||||
assert isinstance(fn(0), UkTaxRegime)
|
||||
assert isinstance(fn(1), UkTaxRegime)
|
||||
assert isinstance(fn(2), UaeTaxRegime)
|
||||
assert isinstance(fn(50), UaeTaxRegime)
|
||||
|
||||
|
||||
def test_build_regime_schedule_uk_constant() -> None:
|
||||
fn = build_regime_schedule("uk", leave_uk_year=3)
|
||||
# All years should resolve to UK
|
||||
assert isinstance(fn(0), UkTaxRegime)
|
||||
assert isinstance(fn(50), UkTaxRegime)
|
||||
|
||||
|
||||
def test_build_regime_schedule_cyprus_switches_at_leave_year() -> None:
|
||||
fn = build_regime_schedule("cyprus", leave_uk_year=3)
|
||||
assert isinstance(fn(0), UkTaxRegime)
|
||||
assert isinstance(fn(2), UkTaxRegime)
|
||||
assert isinstance(fn(3), CyprusTaxRegime)
|
||||
assert isinstance(fn(50), CyprusTaxRegime)
|
||||
|
||||
|
||||
def test_build_regime_schedule_bulgaria() -> None:
|
||||
fn = build_regime_schedule("bulgaria", leave_uk_year=1)
|
||||
assert isinstance(fn(0), UkTaxRegime)
|
||||
assert isinstance(fn(1), BulgariaTaxRegime)
|
||||
|
||||
|
||||
def test_build_regime_schedule_unknown_raises() -> None:
|
||||
with pytest.raises(KeyError):
|
||||
build_regime_schedule("madeupistan", leave_uk_year=3)
|
||||
|
||||
|
||||
def test_cartesian_unknown_glide_raises() -> None:
|
||||
with pytest.raises(KeyError):
|
||||
cartesian_scenarios(
|
||||
spending_gbp=Decimal("100000"),
|
||||
nw_seed_gbp=Decimal("1000000"),
|
||||
glides=("staircase", ),
|
||||
)
|
||||
259
tests/test_simulator.py
Normal file
259
tests/test_simulator.py
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
"""Simulator behaviour: deterministic short-horizon checks, then
|
||||
stochastic monotonicity + cFIREsim sanity calibration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fire_planner.glide_path import static
|
||||
from fire_planner.returns.bootstrap import block_bootstrap
|
||||
from fire_planner.returns.shiller import ReturnsBundle, synthetic_returns
|
||||
from fire_planner.simulator import default_bucket_split, simulate
|
||||
from fire_planner.strategies.guyton_klinger import GuytonKlingerStrategy
|
||||
from fire_planner.strategies.trinity import TrinityStrategy
|
||||
from fire_planner.strategies.vpw import VpwStrategy
|
||||
from fire_planner.tax.bulgaria import BulgariaTaxRegime
|
||||
from fire_planner.tax.malaysia import MalaysiaTaxRegime
|
||||
from fire_planner.tax.uk import UkTaxRegime
|
||||
|
||||
|
||||
def fixed_paths(n_paths: int, n_years: int, stock_ret: float, bond_ret: float,
|
||||
cpi: float) -> np.ndarray:
|
||||
"""All-paths-identical returns — deterministic regression check."""
|
||||
out = np.zeros((n_paths, n_years, 3), dtype=np.float64)
|
||||
out[..., 0] = stock_ret
|
||||
out[..., 1] = bond_ret
|
||||
out[..., 2] = cpi
|
||||
return out
|
||||
|
||||
|
||||
def test_simulate_zero_returns_zero_inflation_drains_at_4pc() -> None:
|
||||
"""0% returns + 0% inflation, 4% Trinity, 25y horizon — withdraw
|
||||
£40k/y from £1M = drain to exactly £0 in year 25. Success because
|
||||
portfolio stays positive *during* every year (clipped to 0 at end)."""
|
||||
paths = fixed_paths(n_paths=1, n_years=25, stock_ret=0.0, bond_ret=0.0, cpi=0.0)
|
||||
res = simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=1_000_000.0,
|
||||
spending_target=40_000.0,
|
||||
glide=static(0.6),
|
||||
strategy=TrinityStrategy(initial_rate=0.04),
|
||||
regime=MalaysiaTaxRegime(), # 0% to keep arithmetic clean
|
||||
)
|
||||
# Year 0 withdrawal is 40k, portfolio after = 960k
|
||||
assert res.portfolio_real[0, 1] == 960_000.0
|
||||
# 25y of £40k draws against zero growth = drain to 0 by end of y25.
|
||||
assert abs(res.portfolio_real[0, 25]) < 1.0
|
||||
|
||||
|
||||
def test_simulate_failing_path_marked_unsuccessful() -> None:
|
||||
"""6% Trinity rate against 0% real return for 25y — clearly fails."""
|
||||
paths = fixed_paths(n_paths=1, n_years=25, stock_ret=0.0, bond_ret=0.0, cpi=0.0)
|
||||
res = simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=1_000_000.0,
|
||||
spending_target=60_000.0,
|
||||
glide=static(1.0),
|
||||
strategy=TrinityStrategy(initial_rate=0.06),
|
||||
regime=MalaysiaTaxRegime(),
|
||||
)
|
||||
assert not res.success_mask[0]
|
||||
|
||||
|
||||
def test_simulate_growing_portfolio_succeeds() -> None:
|
||||
"""5% real return, 4% draw — classic surplus case."""
|
||||
paths = fixed_paths(n_paths=1, n_years=30, stock_ret=0.05, bond_ret=0.05, cpi=0.0)
|
||||
res = simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=1_000_000.0,
|
||||
spending_target=40_000.0,
|
||||
glide=static(1.0),
|
||||
strategy=TrinityStrategy(initial_rate=0.04),
|
||||
regime=MalaysiaTaxRegime(),
|
||||
)
|
||||
assert res.success_mask[0]
|
||||
# Portfolio should grow above starting value
|
||||
assert res.portfolio_real[0, 30] > 1_000_000.0
|
||||
|
||||
|
||||
def test_savings_phase_increases_portfolio() -> None:
|
||||
"""5y of savings @ £100k / 0% return → portfolio grows."""
|
||||
paths = fixed_paths(n_paths=1, n_years=5, stock_ret=0.0, bond_ret=0.0, cpi=0.0)
|
||||
res = simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=1_000_000.0,
|
||||
spending_target=0.0, # not drawing during accumulation
|
||||
glide=static(1.0),
|
||||
strategy=TrinityStrategy(initial_rate=0.0),
|
||||
regime=MalaysiaTaxRegime(),
|
||||
annual_savings=np.full(5, 100_000.0),
|
||||
)
|
||||
# 1M + 5×100k = 1.5M, no growth
|
||||
assert res.portfolio_real[0, 5] == 1_500_000.0
|
||||
|
||||
|
||||
def test_uk_tax_increases_failure_rate_vs_no_tax() -> None:
|
||||
"""Same scenario, UK regime should produce more or equal failures
|
||||
than Malaysia (zero tax) — paths are identical."""
|
||||
bundle = synthetic_returns(seed=1, n_years=120)
|
||||
rng = np.random.default_rng(0)
|
||||
paths = block_bootstrap(bundle, n_paths=200, n_years=30, block_size=5, rng=rng)
|
||||
common = dict(
|
||||
paths=paths,
|
||||
initial_portfolio=600_000.0, # tighter so tax matters
|
||||
spending_target=40_000.0,
|
||||
glide=static(0.7),
|
||||
strategy=TrinityStrategy(initial_rate=0.04),
|
||||
)
|
||||
msy = simulate(**common, regime=MalaysiaTaxRegime()) # type: ignore[arg-type]
|
||||
uk = simulate(**common, regime=UkTaxRegime()) # type: ignore[arg-type]
|
||||
assert uk.success_rate <= msy.success_rate
|
||||
assert uk.median_lifetime_tax() > msy.median_lifetime_tax()
|
||||
|
||||
|
||||
def test_vpw_never_runs_out() -> None:
|
||||
"""VPW scales withdrawal with portfolio — should never fully ruin."""
|
||||
bundle = synthetic_returns(seed=2, n_years=120)
|
||||
rng = np.random.default_rng(0)
|
||||
paths = block_bootstrap(bundle, n_paths=200, n_years=60, block_size=5, rng=rng)
|
||||
res = simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=1_000_000.0,
|
||||
spending_target=50_000.0,
|
||||
glide=static(0.7),
|
||||
strategy=VpwStrategy(),
|
||||
regime=MalaysiaTaxRegime(),
|
||||
)
|
||||
# Every path should keep some portfolio > 0 throughout (until last year).
|
||||
# Year `n-1` end may be tiny but >= 0 since VPW caps drain at 100% in y=H-1.
|
||||
assert res.portfolio_real[:, 1:-1].min() > 0
|
||||
|
||||
|
||||
def test_simulator_deterministic_with_same_paths() -> None:
|
||||
paths = fixed_paths(n_paths=10, n_years=30, stock_ret=0.05, bond_ret=0.03, cpi=0.02)
|
||||
common = dict(
|
||||
paths=paths,
|
||||
initial_portfolio=1_000_000.0,
|
||||
spending_target=40_000.0,
|
||||
glide=static(0.7),
|
||||
strategy=GuytonKlingerStrategy(),
|
||||
regime=BulgariaTaxRegime(),
|
||||
)
|
||||
a = simulate(**common) # type: ignore[arg-type]
|
||||
b = simulate(**common) # type: ignore[arg-type]
|
||||
np.testing.assert_array_equal(a.portfolio_real, b.portfolio_real)
|
||||
|
||||
|
||||
def test_success_rate_monotone_in_portfolio() -> None:
|
||||
"""More starting wealth → higher (or equal) success rate."""
|
||||
bundle = synthetic_returns(seed=3, n_years=120)
|
||||
rng = np.random.default_rng(0)
|
||||
paths = block_bootstrap(bundle, n_paths=300, n_years=30, block_size=5, rng=rng)
|
||||
common = dict(
|
||||
paths=paths,
|
||||
spending_target=40_000.0,
|
||||
glide=static(0.7),
|
||||
strategy=TrinityStrategy(initial_rate=0.04),
|
||||
regime=MalaysiaTaxRegime(),
|
||||
)
|
||||
low = simulate(**common, initial_portfolio=600_000.0) # type: ignore[arg-type]
|
||||
high = simulate(**common, initial_portfolio=1_500_000.0) # type: ignore[arg-type]
|
||||
assert high.success_rate >= low.success_rate
|
||||
|
||||
|
||||
def test_success_rate_monotone_in_spending() -> None:
|
||||
"""Less spending → higher success rate."""
|
||||
bundle = synthetic_returns(seed=4, n_years=120)
|
||||
rng = np.random.default_rng(0)
|
||||
paths = block_bootstrap(bundle, n_paths=300, n_years=30, block_size=5, rng=rng)
|
||||
common = dict(
|
||||
paths=paths,
|
||||
initial_portfolio=1_000_000.0,
|
||||
glide=static(0.7),
|
||||
strategy=TrinityStrategy(initial_rate=0.04),
|
||||
regime=MalaysiaTaxRegime(),
|
||||
)
|
||||
cheap = simulate(**common, spending_target=30_000.0) # type: ignore[arg-type]
|
||||
fat = simulate(**common, spending_target=80_000.0) # type: ignore[arg-type]
|
||||
assert cheap.success_rate >= fat.success_rate
|
||||
|
||||
|
||||
def test_fan_quantiles_shape() -> None:
|
||||
bundle = synthetic_returns(seed=5, n_years=120)
|
||||
paths = block_bootstrap(bundle,
|
||||
n_paths=100,
|
||||
n_years=20,
|
||||
block_size=5,
|
||||
rng=np.random.default_rng(0))
|
||||
res = simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=1_000_000.0,
|
||||
spending_target=40_000.0,
|
||||
glide=static(0.7),
|
||||
strategy=TrinityStrategy(),
|
||||
regime=MalaysiaTaxRegime(),
|
||||
)
|
||||
p10 = res.fan_quantiles(10)
|
||||
assert p10.shape == (21, ) # n_years + 1
|
||||
|
||||
|
||||
def test_perf_under_60s_for_10k_paths_60y() -> None:
|
||||
"""Stretch goal — at 10k paths × 60y the simulator should finish
|
||||
in well under a minute on commodity hardware. Test allows 60s
|
||||
(generous; CI can vary)."""
|
||||
bundle = synthetic_returns(seed=6, n_years=150)
|
||||
paths = block_bootstrap(bundle,
|
||||
n_paths=10_000,
|
||||
n_years=60,
|
||||
block_size=5,
|
||||
rng=np.random.default_rng(0))
|
||||
start = time.perf_counter()
|
||||
simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=1_000_000.0,
|
||||
spending_target=40_000.0,
|
||||
glide=static(0.7),
|
||||
strategy=TrinityStrategy(),
|
||||
regime=MalaysiaTaxRegime(),
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
assert elapsed < 60, f"too slow: {elapsed:.2f}s"
|
||||
|
||||
|
||||
def test_convergence_5k_vs_50k_paths() -> None:
|
||||
"""Success rate should be stable to within ±1.5% between 5k and
|
||||
50k paths (Monte Carlo SE ~0.5% at 10k samples)."""
|
||||
bundle = synthetic_returns(seed=7, n_years=150)
|
||||
paths_small = block_bootstrap(bundle,
|
||||
n_paths=5_000,
|
||||
n_years=30,
|
||||
block_size=5,
|
||||
rng=np.random.default_rng(0))
|
||||
paths_large = block_bootstrap(bundle,
|
||||
n_paths=50_000,
|
||||
n_years=30,
|
||||
block_size=5,
|
||||
rng=np.random.default_rng(0))
|
||||
common = dict(
|
||||
initial_portfolio=1_000_000.0,
|
||||
spending_target=40_000.0,
|
||||
glide=static(0.7),
|
||||
strategy=TrinityStrategy(),
|
||||
regime=MalaysiaTaxRegime(),
|
||||
)
|
||||
small = simulate(paths=paths_small, **common) # type: ignore[arg-type]
|
||||
large = simulate(paths=paths_large, **common) # type: ignore[arg-type]
|
||||
assert abs(small.success_rate - large.success_rate) < 0.015
|
||||
|
||||
|
||||
def test_default_bucket_split_smoke() -> None:
|
||||
inputs = default_bucket_split(50_000.0, year_idx=5)
|
||||
assert inputs.capital_gains == 50000
|
||||
|
||||
|
||||
def test_returns_bundle_supplies_ie_data_columns() -> None:
|
||||
"""Sanity: the bundle has stock/bond/cpi correctly aligned."""
|
||||
b = synthetic_returns(seed=8, n_years=10)
|
||||
assert isinstance(b, ReturnsBundle)
|
||||
assert len(b.stock_nominal) == 10
|
||||
155
tests/test_strategies.py
Normal file
155
tests/test_strategies.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
"""Withdrawal-strategy + glide-path behaviour."""
|
||||
from fire_planner import glide_path
|
||||
from fire_planner.strategies.base import StrategyState
|
||||
from fire_planner.strategies.guyton_klinger import GuytonKlingerStrategy
|
||||
from fire_planner.strategies.trinity import TrinityStrategy
|
||||
from fire_planner.strategies.vpw import VpwStrategy, VpwWithFloorStrategy, pmt_rate
|
||||
|
||||
|
||||
def state(**overrides: float | int) -> StrategyState:
|
||||
base = dict(
|
||||
portfolio=1_000_000.0,
|
||||
initial_portfolio=1_000_000.0,
|
||||
initial_withdrawal=40_000.0,
|
||||
year_idx=0,
|
||||
horizon_years=60,
|
||||
last_withdrawal=40_000.0,
|
||||
expected_real_return=0.04,
|
||||
)
|
||||
base.update(overrides)
|
||||
return StrategyState(**base) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_trinity_year_zero_uses_initial_rate() -> None:
|
||||
s = TrinityStrategy(initial_rate=0.04)
|
||||
assert s.propose_withdrawal(state()) == 40_000.0
|
||||
|
||||
|
||||
def test_trinity_holds_constant_in_real_terms() -> None:
|
||||
s = TrinityStrategy()
|
||||
assert s.propose_withdrawal(state(year_idx=10, last_withdrawal=40_000.0)) == 40_000.0
|
||||
|
||||
|
||||
def test_trinity_doesnt_increase_with_portfolio_growth() -> None:
|
||||
s = TrinityStrategy()
|
||||
assert s.propose_withdrawal(state(year_idx=5, portfolio=2_000_000.0,
|
||||
last_withdrawal=40_000.0)) == 40_000.0
|
||||
|
||||
|
||||
def test_gk_year_zero_uses_initial_rate() -> None:
|
||||
s = GuytonKlingerStrategy(initial_rate=0.055)
|
||||
# 5.5% of 1M = 55,000
|
||||
assert s.propose_withdrawal(state()) == 55_000.0
|
||||
|
||||
|
||||
def test_gk_capital_preservation_cut() -> None:
|
||||
"""Portfolio crashed: current rate now > 120% of 5.5% = 6.6%; > 15y left → cut 10%."""
|
||||
s = GuytonKlingerStrategy(initial_rate=0.055)
|
||||
# last_w = 55,000; portfolio = 700,000 → rate = 7.86% > 6.6%
|
||||
out = s.propose_withdrawal(state(year_idx=5, portfolio=700_000.0, last_withdrawal=55_000.0))
|
||||
assert abs(out - 49_500.0) < 0.01
|
||||
|
||||
|
||||
def test_gk_no_cut_when_horizon_under_15y_left() -> None:
|
||||
"""Same crash, only 10y left — no cut applies."""
|
||||
s = GuytonKlingerStrategy(initial_rate=0.055)
|
||||
out = s.propose_withdrawal(
|
||||
state(year_idx=50, portfolio=700_000.0, last_withdrawal=55_000.0, horizon_years=60))
|
||||
assert out == 55_000.0
|
||||
|
||||
|
||||
def test_gk_prosperity_bump() -> None:
|
||||
"""Big bull market: current rate < 80% of 5.5% = 4.4% → bump 10%."""
|
||||
s = GuytonKlingerStrategy(initial_rate=0.055)
|
||||
out = s.propose_withdrawal(state(year_idx=5, portfolio=2_000_000.0, last_withdrawal=55_000.0))
|
||||
assert abs(out - 60_500.0) < 0.01
|
||||
|
||||
|
||||
def test_pmt_rate_uniform_amortisation_at_zero_rate() -> None:
|
||||
assert abs(pmt_rate(years_remaining=60, real_rate=0.0) - 1 / 60) < 1e-12
|
||||
|
||||
|
||||
def test_pmt_rate_full_drain_when_years_zero() -> None:
|
||||
assert pmt_rate(years_remaining=0, real_rate=0.04) == 1.0
|
||||
|
||||
|
||||
def test_pmt_rate_bogleheads_table_60y() -> None:
|
||||
"""Bogleheads VPW table: at 5% real, 60y, the published rate is
|
||||
5.28% (within £1/£10k of 5.2828% on a 60-year amortisation)."""
|
||||
assert abs(pmt_rate(60, 0.05) - 0.052828) < 1e-4
|
||||
|
||||
|
||||
def test_pmt_rate_bogleheads_table_30y() -> None:
|
||||
"""At 5% real, 30y → 6.51%."""
|
||||
assert abs(pmt_rate(30, 0.05) - 0.06505) < 1e-4
|
||||
|
||||
|
||||
def test_pmt_rate_bogleheads_table_15y() -> None:
|
||||
"""At 5% real, 15y → 9.63%."""
|
||||
assert abs(pmt_rate(15, 0.05) - 0.09634) < 1e-4
|
||||
|
||||
|
||||
def test_vpw_year_zero_at_60y_horizon() -> None:
|
||||
"""1M portfolio × pmt_rate(60, 0.05) = 1M × 0.0528 = 52,828.20."""
|
||||
s = VpwStrategy(expected_real_return=0.05)
|
||||
out = s.propose_withdrawal(state(horizon_years=60, year_idx=0))
|
||||
assert abs(out - 52_828.0) < 5 # within a few quid
|
||||
|
||||
|
||||
def test_vpw_drain_at_horizon_end() -> None:
|
||||
"""Last year: withdraw the entire portfolio."""
|
||||
s = VpwStrategy()
|
||||
out = s.propose_withdrawal(state(year_idx=59, horizon_years=60, portfolio=100_000.0))
|
||||
assert abs(out - 100_000.0) < 1
|
||||
|
||||
|
||||
def test_vpw_with_floor_lifts_to_floor_when_vpw_proposes_less() -> None:
|
||||
"""VPW on a 500k portfolio with 60y left at 5% would propose
|
||||
500k × 0.0528 ≈ 26,400. Floor=40k overrides — withdraw the floor."""
|
||||
s = VpwWithFloorStrategy(floor=40_000.0, expected_real_return=0.05)
|
||||
out = s.propose_withdrawal(state(portfolio=500_000.0, horizon_years=60, year_idx=0))
|
||||
assert out == 40_000.0
|
||||
|
||||
|
||||
def test_vpw_with_floor_uses_vpw_when_above_floor() -> None:
|
||||
"""VPW on a 2M portfolio with 60y left ≈ 105,656. Above floor=40k → use VPW."""
|
||||
s = VpwWithFloorStrategy(floor=40_000.0, expected_real_return=0.05)
|
||||
out = s.propose_withdrawal(state(portfolio=2_000_000.0, horizon_years=60, year_idx=0))
|
||||
assert abs(out - 105_656.0) < 50
|
||||
|
||||
|
||||
def test_vpw_with_floor_clips_to_portfolio_when_portfolio_below_floor() -> None:
|
||||
"""Terminal sequence: portfolio crashed below the floor — withdraw what's left."""
|
||||
s = VpwWithFloorStrategy(floor=40_000.0)
|
||||
out = s.propose_withdrawal(state(portfolio=15_000.0, horizon_years=60, year_idx=30))
|
||||
assert out == 15_000.0
|
||||
|
||||
|
||||
def test_vpw_with_floor_zero_portfolio() -> None:
|
||||
s = VpwWithFloorStrategy(floor=40_000.0)
|
||||
out = s.propose_withdrawal(state(portfolio=0.0))
|
||||
assert out == 0.0
|
||||
|
||||
|
||||
def test_vpw_with_floor_name() -> None:
|
||||
assert VpwWithFloorStrategy(floor=40_000.0).name == "vpw_floor"
|
||||
|
||||
|
||||
def test_glide_rising_default_shape() -> None:
|
||||
g = glide_path.rising_equity()
|
||||
assert g(0) == 0.30
|
||||
assert abs(g(15) - 0.70) < 1e-9
|
||||
assert abs(g(30) - 0.70) < 1e-9
|
||||
# Halfway through the ramp
|
||||
assert abs(g(7) - (0.30 + 0.40 * 7 / 15)) < 1e-9
|
||||
|
||||
|
||||
def test_glide_static() -> None:
|
||||
g = glide_path.static(0.60)
|
||||
assert g(0) == 0.60
|
||||
assert g(50) == 0.60
|
||||
|
||||
|
||||
def test_glide_lookup() -> None:
|
||||
assert glide_path.get("rising")(0) == 0.30
|
||||
assert glide_path.get("static_60_40")(50) == 0.60
|
||||
70
tests/test_tax_base.py
Normal file
70
tests/test_tax_base.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
"""Bracket-arithmetic and breakdown invariants."""
|
||||
from decimal import Decimal
|
||||
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from fire_planner.tax.base import TaxBreakdown, TaxInputs, apply_brackets
|
||||
|
||||
|
||||
def test_apply_brackets_zero_input() -> None:
|
||||
assert apply_brackets(Decimal("0"), [(Decimal("100"), Decimal("0.2"))]) == Decimal("0")
|
||||
|
||||
|
||||
def test_apply_brackets_negative_input() -> None:
|
||||
# Negative income shouldn't generate a refund — clamp to zero.
|
||||
assert apply_brackets(Decimal("-1000"), [(Decimal("100"), Decimal("0.2"))]) == Decimal("0")
|
||||
|
||||
|
||||
def test_apply_brackets_within_first_band() -> None:
|
||||
brackets = [(Decimal("100"), Decimal("0.2")), (Decimal("Infinity"), Decimal("0.4"))]
|
||||
assert apply_brackets(Decimal("50"), brackets) == Decimal("10")
|
||||
|
||||
|
||||
def test_apply_brackets_spans_two_bands() -> None:
|
||||
# 100 @ 20% = 20; next 50 @ 40% = 20 → total 40
|
||||
brackets = [(Decimal("100"), Decimal("0.2")), (Decimal("Infinity"), Decimal("0.4"))]
|
||||
assert apply_brackets(Decimal("150"), brackets) == Decimal("40")
|
||||
|
||||
|
||||
def test_apply_brackets_uk_paye_2026_smoke() -> None:
|
||||
# Taxable income £80,000 (gross £92,570 less £12,570 PA):
|
||||
# £37,700 @ 20% = £7,540
|
||||
# £42,300 @ 40% = £16,920
|
||||
# total = £24,460
|
||||
brackets = [
|
||||
(Decimal("37700"), Decimal("0.20")),
|
||||
(Decimal("112570"), Decimal("0.40")),
|
||||
(Decimal("Infinity"), Decimal("0.45")),
|
||||
]
|
||||
assert apply_brackets(Decimal("80000"), brackets) == Decimal("24460")
|
||||
|
||||
|
||||
@given(amount=st.decimals(min_value=0, max_value=10_000_000, allow_nan=False, allow_infinity=False))
|
||||
def test_apply_brackets_monotone_in_amount(amount: Decimal) -> None:
|
||||
"""More taxable income → never less tax."""
|
||||
brackets = [
|
||||
(Decimal("37700"), Decimal("0.20")),
|
||||
(Decimal("112570"), Decimal("0.40")),
|
||||
(Decimal("Infinity"), Decimal("0.45")),
|
||||
]
|
||||
extra = Decimal("100")
|
||||
assert apply_brackets(amount + extra, brackets) >= apply_brackets(amount, brackets)
|
||||
|
||||
|
||||
def test_breakdown_total_is_sum_of_components() -> None:
|
||||
b = TaxBreakdown(
|
||||
income_tax=Decimal("10000"),
|
||||
national_insurance=Decimal("3000"),
|
||||
capital_gains_tax=Decimal("500"),
|
||||
dividend_tax=Decimal("200"),
|
||||
healthcare_levy=Decimal("100"),
|
||||
other=Decimal("50"),
|
||||
)
|
||||
assert b.total == Decimal("13850")
|
||||
|
||||
|
||||
def test_inputs_default_to_zero() -> None:
|
||||
i = TaxInputs()
|
||||
assert i.earned_income == Decimal("0")
|
||||
assert i.years_since_uk_departure == 0
|
||||
150
tests/test_tax_other_regimes.py
Normal file
150
tests/test_tax_other_regimes.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
"""Nomad, Malaysia, Thailand, Cyprus, Bulgaria, UAE regimes."""
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from fire_planner.tax.base import TaxInputs, TaxRegime
|
||||
from fire_planner.tax.bulgaria import BulgariaTaxRegime
|
||||
from fire_planner.tax.cyprus import CyprusTaxRegime
|
||||
from fire_planner.tax.malaysia import MalaysiaTaxRegime
|
||||
from fire_planner.tax.nomad import NomadTaxRegime
|
||||
from fire_planner.tax.thailand import ThailandTaxRegime
|
||||
from fire_planner.tax.uae import UaeTaxRegime
|
||||
|
||||
|
||||
def test_nomad_zero_inputs() -> None:
|
||||
assert NomadTaxRegime().compute_tax(TaxInputs()).total == Decimal("0")
|
||||
|
||||
|
||||
def test_nomad_one_pc_premium() -> None:
|
||||
b = NomadTaxRegime().compute_tax(
|
||||
TaxInputs(capital_gains=Decimal("100000"), dividends=Decimal("20000")))
|
||||
assert b.other == Decimal("1200")
|
||||
assert b.total == Decimal("1200")
|
||||
|
||||
|
||||
def test_nomad_isa_excluded_from_premium() -> None:
|
||||
b = NomadTaxRegime().compute_tax(TaxInputs(isa_withdrawals=Decimal("100000")))
|
||||
assert b.total == Decimal("0")
|
||||
|
||||
|
||||
def test_malaysia_zero_on_foreign_income() -> None:
|
||||
b = MalaysiaTaxRegime().compute_tax(
|
||||
TaxInputs(capital_gains=Decimal("500000"), dividends=Decimal("50000")))
|
||||
assert b.total == Decimal("0")
|
||||
|
||||
|
||||
def test_thailand_zero_on_foreign_income() -> None:
|
||||
b = ThailandTaxRegime().compute_tax(
|
||||
TaxInputs(capital_gains=Decimal("500000"), dividends=Decimal("50000")))
|
||||
assert b.total == Decimal("0")
|
||||
|
||||
|
||||
def test_cyprus_gesy_below_cap() -> None:
|
||||
# £100k chargeable, below €180k cap (~£154,800 default)
|
||||
# 2.65% × £100,000 = £2,650
|
||||
b = CyprusTaxRegime().compute_tax(TaxInputs(dividends=Decimal("100000")))
|
||||
assert b.healthcare_levy == Decimal("2650.0000")
|
||||
assert b.income_tax == Decimal("0")
|
||||
assert b.capital_gains_tax == Decimal("0")
|
||||
|
||||
|
||||
def test_cyprus_gesy_above_cap() -> None:
|
||||
# £200k chargeable; cap GBP = £154,800 (€180k × 0.86)
|
||||
# 2.65% × £154,800 = £4,102.20
|
||||
b = CyprusTaxRegime().compute_tax(TaxInputs(dividends=Decimal("200000")))
|
||||
assert b.healthcare_levy == Decimal("4102.2000")
|
||||
|
||||
|
||||
def test_cyprus_custom_fx() -> None:
|
||||
# Cap = 180,000 × 0.90 = 162,000
|
||||
b = CyprusTaxRegime(gbp_per_eur=Decimal("0.90")).compute_tax(
|
||||
TaxInputs(dividends=Decimal("200000")))
|
||||
assert b.healthcare_levy == Decimal("4293.0000")
|
||||
|
||||
|
||||
def test_uae_zero_on_all_personal_income() -> None:
|
||||
"""UAE has 0% PIT — capital gains, dividends, earned income all 0."""
|
||||
b = UaeTaxRegime().compute_tax(
|
||||
TaxInputs(
|
||||
earned_income=Decimal("60000"),
|
||||
capital_gains=Decimal("500000"),
|
||||
dividends=Decimal("80000"),
|
||||
interest=Decimal("5000"),
|
||||
))
|
||||
assert b.total == Decimal("0")
|
||||
assert b.income_tax == Decimal("0")
|
||||
assert b.capital_gains_tax == Decimal("0")
|
||||
assert b.dividend_tax == Decimal("0")
|
||||
assert b.healthcare_levy == Decimal("0")
|
||||
assert b.other == Decimal("0")
|
||||
|
||||
|
||||
def test_uae_no_regulatory_premium() -> None:
|
||||
"""Unlike NomadTaxRegime, UAE charges no premium — it's a real
|
||||
tax residence with a treaty network."""
|
||||
b = UaeTaxRegime().compute_tax(TaxInputs(capital_gains=Decimal("100000")))
|
||||
assert b.total == Decimal("0")
|
||||
|
||||
|
||||
def test_uae_zero_inputs() -> None:
|
||||
assert UaeTaxRegime().compute_tax(TaxInputs()).total == Decimal("0")
|
||||
|
||||
|
||||
def test_bulgaria_flat_10_pc() -> None:
|
||||
b = BulgariaTaxRegime().compute_tax(
|
||||
TaxInputs(
|
||||
earned_income=Decimal("50000"),
|
||||
capital_gains=Decimal("30000"),
|
||||
dividends=Decimal("10000"),
|
||||
))
|
||||
assert b.income_tax == Decimal("5000.00")
|
||||
assert b.capital_gains_tax == Decimal("3000.00")
|
||||
assert b.dividend_tax == Decimal("1000.00")
|
||||
assert b.total == Decimal("9000.00")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("regime", [
|
||||
NomadTaxRegime(),
|
||||
MalaysiaTaxRegime(),
|
||||
ThailandTaxRegime(),
|
||||
CyprusTaxRegime(),
|
||||
BulgariaTaxRegime(),
|
||||
UaeTaxRegime(),
|
||||
])
|
||||
def test_total_equals_sum(regime: TaxRegime) -> None:
|
||||
inputs = TaxInputs(
|
||||
earned_income=Decimal("60000"),
|
||||
capital_gains=Decimal("15000"),
|
||||
dividends=Decimal("8000"),
|
||||
interest=Decimal("500"),
|
||||
)
|
||||
b = regime.compute_tax(inputs)
|
||||
assert (b.total == b.income_tax + b.national_insurance + b.capital_gains_tax + b.dividend_tax +
|
||||
b.healthcare_levy + b.other)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("regime", [
|
||||
NomadTaxRegime(),
|
||||
MalaysiaTaxRegime(),
|
||||
ThailandTaxRegime(),
|
||||
CyprusTaxRegime(),
|
||||
BulgariaTaxRegime(),
|
||||
UaeTaxRegime(),
|
||||
])
|
||||
def test_each_regime_has_a_name(regime: TaxRegime) -> None:
|
||||
assert regime.name
|
||||
assert isinstance(regime.name, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("regime", [
|
||||
BulgariaTaxRegime(),
|
||||
NomadTaxRegime(),
|
||||
CyprusTaxRegime(),
|
||||
])
|
||||
def test_lower_spend_lower_tax(regime: TaxRegime) -> None:
|
||||
"""Sanity: more chargeable income → never less tax (for the
|
||||
regimes that actually charge)."""
|
||||
less = regime.compute_tax(TaxInputs(dividends=Decimal("10000")))
|
||||
more = regime.compute_tax(TaxInputs(dividends=Decimal("100000")))
|
||||
assert more.total >= less.total
|
||||
147
tests/test_tax_uk.py
Normal file
147
tests/test_tax_uk.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
"""UK tax regime — bands, allowances, tapers."""
|
||||
from decimal import Decimal
|
||||
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from fire_planner.tax.base import TaxInputs
|
||||
from fire_planner.tax.uk import (
|
||||
PA_TAPER_CEILING,
|
||||
PERSONAL_ALLOWANCE,
|
||||
UkTaxRegime,
|
||||
taper_personal_allowance,
|
||||
)
|
||||
|
||||
|
||||
def test_pa_no_taper_below_100k() -> None:
|
||||
assert taper_personal_allowance(Decimal("80000")) == PERSONAL_ALLOWANCE
|
||||
|
||||
|
||||
def test_pa_full_taper_at_ceiling() -> None:
|
||||
assert taper_personal_allowance(PA_TAPER_CEILING) == Decimal("0")
|
||||
|
||||
|
||||
def test_pa_partial_taper_at_110k() -> None:
|
||||
# £10k above floor → £5k reduction off PA
|
||||
assert taper_personal_allowance(Decimal("110000")) == PERSONAL_ALLOWANCE - Decimal("5000")
|
||||
|
||||
|
||||
def test_zero_income_zero_tax() -> None:
|
||||
b = UkTaxRegime().compute_tax(TaxInputs())
|
||||
assert b.total == Decimal("0")
|
||||
|
||||
|
||||
def test_isa_only_zero_tax() -> None:
|
||||
b = UkTaxRegime().compute_tax(TaxInputs(isa_withdrawals=Decimal("100000")))
|
||||
assert b.total == Decimal("0")
|
||||
|
||||
|
||||
def test_below_pa_zero_tax() -> None:
|
||||
b = UkTaxRegime().compute_tax(TaxInputs(earned_income=Decimal("12000")))
|
||||
# NI primary threshold matches PA so NI is zero too.
|
||||
assert b.total == Decimal("0")
|
||||
|
||||
|
||||
def test_basic_rate_paye_smoke() -> None:
|
||||
# £30k earned: £17,430 taxable @ 20% = £3,486 income tax
|
||||
# NI: £17,430 @ 8% = £1,394.40
|
||||
b = UkTaxRegime().compute_tax(TaxInputs(earned_income=Decimal("30000")))
|
||||
assert b.income_tax == Decimal("3486.00")
|
||||
assert b.national_insurance == Decimal("1394.40")
|
||||
|
||||
|
||||
def test_higher_rate_paye_100k() -> None:
|
||||
# £100k earned, PA still full (taper starts strictly above £100k):
|
||||
# taxable = £87,430
|
||||
# £37,700 @ 20% = £7,540
|
||||
# £49,730 @ 40% = £19,892
|
||||
# total income tax = £27,432
|
||||
b = UkTaxRegime().compute_tax(TaxInputs(earned_income=Decimal("100000")))
|
||||
assert b.income_tax == Decimal("27432.00")
|
||||
|
||||
|
||||
def test_pa_taper_at_125k() -> None:
|
||||
# £125,000: PA = 12,570 - (25,000/2) = 12,570 - 12,500 = 70
|
||||
# taxable = 125,000 - 70 = 124,930
|
||||
# £37,700 @ 20% = £7,540
|
||||
# £87,230 @ 40% = £34,892
|
||||
# total = £42,432
|
||||
b = UkTaxRegime().compute_tax(TaxInputs(earned_income=Decimal("125000")))
|
||||
assert b.income_tax == Decimal("42432.00")
|
||||
|
||||
|
||||
def test_additional_rate_above_125k() -> None:
|
||||
# £200k earned: PA fully tapered.
|
||||
# taxable income = £200,000
|
||||
# £37,700 @ 20% = £7,540
|
||||
# £87,440 @ 40% = £34,976
|
||||
# £74,860 @ 45% = £33,687
|
||||
# total = £76,203
|
||||
b = UkTaxRegime().compute_tax(TaxInputs(earned_income=Decimal("200000")))
|
||||
assert b.income_tax == Decimal("76203.00")
|
||||
|
||||
|
||||
def test_cgt_basic_rate_only() -> None:
|
||||
# No earned income, £20k gains:
|
||||
# exempt £3k → £17k taxable @ 18% (basic band has plenty of room)
|
||||
# = £3,060
|
||||
b = UkTaxRegime().compute_tax(TaxInputs(capital_gains=Decimal("20000")))
|
||||
assert b.capital_gains_tax == Decimal("3060.00")
|
||||
|
||||
|
||||
def test_cgt_spans_into_higher_band() -> None:
|
||||
# £30k earned (taxable income £17,430 — well under £37,700 band top)
|
||||
# £40k gains:
|
||||
# exempt £3k → £37k taxable
|
||||
# basic band remaining = 37,700 - 17,430 = 20,270 → @ 18% = £3,648.60
|
||||
# higher = 37,000 - 20,270 = 16,730 → @ 24% = £4,015.20
|
||||
# total CGT = £7,663.80
|
||||
b = UkTaxRegime().compute_tax(
|
||||
TaxInputs(earned_income=Decimal("30000"), capital_gains=Decimal("40000")))
|
||||
assert b.capital_gains_tax == Decimal("7663.80")
|
||||
|
||||
|
||||
def test_dividend_basic_rate() -> None:
|
||||
# No other income, £10k dividends:
|
||||
# allowance £500 → £9,500 taxable
|
||||
# Stacked on top of taxable_ordinary=0, so basic band has £37,700 room.
|
||||
# All £9,500 @ 8.75% = £831.25
|
||||
b = UkTaxRegime().compute_tax(TaxInputs(dividends=Decimal("10000")))
|
||||
assert b.dividend_tax == Decimal("831.2500")
|
||||
|
||||
|
||||
def test_pension_25pc_tax_free() -> None:
|
||||
# £40k pension drawdown, no other income:
|
||||
# PCLS = £10k tax-free
|
||||
# Taxable pension = £30k → £17,430 taxable @ 20% = £3,486
|
||||
b = UkTaxRegime().compute_tax(TaxInputs(pension_withdrawal=Decimal("40000")))
|
||||
assert b.income_tax == Decimal("3486.00")
|
||||
assert b.national_insurance == Decimal("0") # NI not on pension
|
||||
|
||||
|
||||
def test_total_equals_sum_of_components() -> None:
|
||||
inputs = TaxInputs(
|
||||
earned_income=Decimal("60000"),
|
||||
capital_gains=Decimal("15000"),
|
||||
dividends=Decimal("8000"),
|
||||
)
|
||||
b = UkTaxRegime().compute_tax(inputs)
|
||||
assert (b.total == b.income_tax + b.national_insurance + b.capital_gains_tax + b.dividend_tax +
|
||||
b.healthcare_levy + b.other)
|
||||
|
||||
|
||||
@given(income=st.decimals(
|
||||
min_value=0, max_value=500_000, places=2, allow_nan=False, allow_infinity=False))
|
||||
def test_tax_monotone_in_earned_income(income: Decimal) -> None:
|
||||
"""Adding earned income never decreases total tax."""
|
||||
base = UkTaxRegime().compute_tax(TaxInputs(earned_income=income))
|
||||
plus = UkTaxRegime().compute_tax(TaxInputs(earned_income=income + Decimal("1000")))
|
||||
assert plus.total >= base.total
|
||||
|
||||
|
||||
@given(gains=st.decimals(
|
||||
min_value=0, max_value=500_000, places=2, allow_nan=False, allow_infinity=False))
|
||||
def test_cgt_monotone_in_gains(gains: Decimal) -> None:
|
||||
base = UkTaxRegime().compute_tax(TaxInputs(capital_gains=gains))
|
||||
plus = UkTaxRegime().compute_tax(TaxInputs(capital_gains=gains + Decimal("1000")))
|
||||
assert plus.capital_gains_tax >= base.capital_gains_tax
|
||||
Loading…
Add table
Add a link
Reference in a new issue