Initial extraction from monorepo

This commit is contained in:
Viktor Barzin 2026-05-07 17:06:19 +00:00
commit f7ef7ca4ab
56 changed files with 6163 additions and 0 deletions

0
tests/__init__.py Normal file
View file

36
tests/conftest.py Normal file
View file

@ -0,0 +1,36 @@
"""Shared pytest fixtures.
Tests run against an in-memory SQLite DB created via the SQLAlchemy ORM
metadata directly fast, deterministic, and avoids running Alembic
end-to-end on every test (the migration is exercised separately).
"""
from collections.abc import AsyncIterator
import pytest_asyncio
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from fire_planner.db import SCHEMA_NAME, Base
@pytest_asyncio.fixture
async def engine() -> AsyncIterator[AsyncEngine]:
eng = create_async_engine("sqlite+aiosqlite:///:memory:")
async with eng.begin() as conn:
# SQLite has no schema concept — attach an in-memory DB under the
# `fire_planner` name so `__table_args__ = {"schema": ...}` resolves.
await conn.exec_driver_sql(f"ATTACH DATABASE ':memory:' AS {SCHEMA_NAME}")
await conn.run_sync(Base.metadata.create_all)
yield eng
await eng.dispose()
@pytest_asyncio.fixture
async def session(engine: AsyncEngine) -> AsyncIterator[AsyncSession]:
factory = async_sessionmaker(engine, expire_on_commit=False)
async with factory() as sess:
yield sess

100
tests/test_cli.py Normal file
View file

@ -0,0 +1,100 @@
"""CLI smoke tests via click's CliRunner."""
from click.testing import CliRunner
from fire_planner.__main__ import cli
def test_simulate_smoke() -> None:
"""Run a tiny scenario through the CLI without writing to DB."""
runner = CliRunner()
result = runner.invoke(
cli,
[
"simulate",
"--scenario=cyprus-trinity-leave-y3-glide-rising",
"--n-paths=200",
"--horizon=20",
"--spending=100000",
"--nw-seed=1000000",
"--no-write-db",
],
catch_exceptions=False,
)
assert result.exit_code == 0, result.output
assert "Scenario: cyprus-trinity-leave-y3-glide-rising" in result.output
assert "success_rate" in result.output
def test_simulate_with_underscore_strategy() -> None:
"""guyton_klinger contains an underscore — the parser must handle it."""
runner = CliRunner()
result = runner.invoke(
cli,
[
"simulate",
"--scenario=uk-guyton_klinger-leave-y1-glide-static_60_40",
"--n-paths=100",
"--horizon=15",
"--spending=80000",
"--nw-seed=1500000",
"--no-write-db",
],
catch_exceptions=False,
)
assert result.exit_code == 0, result.output
assert "uk-guyton_klinger-leave-y1-glide-static_60_40" in result.output
def test_simulate_bad_scenario_id() -> None:
runner = CliRunner()
result = runner.invoke(cli, ["simulate", "--scenario=nope"], catch_exceptions=False)
assert result.exit_code != 0
def test_simulate_vpw_floor_with_floor_flag() -> None:
"""vpw_floor strategy + --floor=40000 should run without error."""
runner = CliRunner()
result = runner.invoke(
cli,
[
"simulate",
"--scenario=cyprus-vpw_floor-leave-y2-glide-rising",
"--n-paths=200",
"--horizon=20",
"--spending=60000",
"--nw-seed=1500000",
"--floor=40000",
"--no-write-db",
],
catch_exceptions=False,
)
assert result.exit_code == 0, result.output
assert "cyprus-vpw_floor" in result.output
def test_simulate_uae_smoke() -> None:
runner = CliRunner()
result = runner.invoke(
cli,
[
"simulate",
"--scenario=uae-vpw_floor-leave-y2-glide-rising",
"--n-paths=200",
"--horizon=20",
"--spending=60000",
"--nw-seed=1500000",
"--floor=40000",
"--no-write-db",
],
catch_exceptions=False,
)
assert result.exit_code == 0, result.output
assert "uae-vpw_floor" in result.output
def test_help_lists_commands() -> None:
runner = CliRunner()
result = runner.invoke(cli, ["--help"], catch_exceptions=False)
assert result.exit_code == 0
for cmd in ("ingest", "simulate", "recompute-all", "migrate", "serve"):
assert cmd in result.output

111
tests/test_db_schema.py Normal file
View file

@ -0,0 +1,111 @@
"""Smoke-test the ORM schema — every table must round-trip a row."""
from datetime import date
from decimal import Decimal
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.db import (
AccountSnapshot,
McPath,
McRun,
ProjectionYearly,
Scenario,
ScenarioSummary,
)
async def test_account_snapshot_roundtrip(session: AsyncSession) -> None:
snap = AccountSnapshot(
external_id="wealthfolio:account-1:2026-04-25",
snapshot_date=date(2026, 4, 25),
account_id="account-1",
account_name="ISA",
account_type="ISA",
currency="GBP",
market_value=Decimal("123456.78"),
market_value_gbp=Decimal("123456.78"),
)
session.add(snap)
await session.commit()
result = await session.execute(select(AccountSnapshot))
rows = result.scalars().all()
assert len(rows) == 1
assert rows[0].external_id == "wealthfolio:account-1:2026-04-25"
async def test_scenario_roundtrip(session: AsyncSession) -> None:
scen = Scenario(
external_id="cyprus-vpw-leave-y3-glide-rising",
jurisdiction="cyprus",
strategy="vpw",
leave_uk_year=3,
glide_path="rising",
spending_gbp=Decimal("100000"),
nw_seed_gbp=Decimal("1000000"),
savings_per_year_gbp=Decimal("100000"),
config_json={"horizon_years": 60},
)
session.add(scen)
await session.commit()
result = await session.execute(select(Scenario))
rows = result.scalars().all()
assert len(rows) == 1
assert rows[0].jurisdiction == "cyprus"
async def test_mc_run_roundtrip(session: AsyncSession) -> None:
run = McRun(
scenario_id=1,
n_paths=10000,
seed=42,
success_rate=Decimal("0.9412"),
p10_ending_gbp=Decimal("250000"),
p50_ending_gbp=Decimal("3500000"),
p90_ending_gbp=Decimal("12000000"),
median_lifetime_tax_gbp=Decimal("750000"),
elapsed_seconds=Decimal("42.351"),
)
session.add(run)
await session.commit()
result = await session.execute(select(McRun))
rows = result.scalars().all()
assert len(rows) == 1
assert rows[0].n_paths == 10000
async def test_remaining_tables_smoke(session: AsyncSession) -> None:
session.add(
McPath(mc_run_id=1,
path_idx=0,
bucket="median",
year_idx=0,
portfolio_gbp=Decimal("1000000"),
withdrawal_gbp=Decimal("100000"),
tax_paid_gbp=Decimal("0"),
real_portfolio_gbp=Decimal("1000000")))
session.add(
ProjectionYearly(mc_run_id=1,
year_idx=0,
p10_portfolio_gbp=Decimal("800000"),
p25_portfolio_gbp=Decimal("900000"),
p50_portfolio_gbp=Decimal("1000000"),
p75_portfolio_gbp=Decimal("1100000"),
p90_portfolio_gbp=Decimal("1200000"),
p50_withdrawal_gbp=Decimal("100000"),
p50_tax_gbp=Decimal("0"),
survival_rate=Decimal("1")))
session.add(
ScenarioSummary(scenario_id=1,
mc_run_id=1,
jurisdiction="uk",
strategy="trinity",
leave_uk_year=0,
glide_path="static",
spending_gbp=Decimal("100000"),
success_rate=Decimal("0.95"),
p10_ending_gbp=Decimal("200000"),
p50_ending_gbp=Decimal("3000000"),
p90_ending_gbp=Decimal("10000000"),
median_lifetime_tax_gbp=Decimal("800000")))
await session.commit()

