"""End-to-end test that flex-spending rules survive £ in the portfolio.""" from __future__ import annotations import numpy as np from fire_planner.flex_spending import FlexRule from fire_planner.glide_path import static from fire_planner.simulator import simulate from fire_planner.strategies.trinity import TrinityStrategy from fire_planner.tax.uae import UaeTaxRegime def _flat_paths(n_paths: int, n_years: int, real_return: float = 0.0) -> np.ndarray: """Returns paths cube where real return == 0% — easy to reason about.""" paths = np.zeros((n_paths, n_years, 3), dtype=np.float64) paths[:, :, 0] = real_return # nominal stocks paths[:, :, 1] = real_return # nominal bonds paths[:, :, 2] = 0.0 # cpi return paths def test_flex_rule_saves_money_at_drawdown() -> None: """A scenario that drops below ATH triggers a discretionary cut and ends up richer than the same scenario with no flex rules.""" paths = _flat_paths(n_paths=10, n_years=5, real_return=-0.05) initial = 1_000_000.0 common = dict( paths=paths, initial_portfolio=initial, spending_target=10_000.0, glide=static(1.0), strategy=TrinityStrategy(), regime=UaeTaxRegime(), horizon_years=5, cashflow_adjustments=np.full(5, -20_000.0, dtype=np.float64), discretionary_outflows=np.full(5, 20_000.0, dtype=np.float64), ) no_flex = simulate(**common) with_flex = simulate( **common, flex_rules=[FlexRule(from_ath_pct=0.05, cut_discretionary_pct=0.50)], ) no_flex_end = float(np.median(no_flex.portfolio_real[:, -1])) with_flex_end = float(np.median(with_flex.portfolio_real[:, -1])) assert with_flex_end > no_flex_end assert no_flex_end > 0 # didn't ruin — meaningful comparison def test_flex_rule_no_op_without_drawdown() -> None: """Strong-positive returns, never below ATH → flex rules do nothing.""" paths = _flat_paths(n_paths=10, n_years=5, real_return=0.10) common = dict( paths=paths, initial_portfolio=1_000_000.0, spending_target=40_000.0, glide=static(1.0), strategy=TrinityStrategy(), regime=UaeTaxRegime(), horizon_years=5, cashflow_adjustments=np.full(5, -10_000.0, dtype=np.float64), discretionary_outflows=np.full(5, 10_000.0, dtype=np.float64), ) no_flex = simulate(**common) with_flex = simulate( **common, flex_rules=[FlexRule(from_ath_pct=0.10, cut_discretionary_pct=0.50)], ) assert np.allclose(no_flex.portfolio_real, with_flex.portfolio_real)