"""Tax-regime abstract base — every jurisdiction implements this. Inputs are split by income source because each source carries different tax treatment (e.g. ISA withdrawals are always 0%, capital gains may be exempt in some jurisdictions, pension withdrawals are partially tax-free in the UK). The regime decides how to combine them. Outputs are split per tax type so we can attribute lifetime tax — the Grafana panel shows e.g. "lifetime CGT paid" separately from "lifetime income tax". """ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass, field from decimal import Decimal @dataclass(frozen=True) class TaxInputs: """Annual gross flows for a single tax year. All amounts in GBP, all non-negative — withdrawals are absolute values. `years_since_uk_departure` lets the UK regime apply the 5-year Temporary Non-Residence claw-back: gains realised abroad get clawed back if you return within 5y. Non-UK regimes ignore it. """ earned_income: Decimal = Decimal("0") pension_withdrawal: Decimal = Decimal("0") capital_gains: Decimal = Decimal("0") dividends: Decimal = Decimal("0") isa_withdrawals: Decimal = Decimal("0") interest: Decimal = Decimal("0") years_since_uk_departure: int = 0 @dataclass(frozen=True) class TaxBreakdown: """Tax due, split by category. `total` is the sum — every regime must keep `total == sum of categories` for the integrity check. """ income_tax: Decimal = Decimal("0") national_insurance: Decimal = Decimal("0") capital_gains_tax: Decimal = Decimal("0") dividend_tax: Decimal = Decimal("0") healthcare_levy: Decimal = Decimal("0") other: Decimal = Decimal("0") notes: tuple[str, ...] = field(default_factory=tuple) @property def total(self) -> Decimal: return (self.income_tax + self.national_insurance + self.capital_gains_tax + self.dividend_tax + self.healthcare_levy + self.other) class TaxRegime(ABC): """Per-jurisdiction tax engine. Stateless — every call gets fresh inputs. Sub-classes set `name` for the scenario key. """ name: str @abstractmethod def compute_tax(self, inputs: TaxInputs) -> TaxBreakdown: """Return the year's tax due given gross income/gains/dividends.""" raise NotImplementedError def apply_brackets(amount: Decimal, brackets: list[tuple[Decimal, Decimal]]) -> Decimal: """Apply a progressive bracket schedule to `amount`. `brackets` is a list of (band_top, marginal_rate) — band_top is the upper bound of the band (use Decimal('Infinity') for the last band). Bands are evaluated in order from lowest to highest. Example UK PAYE 2026/27 above the personal allowance: [(50_270 - 12_570, Decimal("0.20")), (125_140 - 12_570, Decimal("0.40")), (Decimal("Infinity"), Decimal("0.45"))] where `amount` is taxable income net of the allowance. """ if amount <= 0: return Decimal("0") tax = Decimal("0") prev_top = Decimal("0") for band_top, rate in brackets: if amount <= prev_top: break slice_top = min(amount, band_top) tax += (slice_top - prev_top) * rate prev_top = band_top return tax