113
tests/test_e2e.py Normal file
View file

@ -0,0 +1,113 @@
"""End-to-end smoke: scenario builder → simulator → reporter → SQLite.
Exercises the same pipeline `recompute-all` runs in production, but on
SQLite (no Postgres needed). Catches integration breakage early.
"""
from decimal import Decimal
import numpy as np
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.db import McRun, Scenario, ScenarioSummary
from fire_planner.glide_path import get as get_glide
from fire_planner.reporters.pg import write_run
from fire_planner.returns.bootstrap import block_bootstrap
from fire_planner.returns.shiller import synthetic_returns
from fire_planner.scenarios import build_regime_schedule, build_strategy, cartesian_scenarios
from fire_planner.simulator import simulate
async def test_full_pipeline_persists_summary_per_scenario(session: AsyncSession) -> None:
"""Run a tiny Cartesian (2 jurisdictions × 1 strategy × 1 leave × 1 glide
= 2 scenarios) end-to-end. Verifies scenario, mc_run, and
scenario_summary all populate."""
bundle = synthetic_returns(seed=1, n_years=120)
paths = block_bootstrap(bundle,
n_paths=200,
n_years=20,
block_size=5,
rng=np.random.default_rng(0))
specs = cartesian_scenarios(
spending_gbp=Decimal("80000"),
nw_seed_gbp=Decimal("1500000"),
horizon_years=20,
jurisdictions=("uk", "cyprus"),
strategies=("trinity", ),
leave_years=(2, ),
glides=("rising", ),
)
assert len(specs) == 2
for spec in specs:
result = simulate(
paths=paths,
initial_portfolio=float(spec.nw_seed_gbp),
spending_target=float(spec.spending_gbp),
glide=get_glide(spec.glide_path),
strategy=build_strategy(spec.strategy),
regime=build_regime_schedule(spec.jurisdiction, spec.leave_uk_year),
horizon_years=spec.horizon_years,
)
await write_run(session, spec, result, seed=42, elapsed_seconds=0.5)
await session.commit()
scenarios = (await session.execute(select(Scenario))).scalars().all()
assert {s.external_id
for s in scenarios} == {
"uk-trinity-leave-y2-glide-rising",
"cyprus-trinity-leave-y2-glide-rising",
}
runs = (await session.execute(select(McRun))).scalars().all()
assert len(runs) == 2
summaries = (await session.execute(select(ScenarioSummary))).scalars().all()
assert len(summaries) == 2
# Cyprus median_lifetime_tax should be lower than UK's for the same
# scenario shape — the canonical Phase 8 sanity test.
by_jur = {s.jurisdiction: s for s in summaries}
assert by_jur["cyprus"].median_lifetime_tax_gbp < by_jur["uk"].median_lifetime_tax_gbp
async def test_pipeline_handles_recompute_idempotency(session: AsyncSession) -> None:
"""Running the same scenario twice must result in 1 scenario row,
2 mc_run rows, and 1 scenario_summary row pointing at the latest run."""
bundle = synthetic_returns(seed=2, n_years=60)
paths = block_bootstrap(bundle,
n_paths=100,
n_years=15,
block_size=5,
rng=np.random.default_rng(0))
spec = next(
iter(
cartesian_scenarios(
spending_gbp=Decimal("100000"),
nw_seed_gbp=Decimal("1000000"),
horizon_years=15,
jurisdictions=("bulgaria", ),
strategies=("vpw", ),
leave_years=(1, ),
glides=("static_60_40", ),
)))
for run in range(2):
result = simulate(
paths=paths,
initial_portfolio=float(spec.nw_seed_gbp),
spending_target=float(spec.spending_gbp),
glide=get_glide(spec.glide_path),
strategy=build_strategy(spec.strategy),
regime=build_regime_schedule(spec.jurisdiction, spec.leave_uk_year),
horizon_years=spec.horizon_years,
)
await write_run(session, spec, result, seed=run, elapsed_seconds=0.2)
await session.commit()
scenarios = (await session.execute(select(Scenario))).scalars().all()
assert len(scenarios) == 1
runs = (await session.execute(select(McRun))).scalars().all()
assert len(runs) == 2
summaries = (await session.execute(select(ScenarioSummary))).scalars().all()
assert len(summaries) == 1

View file

@ -0,0 +1,97 @@
"""Wealthfolio ingest reads a real-shape sqlite and upserts cleanly."""
from __future__ import annotations
import sqlite3
from datetime import date
from decimal import Decimal
from pathlib import Path
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.db import AccountSnapshot
from fire_planner.ingest.wealthfolio import read_account_snapshots, upsert_snapshots
@pytest.fixture
def wealthfolio_db(tmp_path: Path) -> Path:
"""Create a minimal sqlite mimicking Wealthfolio's schema."""
db_path = tmp_path / "wealthfolio.db"
conn = sqlite3.connect(db_path)
cur = conn.cursor()
cur.executescript("""
CREATE TABLE accounts (
id TEXT PRIMARY KEY,
name TEXT,
type TEXT,
currency TEXT
);
CREATE TABLE holdings_snapshot (
account_id TEXT,
snapshot_date TEXT,
symbol TEXT,
market_value REAL,
market_value_gbp REAL
);
INSERT INTO accounts VALUES ('acc-isa', 'ISA', 'ISA', 'GBP');
INSERT INTO accounts VALUES ('acc-schwab', 'Schwab', 'BROKERAGE', 'USD');
INSERT INTO holdings_snapshot VALUES ('acc-isa', '2026-04-25', 'VWRL', 200000, 200000);
INSERT INTO holdings_snapshot VALUES ('acc-isa', '2026-04-25', 'BND', 100000, 100000);
INSERT INTO holdings_snapshot VALUES ('acc-schwab', '2026-04-25', 'META', 800000, 640000);
""")
conn.commit()
conn.close()
return db_path
def test_read_groups_holdings_per_account(wealthfolio_db: Path) -> None:
rows = read_account_snapshots(wealthfolio_db)
assert len(rows) == 2
by_id = {r["account_id"]: r for r in rows}
assert by_id["acc-isa"]["market_value_gbp"] == Decimal("300000")
assert by_id["acc-schwab"]["market_value_gbp"] == Decimal("640000")
assert by_id["acc-isa"]["snapshot_date"] == date(2026, 4, 25)
def test_read_returns_empty_on_unknown_schema(tmp_path: Path) -> None:
"""If the sqlite has a totally different shape, return [] rather
than blow up let the operator surface the warning."""
db = tmp_path / "weird.db"
conn = sqlite3.connect(db)
conn.execute("CREATE TABLE foo (x INTEGER)")
conn.commit()
conn.close()
assert read_account_snapshots(db) == []
def test_read_missing_file_raises(tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError):
read_account_snapshots(tmp_path / "nope.db")
async def test_upsert_inserts_new_rows(session: AsyncSession, wealthfolio_db: Path) -> None:
rows = read_account_snapshots(wealthfolio_db)
n = await upsert_snapshots(session, rows)
await session.commit()
assert n == 2
persisted = (await session.execute(select(AccountSnapshot))).scalars().all()
assert len(persisted) == 2
by_id = {p.account_id: p for p in persisted}
assert by_id["acc-isa"].market_value_gbp == Decimal("300000")
async def test_upsert_is_idempotent(session: AsyncSession, wealthfolio_db: Path) -> None:
rows = read_account_snapshots(wealthfolio_db)
await upsert_snapshots(session, rows)
await session.commit()
# Run again — should still be 2 rows, not 4
await upsert_snapshots(session, rows)
await session.commit()
persisted = (await session.execute(select(AccountSnapshot))).scalars().all()
assert len(persisted) == 2
async def test_upsert_zero_rows_is_noop(session: AsyncSession) -> None:
n = await upsert_snapshots(session, [])
assert n == 0

