trading/tests/services/test_learning_engine.py

545 lines
19 KiB
Python

"""Tests for the Learning Engine service.
Covers trade evaluation (P&L, ROI for long/short), credit attribution
(proportional, single strategy), weight adjustment (formula, clamping,
floor), normalization (sum to 1, floor respected), minimum trade count
gating, and recency decay.
"""
from __future__ import annotations
import uuid
import pytest
from services.learning_engine.config import LearningEngineConfig
from services.learning_engine.evaluator import TradeEvaluator
from services.learning_engine.weight_adjuster import WeightAdjuster
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_config(**overrides) -> LearningEngineConfig:
defaults = dict(
learning_rate=0.1,
min_trades_before_adjustment=20,
max_weight_shift_pct=0.10,
weight_floor=0.05,
recency_decay=0.95,
evaluation_window_hours=1,
)
defaults.update(overrides)
return LearningEngineConfig(**defaults)
def _make_trade_id() -> uuid.UUID:
return uuid.uuid4()
# ---------------------------------------------------------------------------
# TradeEvaluator — profitable trade
# ---------------------------------------------------------------------------
class TestEvaluateProfitableTrade:
"""A long trade that gains in price should have positive PnL and ROI."""
def test_evaluate_profitable_trade(self):
evaluator = TradeEvaluator()
outcome = evaluator.evaluate_trade(
trade_id=_make_trade_id(),
entry_price=100.0,
exit_price=110.0,
qty=10.0,
direction_sign=1.0, # long
hold_duration_seconds=3600.0,
)
# PnL = (110 - 100) * 10 * 1.0 = 100
assert outcome.realized_pnl == pytest.approx(100.0)
# ROI = 100 / (100 * 10) * 100 = 10%
assert outcome.roi_pct == pytest.approx(10.0)
assert outcome.was_profitable is True
assert outcome.hold_duration_seconds == pytest.approx(3600.0)
# ---------------------------------------------------------------------------
# TradeEvaluator — losing trade
# ---------------------------------------------------------------------------
class TestEvaluateLosingTrade:
"""A long trade that drops should have negative PnL."""
def test_evaluate_losing_trade(self):
evaluator = TradeEvaluator()
outcome = evaluator.evaluate_trade(
trade_id=_make_trade_id(),
entry_price=100.0,
exit_price=95.0,
qty=10.0,
direction_sign=1.0,
hold_duration_seconds=7200.0,
)
# PnL = (95 - 100) * 10 * 1.0 = -50
assert outcome.realized_pnl == pytest.approx(-50.0)
# ROI = -50 / (100 * 10) * 100 = -5%
assert outcome.roi_pct == pytest.approx(-5.0)
assert outcome.was_profitable is False
# ---------------------------------------------------------------------------
# TradeEvaluator — short trade PnL
# ---------------------------------------------------------------------------
class TestEvaluateShortTradePnl:
"""A short trade profits when the price drops."""
def test_evaluate_short_trade_pnl(self):
evaluator = TradeEvaluator()
outcome = evaluator.evaluate_trade(
trade_id=_make_trade_id(),
entry_price=100.0,
exit_price=90.0,
qty=10.0,
direction_sign=-1.0, # short
hold_duration_seconds=1800.0,
)
# PnL = (90 - 100) * 10 * (-1) = 100
assert outcome.realized_pnl == pytest.approx(100.0)
# ROI = 100 / (100 * 10) * 100 = 10%
assert outcome.roi_pct == pytest.approx(10.0)
assert outcome.was_profitable is True
def test_short_trade_loss_when_price_rises(self):
"""A short trade loses when price rises."""
evaluator = TradeEvaluator()
outcome = evaluator.evaluate_trade(
trade_id=_make_trade_id(),
entry_price=100.0,
exit_price=110.0,
qty=10.0,
direction_sign=-1.0,
hold_duration_seconds=1800.0,
)
# PnL = (110 - 100) * 10 * (-1) = -100
assert outcome.realized_pnl == pytest.approx(-100.0)
assert outcome.was_profitable is False
# ---------------------------------------------------------------------------
# Credit attribution — proportional
# ---------------------------------------------------------------------------
class TestCreditAttributionProportional:
"""Credit should be distributed proportionally to signal strength."""
def test_credit_attribution_proportional(self):
evaluator = TradeEvaluator()
outcome = evaluator.evaluate_trade(
trade_id=_make_trade_id(),
entry_price=100.0,
exit_price=110.0,
qty=10.0,
direction_sign=1.0,
hold_duration_seconds=3600.0,
)
# ROI = 10%
strategy_sources = [
"momentum:LONG:0.8",
"news_driven:LONG:0.2",
]
rewards = evaluator.attribute_credit(outcome, strategy_sources)
assert "momentum" in rewards
assert "news_driven" in rewards
# Total strength = 0.8 + 0.2 = 1.0
# momentum proportion = 0.8 / 1.0 = 0.8
# news_driven proportion = 0.2 / 1.0 = 0.2
# momentum reward = 10.0 * 0.8 = 8.0
# news_driven reward = 10.0 * 0.2 = 2.0
assert rewards["momentum"] == pytest.approx(8.0)
assert rewards["news_driven"] == pytest.approx(2.0)
def test_credit_attribution_three_strategies(self):
"""Three strategies should split reward proportionally."""
evaluator = TradeEvaluator()
outcome = evaluator.evaluate_trade(
trade_id=_make_trade_id(),
entry_price=100.0,
exit_price=105.0,
qty=20.0,
direction_sign=1.0,
hold_duration_seconds=1000.0,
)
# ROI = 5%
strategy_sources = [
"momentum:LONG:0.5",
"mean_reversion:LONG:0.3",
"news_driven:LONG:0.2",
]
rewards = evaluator.attribute_credit(outcome, strategy_sources)
total = sum(rewards.values())
assert total == pytest.approx(5.0) # sum of rewards = ROI
assert rewards["momentum"] == pytest.approx(2.5)
assert rewards["mean_reversion"] == pytest.approx(1.5)
assert rewards["news_driven"] == pytest.approx(1.0)
# ---------------------------------------------------------------------------
# Credit attribution — single strategy
# ---------------------------------------------------------------------------
class TestCreditAttributionSingleStrategy:
"""A single strategy should receive the full reward."""
def test_credit_attribution_single_strategy(self):
evaluator = TradeEvaluator()
outcome = evaluator.evaluate_trade(
trade_id=_make_trade_id(),
entry_price=100.0,
exit_price=108.0,
qty=10.0,
direction_sign=1.0,
hold_duration_seconds=3600.0,
)
# ROI = 8%
strategy_sources = ["momentum:LONG:0.9"]
rewards = evaluator.attribute_credit(outcome, strategy_sources)
assert len(rewards) == 1
assert rewards["momentum"] == pytest.approx(8.0)
def test_credit_attribution_bare_name(self):
"""A bare strategy name (no colon format) defaults to strength 1.0."""
evaluator = TradeEvaluator()
outcome = evaluator.evaluate_trade(
trade_id=_make_trade_id(),
entry_price=100.0,
exit_price=110.0,
qty=10.0,
direction_sign=1.0,
hold_duration_seconds=3600.0,
)
# ROI = 10%
strategy_sources = ["momentum"]
rewards = evaluator.attribute_credit(outcome, strategy_sources)
assert rewards["momentum"] == pytest.approx(10.0)
def test_credit_attribution_empty_sources(self):
"""Empty strategy sources should return empty rewards."""
evaluator = TradeEvaluator()
outcome = evaluator.evaluate_trade(
trade_id=_make_trade_id(),
entry_price=100.0,
exit_price=110.0,
qty=10.0,
direction_sign=1.0,
hold_duration_seconds=3600.0,
)
rewards = evaluator.attribute_credit(outcome, [])
assert rewards == {}
# ---------------------------------------------------------------------------
# Weight adjustment — formula
# ---------------------------------------------------------------------------
class TestWeightAdjustmentFormula:
"""Verify the EMA weight adjustment formula."""
def test_weight_adjustment_formula(self):
config = _make_config(learning_rate=0.1, max_weight_shift_pct=1.0)
adjuster = WeightAdjuster(config)
# new_weight = (1 - 0.1) * 0.3 + 0.1 * 0.5 = 0.27 + 0.05 = 0.32
new_weight = adjuster.adjust_weight(0.3, 0.5)
assert new_weight == pytest.approx(0.32)
def test_weight_adjustment_negative_reward(self):
"""Negative reward should decrease the weight."""
config = _make_config(learning_rate=0.1, max_weight_shift_pct=1.0, weight_floor=0.0)
adjuster = WeightAdjuster(config)
# new_weight = (1 - 0.1) * 0.3 + 0.1 * (-0.5) = 0.27 + (-0.05) = 0.22
new_weight = adjuster.adjust_weight(0.3, -0.5)
assert new_weight == pytest.approx(0.22)
def test_weight_adjustment_high_learning_rate(self):
"""Higher learning rate should cause larger shifts."""
config = _make_config(learning_rate=0.5, max_weight_shift_pct=1.0)
adjuster = WeightAdjuster(config)
# new_weight = (1 - 0.5) * 0.3 + 0.5 * 0.8 = 0.15 + 0.4 = 0.55
new_weight = adjuster.adjust_weight(0.3, 0.8)
assert new_weight == pytest.approx(0.55)
# ---------------------------------------------------------------------------
# Weight adjustment — max shift clamped
# ---------------------------------------------------------------------------
class TestWeightAdjustmentMaxShiftClamped:
"""Weight shift should be clamped to max_weight_shift_pct."""
def test_weight_adjustment_max_shift_clamped(self):
config = _make_config(
learning_rate=0.5,
max_weight_shift_pct=0.05, # only 5% shift allowed
)
adjuster = WeightAdjuster(config)
# Raw: (1-0.5)*0.3 + 0.5*0.9 = 0.15 + 0.45 = 0.60
# Shift = 0.60 - 0.30 = 0.30 > 0.05, so clamp to +0.05
# Result = 0.30 + 0.05 = 0.35
new_weight = adjuster.adjust_weight(0.3, 0.9)
assert new_weight == pytest.approx(0.35)
def test_weight_adjustment_negative_clamped(self):
"""Negative shift should also be clamped."""
config = _make_config(
learning_rate=0.5,
max_weight_shift_pct=0.05,
weight_floor=0.0,
)
adjuster = WeightAdjuster(config)
# Raw: (1-0.5)*0.3 + 0.5*(-1.0) = 0.15 - 0.5 = -0.35
# Shift = -0.35 - 0.30 = -0.65 < -0.05, so clamp to -0.05
# Result = 0.30 - 0.05 = 0.25
new_weight = adjuster.adjust_weight(0.3, -1.0)
assert new_weight == pytest.approx(0.25)
# ---------------------------------------------------------------------------
# Weight adjustment — floor applied
# ---------------------------------------------------------------------------
class TestWeightAdjustmentFloorApplied:
"""Weight should never go below the configured floor."""
def test_weight_adjustment_floor_applied(self):
config = _make_config(
learning_rate=0.5,
max_weight_shift_pct=1.0, # no clamping
weight_floor=0.05,
)
adjuster = WeightAdjuster(config)
# Raw: (1-0.5)*0.1 + 0.5*(-1.0) = 0.05 - 0.5 = -0.45
# -0.45 < 0.05 floor, so clamp to 0.05
new_weight = adjuster.adjust_weight(0.1, -1.0)
assert new_weight == pytest.approx(0.05)
def test_weight_at_floor_stays_at_floor(self):
"""Weight already at floor with negative reward stays at floor."""
config = _make_config(weight_floor=0.05, max_weight_shift_pct=1.0)
adjuster = WeightAdjuster(config)
new_weight = adjuster.adjust_weight(0.05, -1.0)
assert new_weight >= 0.05
# ---------------------------------------------------------------------------
# Normalize weights — sums to one
# ---------------------------------------------------------------------------
class TestNormalizeWeightsSumsToOne:
"""Normalized weights should always sum to 1.0."""
def test_normalize_weights_sums_to_one(self):
config = _make_config(weight_floor=0.05)
adjuster = WeightAdjuster(config)
weights = {"a": 0.5, "b": 0.3, "c": 0.2}
normalized = adjuster.normalize_weights(weights)
assert sum(normalized.values()) == pytest.approx(1.0)
def test_normalize_unequal_weights(self):
"""Very unequal weights should still normalize to 1.0."""
config = _make_config(weight_floor=0.01)
adjuster = WeightAdjuster(config)
weights = {"a": 10.0, "b": 1.0, "c": 0.1}
normalized = adjuster.normalize_weights(weights)
assert sum(normalized.values()) == pytest.approx(1.0)
# Proportions should be maintained
assert normalized["a"] > normalized["b"] > normalized["c"]
def test_normalize_empty_weights(self):
"""Empty weights dict should return empty."""
config = _make_config()
adjuster = WeightAdjuster(config)
assert adjuster.normalize_weights({}) == {}
# ---------------------------------------------------------------------------
# Normalize weights — respects floor
# ---------------------------------------------------------------------------
class TestNormalizeWeightsRespectsFloor:
"""No strategy weight should drop below the floor after normalization."""
def test_normalize_weights_respects_floor(self):
config = _make_config(weight_floor=0.10)
adjuster = WeightAdjuster(config)
# One very small weight
weights = {"a": 1.0, "b": 0.01, "c": 0.01}
normalized = adjuster.normalize_weights(weights)
assert sum(normalized.values()) == pytest.approx(1.0)
for name, w in normalized.items():
assert w >= 0.10 - 1e-9, f"Weight for {name} is {w}, below floor 0.10"
def test_normalize_all_below_floor(self):
"""If all weights are tiny, they should all get at least the floor."""
config = _make_config(weight_floor=0.10)
adjuster = WeightAdjuster(config)
weights = {"a": 0.001, "b": 0.001, "c": 0.001}
normalized = adjuster.normalize_weights(weights)
assert sum(normalized.values()) == pytest.approx(1.0)
for w in normalized.values():
assert w >= 0.10 - 1e-9
# ---------------------------------------------------------------------------
# Should adjust — requires min trades
# ---------------------------------------------------------------------------
class TestShouldAdjustRequiresMinTrades:
"""Weight adjustment should be gated by minimum trade count."""
def test_should_adjust_requires_min_trades(self):
config = _make_config(min_trades_before_adjustment=20)
adjuster = WeightAdjuster(config)
# No trades recorded yet
assert adjuster.should_adjust("momentum") is False
# Record 19 trades -- still not enough
for _ in range(19):
adjuster.record_trade("momentum")
assert adjuster.should_adjust("momentum") is False
# ---------------------------------------------------------------------------
# Should adjust — after enough trades
# ---------------------------------------------------------------------------
class TestShouldAdjustAfterEnoughTrades:
"""Once enough trades are recorded, adjustment should be allowed."""
def test_should_adjust_after_enough_trades(self):
config = _make_config(min_trades_before_adjustment=20)
adjuster = WeightAdjuster(config)
for _ in range(20):
adjuster.record_trade("momentum")
assert adjuster.should_adjust("momentum") is True
def test_other_strategies_still_gated(self):
"""Other strategies should still be gated independently."""
config = _make_config(min_trades_before_adjustment=5)
adjuster = WeightAdjuster(config)
for _ in range(5):
adjuster.record_trade("momentum")
assert adjuster.should_adjust("momentum") is True
assert adjuster.should_adjust("news_driven") is False
# ---------------------------------------------------------------------------
# Recency decay — recent trades weighted more
# ---------------------------------------------------------------------------
class TestRecencyDecayApplied:
"""Recent trades should carry more weight than older ones."""
def test_recency_decay_applied(self):
config = _make_config(recency_decay=0.5)
adjuster = WeightAdjuster(config)
# Record two rewards: first one will be decayed
adjuster.record_reward("momentum", 10.0)
# After recording the second, the first gets multiplied by 0.5
adjuster.record_reward("momentum", 10.0)
# History after second record: [10.0 * 0.5, 10.0] = [5.0, 10.0]
avg = adjuster.get_decayed_reward("momentum")
assert avg == pytest.approx(7.5) # (5.0 + 10.0) / 2
def test_recency_decay_multiple_rounds(self):
"""Multiple rounds of decay should compound."""
config = _make_config(recency_decay=0.5)
adjuster = WeightAdjuster(config)
# First reward
adjuster.record_reward("momentum", 8.0)
# Second reward: first decayed to 8*0.5 = 4
adjuster.record_reward("momentum", 8.0)
# Third reward: existing [4, 8] decayed to [2, 4], then add 8
adjuster.record_reward("momentum", 8.0)
# History: [2.0, 4.0, 8.0]
avg = adjuster.get_decayed_reward("momentum")
assert avg == pytest.approx((2.0 + 4.0 + 8.0) / 3)
def test_no_decay_when_no_history(self):
"""Getting decayed reward with no history should return 0."""
config = _make_config()
adjuster = WeightAdjuster(config)
assert adjuster.get_decayed_reward("momentum") == 0.0
def test_high_decay_preserves_more(self):
"""A high decay factor (close to 1) preserves older rewards more."""
config_high = _make_config(recency_decay=0.99)
adjuster_high = WeightAdjuster(config_high)
config_low = _make_config(recency_decay=0.5)
adjuster_low = WeightAdjuster(config_low)
for adj in [adjuster_high, adjuster_low]:
adj.record_reward("momentum", 10.0)
adj.record_reward("momentum", 10.0)
# High decay: [10*0.99, 10] = [9.9, 10] -> avg = 9.95
# Low decay: [10*0.5, 10] = [5.0, 10] -> avg = 7.5
high_avg = adjuster_high.get_decayed_reward("momentum")
low_avg = adjuster_low.get_decayed_reward("momentum")
assert high_avg > low_avg
assert high_avg == pytest.approx(9.95)
assert low_avg == pytest.approx(7.5)