View file

@ -0,0 +1,93 @@
"""Postgres reporter — write_run round-trips into the schema."""
from decimal import Decimal
import numpy as np
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.db import McRun, ProjectionYearly, ScenarioSummary
from fire_planner.glide_path import static
from fire_planner.reporters.pg import write_run
from fire_planner.scenarios import ScenarioSpec
from fire_planner.simulator import simulate
from fire_planner.strategies.trinity import TrinityStrategy
from fire_planner.tax.malaysia import MalaysiaTaxRegime
def fixed_paths(n_paths: int, n_years: int) -> np.ndarray:
out = np.zeros((n_paths, n_years, 3))
out[..., 0] = 0.05
out[..., 1] = 0.03
out[..., 2] = 0.02
return out
async def test_write_run_persists_summary_run_and_projection(session: AsyncSession) -> None:
spec = ScenarioSpec(
jurisdiction="cyprus",
strategy="trinity",
leave_uk_year=3,
glide_path="rising",
spending_gbp=Decimal("100000"),
nw_seed_gbp=Decimal("1000000"),
horizon_years=20,
)
paths = fixed_paths(50, 20)
result = simulate(
paths=paths,
initial_portfolio=1_000_000.0,
spending_target=40_000.0,
glide=static(0.7),
strategy=TrinityStrategy(),
regime=MalaysiaTaxRegime(),
horizon_years=20,
)
summary = await write_run(session, spec, result, seed=42, elapsed_seconds=1.5)
await session.commit()
runs = (await session.execute(select(McRun))).scalars().all()
assert len(runs) == 1
assert runs[0].id == summary.mc_run_id
assert runs[0].n_paths == 50
projections = (await session.execute(select(ProjectionYearly))).scalars().all()
assert len(projections) == 20 # one row per year
summaries = (await session.execute(select(ScenarioSummary))).scalars().all()
assert len(summaries) == 1
assert summaries[0].jurisdiction == "cyprus"
async def test_write_run_idempotent_summary(session: AsyncSession) -> None:
"""Running twice for the same scenario should keep summary at one row,
pointing at the latest run."""
spec = ScenarioSpec(
jurisdiction="bulgaria",
strategy="vpw",
leave_uk_year=2,
glide_path="static_60_40",
spending_gbp=Decimal("100000"),
nw_seed_gbp=Decimal("1000000"),
horizon_years=20,
)
paths = fixed_paths(20, 20)
result = simulate(
paths=paths,
initial_portfolio=1_000_000.0,
spending_target=40_000.0,
glide=static(0.6),
strategy=TrinityStrategy(),
regime=MalaysiaTaxRegime(),
horizon_years=20,
)
s1 = await write_run(session, spec, result, seed=42, elapsed_seconds=1.0)
await session.commit()
s2 = await write_run(session, spec, result, seed=43, elapsed_seconds=1.5)
await session.commit()
assert s1.scenario_id == s2.scenario_id
assert s2.mc_run_id != s1.mc_run_id
runs = (await session.execute(select(McRun))).scalars().all()
assert len(runs) == 2
summaries = (await session.execute(select(ScenarioSummary))).scalars().all()
assert len(summaries) == 1
assert summaries[0].mc_run_id == s2.mc_run_id

126
tests/test_returns.py Normal file
View file

@ -0,0 +1,126 @@
"""Returns loader + bootstrap behaviour."""
from pathlib import Path
import numpy as np
import pytest
from fire_planner.returns.bootstrap import block_bootstrap
from fire_planner.returns.shiller import ReturnsBundle, load_from_csv, synthetic_returns
def test_synthetic_returns_shape() -> None:
b = synthetic_returns(seed=1, n_years=120)
assert b.n_years == 120
assert b.stock_nominal.shape == (120, )
assert b.years[0] == 1871
def test_synthetic_deterministic_for_seed() -> None:
a = synthetic_returns(seed=42, n_years=10)
b = synthetic_returns(seed=42, n_years=10)
np.testing.assert_array_equal(a.stock_nominal, b.stock_nominal)
def test_real_return_smoke() -> None:
b = ReturnsBundle(
years=np.array([2020], dtype=np.int32),
stock_nominal=np.array([0.10]),
bond_nominal=np.array([0.04]),
cpi=np.array([0.03]),
)
# (1.10 / 1.03) - 1 ≈ 0.06796
assert abs(b.stock_real()[0] - 0.06796116505) < 1e-9
def test_load_from_csv(tmp_path: Path) -> None:
csv_path = tmp_path / "returns.csv"
csv_path.write_text("year,stock_nominal_return,bond_nominal_return,cpi_inflation\n"
"1990,0.05,0.07,0.025\n"
"1991,-0.10,0.04,0.03\n")
b = load_from_csv(csv_path)
assert b.n_years == 2
assert b.stock_nominal[1] == pytest.approx(-0.10)
assert b.cpi[0] == pytest.approx(0.025)
def test_returns_bundle_rejects_mismatched_lengths() -> None:
with pytest.raises(ValueError):
ReturnsBundle(
years=np.array([2020, 2021], dtype=np.int32),
stock_nominal=np.array([0.1]),
bond_nominal=np.array([0.04, 0.05]),
cpi=np.array([0.03, 0.025]),
)
def test_returns_bundle_rejects_empty() -> None:
with pytest.raises(ValueError):
ReturnsBundle(
years=np.array([], dtype=np.int32),
stock_nominal=np.array([]),
bond_nominal=np.array([]),
cpi=np.array([]),
)
def test_bootstrap_shape() -> None:
bundle = synthetic_returns(seed=1, n_years=150)
rng = np.random.default_rng(0)
paths = block_bootstrap(bundle, n_paths=100, n_years=60, block_size=5, rng=rng)
assert paths.shape == (100, 60, 3)
def test_bootstrap_deterministic_with_seed() -> None:
bundle = synthetic_returns(seed=1, n_years=150)
a = block_bootstrap(bundle, n_paths=50, n_years=30, block_size=5, rng=np.random.default_rng(0))
b = block_bootstrap(bundle, n_paths=50, n_years=30, block_size=5, rng=np.random.default_rng(0))
np.testing.assert_array_equal(a, b)
def test_bootstrap_block_size_one_is_iid() -> None:
"""Block size 1 reduces to simple IID resampling — covariance
structure isn't preserved, but all draws come from the source."""
bundle = synthetic_returns(seed=2, n_years=100)
rng = np.random.default_rng(0)
paths = block_bootstrap(bundle, n_paths=10, n_years=20, block_size=1, rng=rng)
src_set = set(zip(bundle.stock_nominal, bundle.bond_nominal, bundle.cpi, strict=True))
drawn_set = set((float(s), float(b), float(c)) for path in paths for s, b, c in path)
assert drawn_set <= src_set
def test_bootstrap_preserves_block_runs() -> None:
"""For block_size=5, every consecutive 5-year run within a path
must equal a 5-year run from the source (mod circular)."""
bundle = synthetic_returns(seed=3, n_years=50)
rng = np.random.default_rng(0)
paths = block_bootstrap(bundle, n_paths=5, n_years=15, block_size=5, rng=rng)
src = np.stack([bundle.stock_nominal, bundle.bond_nominal, bundle.cpi], axis=-1)
src_n = src.shape[0]
for path in paths:
for block_start in range(0, 15, 5):
block = path[block_start:block_start + 5]
# Find this block in source by matching the first row, then
# checking consecutiveness (mod circular).
for src_idx in range(src_n):
circ_block = np.stack([src[(src_idx + i) % src_n] for i in range(5)])
if np.allclose(block, circ_block):
break
else:
raise AssertionError(f"block {block_start} not a circular slice of source")
def test_bootstrap_rejects_zero_block_size() -> None:
bundle = synthetic_returns(seed=1, n_years=30)
with pytest.raises(ValueError):
block_bootstrap(bundle, n_paths=10, n_years=10, block_size=0)
def test_bootstrap_n_years_not_multiple_of_block() -> None:
"""13 years from 5-year blocks: 3 blocks then truncate to 13."""
bundle = synthetic_returns(seed=1, n_years=50)
paths = block_bootstrap(bundle,
n_paths=4,
n_years=13,
block_size=5,
rng=np.random.default_rng(0))
assert paths.shape == (4, 13, 3)

113
tests/test_scenarios.py Normal file
View file

@ -0,0 +1,113 @@
"""Cartesian scenario builder + strategy/regime factory."""
from decimal import Decimal
import pytest
from fire_planner.scenarios import (
DEFAULT_GLIDES,
DEFAULT_JURISDICTIONS,
DEFAULT_LEAVE_YEARS,
DEFAULT_STRATEGIES,
ScenarioSpec,
build_regime_schedule,
build_strategy,
cartesian_scenarios,
)
from fire_planner.strategies.guyton_klinger import GuytonKlingerStrategy
from fire_planner.strategies.trinity import TrinityStrategy
from fire_planner.strategies.vpw import VpwStrategy, VpwWithFloorStrategy
from fire_planner.tax.bulgaria import BulgariaTaxRegime
from fire_planner.tax.cyprus import CyprusTaxRegime
from fire_planner.tax.uae import UaeTaxRegime
from fire_planner.tax.uk import UkTaxRegime
def test_default_cartesian_count_is_120() -> None:
specs = cartesian_scenarios(spending_gbp=Decimal("100000"), nw_seed_gbp=Decimal("1000000"))
expected = (len(DEFAULT_JURISDICTIONS) * len(DEFAULT_STRATEGIES) * len(DEFAULT_LEAVE_YEARS) *
len(DEFAULT_GLIDES))
assert expected == 120
assert len(specs) == 120
def test_external_id_format() -> None:
spec = ScenarioSpec(
jurisdiction="cyprus",
strategy="vpw",
leave_uk_year=3,
glide_path="rising",
spending_gbp=Decimal("100000"),
nw_seed_gbp=Decimal("1000000"),
)
assert spec.external_id == "cyprus-vpw-leave-y3-glide-rising"
def test_cartesian_unique_external_ids() -> None:
specs = cartesian_scenarios(spending_gbp=Decimal("100000"), nw_seed_gbp=Decimal("1000000"))
ids = [s.external_id for s in specs]
assert len(ids) == len(set(ids))
def test_build_strategy_dispatch() -> None:
assert isinstance(build_strategy("trinity"), TrinityStrategy)
assert isinstance(build_strategy("guyton_klinger"), GuytonKlingerStrategy)
assert isinstance(build_strategy("vpw"), VpwStrategy)
def test_build_strategy_vpw_floor_requires_floor() -> None:
s = build_strategy("vpw_floor", floor=40_000.0)
assert isinstance(s, VpwWithFloorStrategy)
assert s.floor == 40_000.0
def test_build_strategy_vpw_floor_missing_floor_raises() -> None:
with pytest.raises(ValueError):
build_strategy("vpw_floor")
def test_build_strategy_unknown_raises() -> None:
with pytest.raises(KeyError):
build_strategy("walmart")
def test_build_regime_schedule_uae() -> None:
fn = build_regime_schedule("uae", leave_uk_year=2)
assert isinstance(fn(0), UkTaxRegime)
assert isinstance(fn(1), UkTaxRegime)
assert isinstance(fn(2), UaeTaxRegime)
assert isinstance(fn(50), UaeTaxRegime)
def test_build_regime_schedule_uk_constant() -> None:
fn = build_regime_schedule("uk", leave_uk_year=3)
# All years should resolve to UK
assert isinstance(fn(0), UkTaxRegime)
assert isinstance(fn(50), UkTaxRegime)
def test_build_regime_schedule_cyprus_switches_at_leave_year() -> None:
fn = build_regime_schedule("cyprus", leave_uk_year=3)
assert isinstance(fn(0), UkTaxRegime)
assert isinstance(fn(2), UkTaxRegime)
assert isinstance(fn(3), CyprusTaxRegime)
assert isinstance(fn(50), CyprusTaxRegime)
def test_build_regime_schedule_bulgaria() -> None:
fn = build_regime_schedule("bulgaria", leave_uk_year=1)
assert isinstance(fn(0), UkTaxRegime)
assert isinstance(fn(1), BulgariaTaxRegime)
def test_build_regime_schedule_unknown_raises() -> None:
with pytest.raises(KeyError):
build_regime_schedule("madeupistan", leave_uk_year=3)
def test_cartesian_unknown_glide_raises() -> None:
with pytest.raises(KeyError):
cartesian_scenarios(
spending_gbp=Decimal("100000"),
nw_seed_gbp=Decimal("1000000"),
glides=("staircase", ),
)

259
tests/test_simulator.py Normal file
View file

@ -0,0 +1,259 @@
"""Simulator behaviour: deterministic short-horizon checks, then
stochastic monotonicity + cFIREsim sanity calibration."""
from __future__ import annotations
import time
import numpy as np
from fire_planner.glide_path import static
from fire_planner.returns.bootstrap import block_bootstrap
from fire_planner.returns.shiller import ReturnsBundle, synthetic_returns
from fire_planner.simulator import default_bucket_split, simulate
from fire_planner.strategies.guyton_klinger import GuytonKlingerStrategy
from fire_planner.strategies.trinity import TrinityStrategy
from fire_planner.strategies.vpw import VpwStrategy
from fire_planner.tax.bulgaria import BulgariaTaxRegime
from fire_planner.tax.malaysia import MalaysiaTaxRegime
from fire_planner.tax.uk import UkTaxRegime
def fixed_paths(n_paths: int, n_years: int, stock_ret: float, bond_ret: float,
cpi: float) -> np.ndarray:
"""All-paths-identical returns — deterministic regression check."""
out = np.zeros((n_paths, n_years, 3), dtype=np.float64)
out[..., 0] = stock_ret
out[..., 1] = bond_ret
out[..., 2] = cpi
return out
def test_simulate_zero_returns_zero_inflation_drains_at_4pc() -> None:
"""0% returns + 0% inflation, 4% Trinity, 25y horizon — withdraw
£40k/y from £1M = drain to exactly £0 in year 25. Success because
portfolio stays positive *during* every year (clipped to 0 at end)."""
paths = fixed_paths(n_paths=1, n_years=25, stock_ret=0.0, bond_ret=0.0, cpi=0.0)
res = simulate(
paths=paths,
initial_portfolio=1_000_000.0,
spending_target=40_000.0,
glide=static(0.6),
strategy=TrinityStrategy(initial_rate=0.04),
regime=MalaysiaTaxRegime(), # 0% to keep arithmetic clean
)
# Year 0 withdrawal is 40k, portfolio after = 960k
assert res.portfolio_real[0, 1] == 960_000.0
# 25y of £40k draws against zero growth = drain to 0 by end of y25.
assert abs(res.portfolio_real[0, 25]) < 1.0
def test_simulate_failing_path_marked_unsuccessful() -> None:
"""6% Trinity rate against 0% real return for 25y — clearly fails."""
paths = fixed_paths(n_paths=1, n_years=25, stock_ret=0.0, bond_ret=0.0, cpi=0.0)
res = simulate(
paths=paths,
initial_portfolio=1_000_000.0,
spending_target=60_000.0,
glide=static(1.0),
strategy=TrinityStrategy(initial_rate=0.06),
regime=MalaysiaTaxRegime(),
)
assert not res.success_mask[0]
def test_simulate_growing_portfolio_succeeds() -> None:
"""5% real return, 4% draw — classic surplus case."""
paths = fixed_paths(n_paths=1, n_years=30, stock_ret=0.05, bond_ret=0.05, cpi=0.0)
res = simulate(
paths=paths,
initial_portfolio=1_000_000.0,
spending_target=40_000.0,
glide=static(1.0),
strategy=TrinityStrategy(initial_rate=0.04),
regime=MalaysiaTaxRegime(),
)
assert res.success_mask[0]
# Portfolio should grow above starting value
assert res.portfolio_real[0, 30] > 1_000_000.0
def test_savings_phase_increases_portfolio() -> None:
"""5y of savings @ £100k / 0% return → portfolio grows."""
paths = fixed_paths(n_paths=1, n_years=5, stock_ret=0.0, bond_ret=0.0, cpi=0.0)
res = simulate(
paths=paths,
initial_portfolio=1_000_000.0,
spending_target=0.0, # not drawing during accumulation
glide=static(1.0),
strategy=TrinityStrategy(initial_rate=0.0),
regime=MalaysiaTaxRegime(),
annual_savings=np.full(5, 100_000.0),
)
# 1M + 5×100k = 1.5M, no growth
assert res.portfolio_real[0, 5] == 1_500_000.0
def test_uk_tax_increases_failure_rate_vs_no_tax() -> None:
"""Same scenario, UK regime should produce more or equal failures
than Malaysia (zero tax) paths are identical."""
bundle = synthetic_returns(seed=1, n_years=120)
rng = np.random.default_rng(0)
paths = block_bootstrap(bundle, n_paths=200, n_years=30, block_size=5, rng=rng)
common = dict(
paths=paths,
initial_portfolio=600_000.0, # tighter so tax matters
spending_target=40_000.0,
glide=static(0.7),
strategy=TrinityStrategy(initial_rate=0.04),
)
msy = simulate(**common, regime=MalaysiaTaxRegime()) # type: ignore[arg-type]
uk = simulate(**common, regime=UkTaxRegime()) # type: ignore[arg-type]
assert uk.success_rate <= msy.success_rate
assert uk.median_lifetime_tax() > msy.median_lifetime_tax()
def test_vpw_never_runs_out() -> None:
"""VPW scales withdrawal with portfolio — should never fully ruin."""
bundle = synthetic_returns(seed=2, n_years=120)
rng = np.random.default_rng(0)
paths = block_bootstrap(bundle, n_paths=200, n_years=60, block_size=5, rng=rng)
res = simulate(
paths=paths,
initial_portfolio=1_000_000.0,
spending_target=50_000.0,
glide=static(0.7),
strategy=VpwStrategy(),
regime=MalaysiaTaxRegime(),
)
# Every path should keep some portfolio > 0 throughout (until last year).
# Year `n-1` end may be tiny but >= 0 since VPW caps drain at 100% in y=H-1.
assert res.portfolio_real[:, 1:-1].min() > 0
def test_simulator_deterministic_with_same_paths() -> None:
paths = fixed_paths(n_paths=10, n_years=30, stock_ret=0.05, bond_ret=0.03, cpi=0.02)
common = dict(
paths=paths,
initial_portfolio=1_000_000.0,
spending_target=40_000.0,
glide=static(0.7),
strategy=GuytonKlingerStrategy(),
regime=BulgariaTaxRegime(),
)
a = simulate(**common) # type: ignore[arg-type]
b = simulate(**common) # type: ignore[arg-type]
np.testing.assert_array_equal(a.portfolio_real, b.portfolio_real)
def test_success_rate_monotone_in_portfolio() -> None:
"""More starting wealth → higher (or equal) success rate."""
bundle = synthetic_returns(seed=3, n_years=120)
rng = np.random.default_rng(0)
paths = block_bootstrap(bundle, n_paths=300, n_years=30, block_size=5, rng=rng)
common = dict(
paths=paths,
spending_target=40_000.0,
glide=static(0.7),
strategy=TrinityStrategy(initial_rate=0.04),
regime=MalaysiaTaxRegime(),
)
low = simulate(**common, initial_portfolio=600_000.0) # type: ignore[arg-type]
high = simulate(**common, initial_portfolio=1_500_000.0) # type: ignore[arg-type]
assert high.success_rate >= low.success_rate
def test_success_rate_monotone_in_spending() -> None:
"""Less spending → higher success rate."""
bundle = synthetic_returns(seed=4, n_years=120)
rng = np.random.default_rng(0)
paths = block_bootstrap(bundle, n_paths=300, n_years=30, block_size=5, rng=rng)
common = dict(
paths=paths,
initial_portfolio=1_000_000.0,
glide=static(0.7),
strategy=TrinityStrategy(initial_rate=0.04),
regime=MalaysiaTaxRegime(),
)
cheap = simulate(**common, spending_target=30_000.0) # type: ignore[arg-type]
fat = simulate(**common, spending_target=80_000.0) # type: ignore[arg-type]
assert cheap.success_rate >= fat.success_rate
def test_fan_quantiles_shape() -> None:
bundle = synthetic_returns(seed=5, n_years=120)
paths = block_bootstrap(bundle,
n_paths=100,
n_years=20,
block_size=5,
rng=np.random.default_rng(0))
res = simulate(
paths=paths,
initial_portfolio=1_000_000.0,
spending_target=40_000.0,
glide=static(0.7),
strategy=TrinityStrategy(),
regime=MalaysiaTaxRegime(),
)
p10 = res.fan_quantiles(10)
assert p10.shape == (21, ) # n_years + 1
def test_perf_under_60s_for_10k_paths_60y() -> None:
"""Stretch goal — at 10k paths × 60y the simulator should finish
in well under a minute on commodity hardware. Test allows 60s
(generous; CI can vary)."""
bundle = synthetic_returns(seed=6, n_years=150)
paths = block_bootstrap(bundle,
n_paths=10_000,
n_years=60,
block_size=5,
rng=np.random.default_rng(0))
start = time.perf_counter()
simulate(
paths=paths,
initial_portfolio=1_000_000.0,
spending_target=40_000.0,
glide=static(0.7),
strategy=TrinityStrategy(),
regime=MalaysiaTaxRegime(),
)
elapsed = time.perf_counter() - start
assert elapsed < 60, f"too slow: {elapsed:.2f}s"
def test_convergence_5k_vs_50k_paths() -> None:
"""Success rate should be stable to within ±1.5% between 5k and
50k paths (Monte Carlo SE ~0.5% at 10k samples)."""
bundle = synthetic_returns(seed=7, n_years=150)
paths_small = block_bootstrap(bundle,
n_paths=5_000,
n_years=30,
block_size=5,
rng=np.random.default_rng(0))
paths_large = block_bootstrap(bundle,
n_paths=50_000,
n_years=30,
block_size=5,
rng=np.random.default_rng(0))
common = dict(
initial_portfolio=1_000_000.0,
spending_target=40_000.0,
glide=static(0.7),
strategy=TrinityStrategy(),
regime=MalaysiaTaxRegime(),
)
small = simulate(paths=paths_small, **common) # type: ignore[arg-type]
large = simulate(paths=paths_large, **common) # type: ignore[arg-type]
assert abs(small.success_rate - large.success_rate) < 0.015
def test_default_bucket_split_smoke() -> None:
inputs = default_bucket_split(50_000.0, year_idx=5)
assert inputs.capital_gains == 50000
def test_returns_bundle_supplies_ie_data_columns() -> None:
"""Sanity: the bundle has stock/bond/cpi correctly aligned."""
b = synthetic_returns(seed=8, n_years=10)
assert isinstance(b, ReturnsBundle)
assert len(b.stock_nominal) == 10

155
tests/test_strategies.py Normal file
View file

@ -0,0 +1,155 @@
"""Withdrawal-strategy + glide-path behaviour."""
from fire_planner import glide_path
from fire_planner.strategies.base import StrategyState
from fire_planner.strategies.guyton_klinger import GuytonKlingerStrategy
from fire_planner.strategies.trinity import TrinityStrategy
from fire_planner.strategies.vpw import VpwStrategy, VpwWithFloorStrategy, pmt_rate
def state(**overrides: float | int) -> StrategyState:
base = dict(
portfolio=1_000_000.0,
initial_portfolio=1_000_000.0,
initial_withdrawal=40_000.0,
year_idx=0,
horizon_years=60,
last_withdrawal=40_000.0,
expected_real_return=0.04,
)
base.update(overrides)
return StrategyState(**base) # type: ignore[arg-type]
def test_trinity_year_zero_uses_initial_rate() -> None:
s = TrinityStrategy(initial_rate=0.04)
assert s.propose_withdrawal(state()) == 40_000.0
def test_trinity_holds_constant_in_real_terms() -> None:
s = TrinityStrategy()
assert s.propose_withdrawal(state(year_idx=10, last_withdrawal=40_000.0)) == 40_000.0
def test_trinity_doesnt_increase_with_portfolio_growth() -> None:
s = TrinityStrategy()
assert s.propose_withdrawal(state(year_idx=5, portfolio=2_000_000.0,
last_withdrawal=40_000.0)) == 40_000.0
def test_gk_year_zero_uses_initial_rate() -> None:
s = GuytonKlingerStrategy(initial_rate=0.055)
# 5.5% of 1M = 55,000
assert s.propose_withdrawal(state()) == 55_000.0
def test_gk_capital_preservation_cut() -> None:
"""Portfolio crashed: current rate now > 120% of 5.5% = 6.6%; > 15y left → cut 10%."""
s = GuytonKlingerStrategy(initial_rate=0.055)
# last_w = 55,000; portfolio = 700,000 → rate = 7.86% > 6.6%
out = s.propose_withdrawal(state(year_idx=5, portfolio=700_000.0, last_withdrawal=55_000.0))
assert abs(out - 49_500.0) < 0.01
def test_gk_no_cut_when_horizon_under_15y_left() -> None:
"""Same crash, only 10y left — no cut applies."""
s = GuytonKlingerStrategy(initial_rate=0.055)
out = s.propose_withdrawal(
state(year_idx=50, portfolio=700_000.0, last_withdrawal=55_000.0, horizon_years=60))
assert out == 55_000.0
def test_gk_prosperity_bump() -> None:
"""Big bull market: current rate < 80% of 5.5% = 4.4% → bump 10%."""
s = GuytonKlingerStrategy(initial_rate=0.055)
out = s.propose_withdrawal(state(year_idx=5, portfolio=2_000_000.0, last_withdrawal=55_000.0))
assert abs(out - 60_500.0) < 0.01
def test_pmt_rate_uniform_amortisation_at_zero_rate() -> None:
assert abs(pmt_rate(years_remaining=60, real_rate=0.0) - 1 / 60) < 1e-12
def test_pmt_rate_full_drain_when_years_zero() -> None:
assert pmt_rate(years_remaining=0, real_rate=0.04) == 1.0
def test_pmt_rate_bogleheads_table_60y() -> None:
"""Bogleheads VPW table: at 5% real, 60y, the published rate is
5.28% (within £1/£10k of 5.2828% on a 60-year amortisation)."""
assert abs(pmt_rate(60, 0.05) - 0.052828) < 1e-4
def test_pmt_rate_bogleheads_table_30y() -> None:
"""At 5% real, 30y → 6.51%."""
assert abs(pmt_rate(30, 0.05) - 0.06505) < 1e-4
def test_pmt_rate_bogleheads_table_15y() -> None:
"""At 5% real, 15y → 9.63%."""
assert abs(pmt_rate(15, 0.05) - 0.09634) < 1e-4
def test_vpw_year_zero_at_60y_horizon() -> None:
"""1M portfolio × pmt_rate(60, 0.05) = 1M × 0.0528 = 52,828.20."""
s = VpwStrategy(expected_real_return=0.05)
out = s.propose_withdrawal(state(horizon_years=60, year_idx=0))
assert abs(out - 52_828.0) < 5 # within a few quid
def test_vpw_drain_at_horizon_end() -> None:
"""Last year: withdraw the entire portfolio."""
s = VpwStrategy()
out = s.propose_withdrawal(state(year_idx=59, horizon_years=60, portfolio=100_000.0))
assert abs(out - 100_000.0) < 1
def test_vpw_with_floor_lifts_to_floor_when_vpw_proposes_less() -> None:
"""VPW on a 500k portfolio with 60y left at 5% would propose
500k × 0.0528 26,400. Floor=40k overrides withdraw the floor."""
s = VpwWithFloorStrategy(floor=40_000.0, expected_real_return=0.05)
out = s.propose_withdrawal(state(portfolio=500_000.0, horizon_years=60, year_idx=0))
assert out == 40_000.0
def test_vpw_with_floor_uses_vpw_when_above_floor() -> None:
"""VPW on a 2M portfolio with 60y left ≈ 105,656. Above floor=40k → use VPW."""
s = VpwWithFloorStrategy(floor=40_000.0, expected_real_return=0.05)
out = s.propose_withdrawal(state(portfolio=2_000_000.0, horizon_years=60, year_idx=0))
assert abs(out - 105_656.0) < 50
def test_vpw_with_floor_clips_to_portfolio_when_portfolio_below_floor() -> None:
"""Terminal sequence: portfolio crashed below the floor — withdraw what's left."""
s = VpwWithFloorStrategy(floor=40_000.0)
out = s.propose_withdrawal(state(portfolio=15_000.0, horizon_years=60, year_idx=30))
assert out == 15_000.0
def test_vpw_with_floor_zero_portfolio() -> None:
s = VpwWithFloorStrategy(floor=40_000.0)
out = s.propose_withdrawal(state(portfolio=0.0))
assert out == 0.0
def test_vpw_with_floor_name() -> None:
assert VpwWithFloorStrategy(floor=40_000.0).name == "vpw_floor"
def test_glide_rising_default_shape() -> None:
g = glide_path.rising_equity()
assert g(0) == 0.30
assert abs(g(15) - 0.70) < 1e-9
assert abs(g(30) - 0.70) < 1e-9
# Halfway through the ramp
assert abs(g(7) - (0.30 + 0.40 * 7 / 15)) < 1e-9
def test_glide_static() -> None:
g = glide_path.static(0.60)
assert g(0) == 0.60
assert g(50) == 0.60
def test_glide_lookup() -> None:
assert glide_path.get("rising")(0) == 0.30
assert glide_path.get("static_60_40")(50) == 0.60

70
tests/test_tax_base.py Normal file
View file

@ -0,0 +1,70 @@
"""Bracket-arithmetic and breakdown invariants."""
from decimal import Decimal
from hypothesis import given
from hypothesis import strategies as st
from fire_planner.tax.base import TaxBreakdown, TaxInputs, apply_brackets
def test_apply_brackets_zero_input() -> None:
assert apply_brackets(Decimal("0"), [(Decimal("100"), Decimal("0.2"))]) == Decimal("0")
def test_apply_brackets_negative_input() -> None:
# Negative income shouldn't generate a refund — clamp to zero.
assert apply_brackets(Decimal("-1000"), [(Decimal("100"), Decimal("0.2"))]) == Decimal("0")
def test_apply_brackets_within_first_band() -> None:
brackets = [(Decimal("100"), Decimal("0.2")), (Decimal("Infinity"), Decimal("0.4"))]
assert apply_brackets(Decimal("50"), brackets) == Decimal("10")
def test_apply_brackets_spans_two_bands() -> None:
# 100 @ 20% = 20; next 50 @ 40% = 20 → total 40
brackets = [(Decimal("100"), Decimal("0.2")), (Decimal("Infinity"), Decimal("0.4"))]
assert apply_brackets(Decimal("150"), brackets) == Decimal("40")
def test_apply_brackets_uk_paye_2026_smoke() -> None:
# Taxable income £80,000 (gross £92,570 less £12,570 PA):
# £37,700 @ 20% = £7,540
# £42,300 @ 40% = £16,920
# total = £24,460
brackets = [
(Decimal("37700"), Decimal("0.20")),
(Decimal("112570"), Decimal("0.40")),
(Decimal("Infinity"), Decimal("0.45")),
]
assert apply_brackets(Decimal("80000"), brackets) == Decimal("24460")
@given(amount=st.decimals(min_value=0, max_value=10_000_000, allow_nan=False, allow_infinity=False))
def test_apply_brackets_monotone_in_amount(amount: Decimal) -> None:
"""More taxable income → never less tax."""
brackets = [
(Decimal("37700"), Decimal("0.20")),
(Decimal("112570"), Decimal("0.40")),
(Decimal("Infinity"), Decimal("0.45")),
]
extra = Decimal("100")
assert apply_brackets(amount + extra, brackets) >= apply_brackets(amount, brackets)
def test_breakdown_total_is_sum_of_components() -> None:
b = TaxBreakdown(
income_tax=Decimal("10000"),
national_insurance=Decimal("3000"),
capital_gains_tax=Decimal("500"),
dividend_tax=Decimal("200"),
healthcare_levy=Decimal("100"),
other=Decimal("50"),
)
assert b.total == Decimal("13850")
def test_inputs_default_to_zero() -> None:
i = TaxInputs()
assert i.earned_income == Decimal("0")
assert i.years_since_uk_departure == 0

View file

@ -0,0 +1,150 @@
"""Nomad, Malaysia, Thailand, Cyprus, Bulgaria, UAE regimes."""
from decimal import Decimal
import pytest
from fire_planner.tax.base import TaxInputs, TaxRegime
from fire_planner.tax.bulgaria import BulgariaTaxRegime
from fire_planner.tax.cyprus import CyprusTaxRegime
from fire_planner.tax.malaysia import MalaysiaTaxRegime
from fire_planner.tax.nomad import NomadTaxRegime
from fire_planner.tax.thailand import ThailandTaxRegime
from fire_planner.tax.uae import UaeTaxRegime
def test_nomad_zero_inputs() -> None:
assert NomadTaxRegime().compute_tax(TaxInputs()).total == Decimal("0")
def test_nomad_one_pc_premium() -> None:
b = NomadTaxRegime().compute_tax(
TaxInputs(capital_gains=Decimal("100000"), dividends=Decimal("20000")))
assert b.other == Decimal("1200")
assert b.total == Decimal("1200")
def test_nomad_isa_excluded_from_premium() -> None:
b = NomadTaxRegime().compute_tax(TaxInputs(isa_withdrawals=Decimal("100000")))
assert b.total == Decimal("0")
def test_malaysia_zero_on_foreign_income() -> None:
b = MalaysiaTaxRegime().compute_tax(
TaxInputs(capital_gains=Decimal("500000"), dividends=Decimal("50000")))
assert b.total == Decimal("0")
def test_thailand_zero_on_foreign_income() -> None:
b = ThailandTaxRegime().compute_tax(
TaxInputs(capital_gains=Decimal("500000"), dividends=Decimal("50000")))
assert b.total == Decimal("0")
def test_cyprus_gesy_below_cap() -> None:
# £100k chargeable, below €180k cap (~£154,800 default)
# 2.65% × £100,000 = £2,650
b = CyprusTaxRegime().compute_tax(TaxInputs(dividends=Decimal("100000")))
assert b.healthcare_levy == Decimal("2650.0000")
assert b.income_tax == Decimal("0")
assert b.capital_gains_tax == Decimal("0")
def test_cyprus_gesy_above_cap() -> None:
# £200k chargeable; cap GBP = £154,800 (€180k × 0.86)
# 2.65% × £154,800 = £4,102.20
b = CyprusTaxRegime().compute_tax(TaxInputs(dividends=Decimal("200000")))
assert b.healthcare_levy == Decimal("4102.2000")
def test_cyprus_custom_fx() -> None:
# Cap = 180,000 × 0.90 = 162,000
b = CyprusTaxRegime(gbp_per_eur=Decimal("0.90")).compute_tax(
TaxInputs(dividends=Decimal("200000")))
assert b.healthcare_levy == Decimal("4293.0000")
def test_uae_zero_on_all_personal_income() -> None:
"""UAE has 0% PIT — capital gains, dividends, earned income all 0."""
b = UaeTaxRegime().compute_tax(
TaxInputs(
earned_income=Decimal("60000"),
capital_gains=Decimal("500000"),
dividends=Decimal("80000"),
interest=Decimal("5000"),
))
assert b.total == Decimal("0")
assert b.income_tax == Decimal("0")
assert b.capital_gains_tax == Decimal("0")
assert b.dividend_tax == Decimal("0")
assert b.healthcare_levy == Decimal("0")
assert b.other == Decimal("0")
def test_uae_no_regulatory_premium() -> None:
"""Unlike NomadTaxRegime, UAE charges no premium — it's a real
tax residence with a treaty network."""
b = UaeTaxRegime().compute_tax(TaxInputs(capital_gains=Decimal("100000")))
assert b.total == Decimal("0")
def test_uae_zero_inputs() -> None:
assert UaeTaxRegime().compute_tax(TaxInputs()).total == Decimal("0")
def test_bulgaria_flat_10_pc() -> None:
b = BulgariaTaxRegime().compute_tax(
TaxInputs(
earned_income=Decimal("50000"),
capital_gains=Decimal("30000"),
dividends=Decimal("10000"),
))
assert b.income_tax == Decimal("5000.00")
assert b.capital_gains_tax == Decimal("3000.00")
assert b.dividend_tax == Decimal("1000.00")
assert b.total == Decimal("9000.00")
@pytest.mark.parametrize("regime", [
NomadTaxRegime(),
MalaysiaTaxRegime(),
ThailandTaxRegime(),
CyprusTaxRegime(),
BulgariaTaxRegime(),
UaeTaxRegime(),
])
def test_total_equals_sum(regime: TaxRegime) -> None:
inputs = TaxInputs(
earned_income=Decimal("60000"),
capital_gains=Decimal("15000"),
dividends=Decimal("8000"),
interest=Decimal("500"),
)
b = regime.compute_tax(inputs)
assert (b.total == b.income_tax + b.national_insurance + b.capital_gains_tax + b.dividend_tax +
b.healthcare_levy + b.other)
@pytest.mark.parametrize("regime", [
NomadTaxRegime(),
MalaysiaTaxRegime(),
ThailandTaxRegime(),
CyprusTaxRegime(),
BulgariaTaxRegime(),
UaeTaxRegime(),
])
def test_each_regime_has_a_name(regime: TaxRegime) -> None:
assert regime.name
assert isinstance(regime.name, str)
@pytest.mark.parametrize("regime", [
BulgariaTaxRegime(),
NomadTaxRegime(),
CyprusTaxRegime(),
])
def test_lower_spend_lower_tax(regime: TaxRegime) -> None:
"""Sanity: more chargeable income → never less tax (for the
regimes that actually charge)."""
less = regime.compute_tax(TaxInputs(dividends=Decimal("10000")))
more = regime.compute_tax(TaxInputs(dividends=Decimal("100000")))
assert more.total >= less.total

147
tests/test_tax_uk.py Normal file
View file

@ -0,0 +1,147 @@
"""UK tax regime — bands, allowances, tapers."""
from decimal import Decimal
from hypothesis import given
from hypothesis import strategies as st
from fire_planner.tax.base import TaxInputs
from fire_planner.tax.uk import (
PA_TAPER_CEILING,
PERSONAL_ALLOWANCE,
UkTaxRegime,
taper_personal_allowance,
)
def test_pa_no_taper_below_100k() -> None:
assert taper_personal_allowance(Decimal("80000")) == PERSONAL_ALLOWANCE
def test_pa_full_taper_at_ceiling() -> None:
assert taper_personal_allowance(PA_TAPER_CEILING) == Decimal("0")
def test_pa_partial_taper_at_110k() -> None:
# £10k above floor → £5k reduction off PA
assert taper_personal_allowance(Decimal("110000")) == PERSONAL_ALLOWANCE - Decimal("5000")
def test_zero_income_zero_tax() -> None:
b = UkTaxRegime().compute_tax(TaxInputs())
assert b.total == Decimal("0")
def test_isa_only_zero_tax() -> None:
b = UkTaxRegime().compute_tax(TaxInputs(isa_withdrawals=Decimal("100000")))
assert b.total == Decimal("0")
def test_below_pa_zero_tax() -> None:
b = UkTaxRegime().compute_tax(TaxInputs(earned_income=Decimal("12000")))
# NI primary threshold matches PA so NI is zero too.
assert b.total == Decimal("0")
def test_basic_rate_paye_smoke() -> None:
# £30k earned: £17,430 taxable @ 20% = £3,486 income tax
# NI: £17,430 @ 8% = £1,394.40
b = UkTaxRegime().compute_tax(TaxInputs(earned_income=Decimal("30000")))
assert b.income_tax == Decimal("3486.00")
assert b.national_insurance == Decimal("1394.40")
def test_higher_rate_paye_100k() -> None:
# £100k earned, PA still full (taper starts strictly above £100k):
# taxable = £87,430
# £37,700 @ 20% = £7,540
# £49,730 @ 40% = £19,892
# total income tax = £27,432
b = UkTaxRegime().compute_tax(TaxInputs(earned_income=Decimal("100000")))
assert b.income_tax == Decimal("27432.00")
def test_pa_taper_at_125k() -> None:
# £125,000: PA = 12,570 - (25,000/2) = 12,570 - 12,500 = 70
# taxable = 125,000 - 70 = 124,930
# £37,700 @ 20% = £7,540
# £87,230 @ 40% = £34,892
# total = £42,432
b = UkTaxRegime().compute_tax(TaxInputs(earned_income=Decimal("125000")))
assert b.income_tax == Decimal("42432.00")
def test_additional_rate_above_125k() -> None:
# £200k earned: PA fully tapered.
# taxable income = £200,000
# £37,700 @ 20% = £7,540
# £87,440 @ 40% = £34,976
# £74,860 @ 45% = £33,687
# total = £76,203
b = UkTaxRegime().compute_tax(TaxInputs(earned_income=Decimal("200000")))
assert b.income_tax == Decimal("76203.00")
def test_cgt_basic_rate_only() -> None:
# No earned income, £20k gains:
# exempt £3k → £17k taxable @ 18% (basic band has plenty of room)
# = £3,060
b = UkTaxRegime().compute_tax(TaxInputs(capital_gains=Decimal("20000")))
assert b.capital_gains_tax == Decimal("3060.00")
def test_cgt_spans_into_higher_band() -> None:
# £30k earned (taxable income £17,430 — well under £37,700 band top)
# £40k gains:
# exempt £3k → £37k taxable
# basic band remaining = 37,700 - 17,430 = 20,270 → @ 18% = £3,648.60
# higher = 37,000 - 20,270 = 16,730 → @ 24% = £4,015.20
# total CGT = £7,663.80
b = UkTaxRegime().compute_tax(
TaxInputs(earned_income=Decimal("30000"), capital_gains=Decimal("40000")))
assert b.capital_gains_tax == Decimal("7663.80")
def test_dividend_basic_rate() -> None:
# No other income, £10k dividends:
# allowance £500 → £9,500 taxable
# Stacked on top of taxable_ordinary=0, so basic band has £37,700 room.
# All £9,500 @ 8.75% = £831.25
b = UkTaxRegime().compute_tax(TaxInputs(dividends=Decimal("10000")))
assert b.dividend_tax == Decimal("831.2500")
def test_pension_25pc_tax_free() -> None:
# £40k pension drawdown, no other income:
# PCLS = £10k tax-free
# Taxable pension = £30k → £17,430 taxable @ 20% = £3,486
b = UkTaxRegime().compute_tax(TaxInputs(pension_withdrawal=Decimal("40000")))
assert b.income_tax == Decimal("3486.00")
assert b.national_insurance == Decimal("0") # NI not on pension
def test_total_equals_sum_of_components() -> None:
inputs = TaxInputs(
earned_income=Decimal("60000"),
capital_gains=Decimal("15000"),
dividends=Decimal("8000"),
)
b = UkTaxRegime().compute_tax(inputs)
assert (b.total == b.income_tax + b.national_insurance + b.capital_gains_tax + b.dividend_tax +
b.healthcare_levy + b.other)
@given(income=st.decimals(
min_value=0, max_value=500_000, places=2, allow_nan=False, allow_infinity=False))
def test_tax_monotone_in_earned_income(income: Decimal) -> None:
"""Adding earned income never decreases total tax."""
base = UkTaxRegime().compute_tax(TaxInputs(earned_income=income))
plus = UkTaxRegime().compute_tax(TaxInputs(earned_income=income + Decimal("1000")))
assert plus.total >= base.total
@given(gains=st.decimals(
min_value=0, max_value=500_000, places=2, allow_nan=False, allow_infinity=False))
def test_cgt_monotone_in_gains(gains: Decimal) -> None:
base = UkTaxRegime().compute_tax(TaxInputs(capital_gains=gains))
plus = UkTaxRegime().compute_tax(TaxInputs(capital_gains=gains + Decimal("1000")))
assert plus.capital_gains_tax >= base.capital_gains_tax