545 lines
19 KiB
Python
545 lines
19 KiB
Python
"""Tests for the Learning Engine service.
|
|
|
|
Covers trade evaluation (P&L, ROI for long/short), credit attribution
|
|
(proportional, single strategy), weight adjustment (formula, clamping,
|
|
floor), normalization (sum to 1, floor respected), minimum trade count
|
|
gating, and recency decay.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
|
|
import pytest
|
|
|
|
from services.learning_engine.config import LearningEngineConfig
|
|
from services.learning_engine.evaluator import TradeEvaluator
|
|
from services.learning_engine.weight_adjuster import WeightAdjuster
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_config(**overrides) -> LearningEngineConfig:
|
|
defaults = dict(
|
|
learning_rate=0.1,
|
|
min_trades_before_adjustment=20,
|
|
max_weight_shift_pct=0.10,
|
|
weight_floor=0.05,
|
|
recency_decay=0.95,
|
|
evaluation_window_hours=1,
|
|
)
|
|
defaults.update(overrides)
|
|
return LearningEngineConfig(**defaults)
|
|
|
|
|
|
def _make_trade_id() -> uuid.UUID:
|
|
return uuid.uuid4()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TradeEvaluator — profitable trade
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEvaluateProfitableTrade:
|
|
"""A long trade that gains in price should have positive PnL and ROI."""
|
|
|
|
def test_evaluate_profitable_trade(self):
|
|
evaluator = TradeEvaluator()
|
|
outcome = evaluator.evaluate_trade(
|
|
trade_id=_make_trade_id(),
|
|
entry_price=100.0,
|
|
exit_price=110.0,
|
|
qty=10.0,
|
|
direction_sign=1.0, # long
|
|
hold_duration_seconds=3600.0,
|
|
)
|
|
|
|
# PnL = (110 - 100) * 10 * 1.0 = 100
|
|
assert outcome.realized_pnl == pytest.approx(100.0)
|
|
# ROI = 100 / (100 * 10) * 100 = 10%
|
|
assert outcome.roi_pct == pytest.approx(10.0)
|
|
assert outcome.was_profitable is True
|
|
assert outcome.hold_duration_seconds == pytest.approx(3600.0)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TradeEvaluator — losing trade
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEvaluateLosingTrade:
|
|
"""A long trade that drops should have negative PnL."""
|
|
|
|
def test_evaluate_losing_trade(self):
|
|
evaluator = TradeEvaluator()
|
|
outcome = evaluator.evaluate_trade(
|
|
trade_id=_make_trade_id(),
|
|
entry_price=100.0,
|
|
exit_price=95.0,
|
|
qty=10.0,
|
|
direction_sign=1.0,
|
|
hold_duration_seconds=7200.0,
|
|
)
|
|
|
|
# PnL = (95 - 100) * 10 * 1.0 = -50
|
|
assert outcome.realized_pnl == pytest.approx(-50.0)
|
|
# ROI = -50 / (100 * 10) * 100 = -5%
|
|
assert outcome.roi_pct == pytest.approx(-5.0)
|
|
assert outcome.was_profitable is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TradeEvaluator — short trade PnL
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEvaluateShortTradePnl:
|
|
"""A short trade profits when the price drops."""
|
|
|
|
def test_evaluate_short_trade_pnl(self):
|
|
evaluator = TradeEvaluator()
|
|
outcome = evaluator.evaluate_trade(
|
|
trade_id=_make_trade_id(),
|
|
entry_price=100.0,
|
|
exit_price=90.0,
|
|
qty=10.0,
|
|
direction_sign=-1.0, # short
|
|
hold_duration_seconds=1800.0,
|
|
)
|
|
|
|
# PnL = (90 - 100) * 10 * (-1) = 100
|
|
assert outcome.realized_pnl == pytest.approx(100.0)
|
|
# ROI = 100 / (100 * 10) * 100 = 10%
|
|
assert outcome.roi_pct == pytest.approx(10.0)
|
|
assert outcome.was_profitable is True
|
|
|
|
def test_short_trade_loss_when_price_rises(self):
|
|
"""A short trade loses when price rises."""
|
|
evaluator = TradeEvaluator()
|
|
outcome = evaluator.evaluate_trade(
|
|
trade_id=_make_trade_id(),
|
|
entry_price=100.0,
|
|
exit_price=110.0,
|
|
qty=10.0,
|
|
direction_sign=-1.0,
|
|
hold_duration_seconds=1800.0,
|
|
)
|
|
|
|
# PnL = (110 - 100) * 10 * (-1) = -100
|
|
assert outcome.realized_pnl == pytest.approx(-100.0)
|
|
assert outcome.was_profitable is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Credit attribution — proportional
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCreditAttributionProportional:
|
|
"""Credit should be distributed proportionally to signal strength."""
|
|
|
|
def test_credit_attribution_proportional(self):
|
|
evaluator = TradeEvaluator()
|
|
outcome = evaluator.evaluate_trade(
|
|
trade_id=_make_trade_id(),
|
|
entry_price=100.0,
|
|
exit_price=110.0,
|
|
qty=10.0,
|
|
direction_sign=1.0,
|
|
hold_duration_seconds=3600.0,
|
|
)
|
|
# ROI = 10%
|
|
|
|
strategy_sources = [
|
|
"momentum:LONG:0.8",
|
|
"news_driven:LONG:0.2",
|
|
]
|
|
rewards = evaluator.attribute_credit(outcome, strategy_sources)
|
|
|
|
assert "momentum" in rewards
|
|
assert "news_driven" in rewards
|
|
|
|
# Total strength = 0.8 + 0.2 = 1.0
|
|
# momentum proportion = 0.8 / 1.0 = 0.8
|
|
# news_driven proportion = 0.2 / 1.0 = 0.2
|
|
# momentum reward = 10.0 * 0.8 = 8.0
|
|
# news_driven reward = 10.0 * 0.2 = 2.0
|
|
assert rewards["momentum"] == pytest.approx(8.0)
|
|
assert rewards["news_driven"] == pytest.approx(2.0)
|
|
|
|
def test_credit_attribution_three_strategies(self):
|
|
"""Three strategies should split reward proportionally."""
|
|
evaluator = TradeEvaluator()
|
|
outcome = evaluator.evaluate_trade(
|
|
trade_id=_make_trade_id(),
|
|
entry_price=100.0,
|
|
exit_price=105.0,
|
|
qty=20.0,
|
|
direction_sign=1.0,
|
|
hold_duration_seconds=1000.0,
|
|
)
|
|
# ROI = 5%
|
|
|
|
strategy_sources = [
|
|
"momentum:LONG:0.5",
|
|
"mean_reversion:LONG:0.3",
|
|
"news_driven:LONG:0.2",
|
|
]
|
|
rewards = evaluator.attribute_credit(outcome, strategy_sources)
|
|
|
|
total = sum(rewards.values())
|
|
assert total == pytest.approx(5.0) # sum of rewards = ROI
|
|
assert rewards["momentum"] == pytest.approx(2.5)
|
|
assert rewards["mean_reversion"] == pytest.approx(1.5)
|
|
assert rewards["news_driven"] == pytest.approx(1.0)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Credit attribution — single strategy
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCreditAttributionSingleStrategy:
|
|
"""A single strategy should receive the full reward."""
|
|
|
|
def test_credit_attribution_single_strategy(self):
|
|
evaluator = TradeEvaluator()
|
|
outcome = evaluator.evaluate_trade(
|
|
trade_id=_make_trade_id(),
|
|
entry_price=100.0,
|
|
exit_price=108.0,
|
|
qty=10.0,
|
|
direction_sign=1.0,
|
|
hold_duration_seconds=3600.0,
|
|
)
|
|
# ROI = 8%
|
|
|
|
strategy_sources = ["momentum:LONG:0.9"]
|
|
rewards = evaluator.attribute_credit(outcome, strategy_sources)
|
|
|
|
assert len(rewards) == 1
|
|
assert rewards["momentum"] == pytest.approx(8.0)
|
|
|
|
def test_credit_attribution_bare_name(self):
|
|
"""A bare strategy name (no colon format) defaults to strength 1.0."""
|
|
evaluator = TradeEvaluator()
|
|
outcome = evaluator.evaluate_trade(
|
|
trade_id=_make_trade_id(),
|
|
entry_price=100.0,
|
|
exit_price=110.0,
|
|
qty=10.0,
|
|
direction_sign=1.0,
|
|
hold_duration_seconds=3600.0,
|
|
)
|
|
# ROI = 10%
|
|
|
|
strategy_sources = ["momentum"]
|
|
rewards = evaluator.attribute_credit(outcome, strategy_sources)
|
|
|
|
assert rewards["momentum"] == pytest.approx(10.0)
|
|
|
|
def test_credit_attribution_empty_sources(self):
|
|
"""Empty strategy sources should return empty rewards."""
|
|
evaluator = TradeEvaluator()
|
|
outcome = evaluator.evaluate_trade(
|
|
trade_id=_make_trade_id(),
|
|
entry_price=100.0,
|
|
exit_price=110.0,
|
|
qty=10.0,
|
|
direction_sign=1.0,
|
|
hold_duration_seconds=3600.0,
|
|
)
|
|
rewards = evaluator.attribute_credit(outcome, [])
|
|
assert rewards == {}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Weight adjustment — formula
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestWeightAdjustmentFormula:
|
|
"""Verify the EMA weight adjustment formula."""
|
|
|
|
def test_weight_adjustment_formula(self):
|
|
config = _make_config(learning_rate=0.1, max_weight_shift_pct=1.0)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
# new_weight = (1 - 0.1) * 0.3 + 0.1 * 0.5 = 0.27 + 0.05 = 0.32
|
|
new_weight = adjuster.adjust_weight(0.3, 0.5)
|
|
assert new_weight == pytest.approx(0.32)
|
|
|
|
def test_weight_adjustment_negative_reward(self):
|
|
"""Negative reward should decrease the weight."""
|
|
config = _make_config(learning_rate=0.1, max_weight_shift_pct=1.0, weight_floor=0.0)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
# new_weight = (1 - 0.1) * 0.3 + 0.1 * (-0.5) = 0.27 + (-0.05) = 0.22
|
|
new_weight = adjuster.adjust_weight(0.3, -0.5)
|
|
assert new_weight == pytest.approx(0.22)
|
|
|
|
def test_weight_adjustment_high_learning_rate(self):
|
|
"""Higher learning rate should cause larger shifts."""
|
|
config = _make_config(learning_rate=0.5, max_weight_shift_pct=1.0)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
# new_weight = (1 - 0.5) * 0.3 + 0.5 * 0.8 = 0.15 + 0.4 = 0.55
|
|
new_weight = adjuster.adjust_weight(0.3, 0.8)
|
|
assert new_weight == pytest.approx(0.55)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Weight adjustment — max shift clamped
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestWeightAdjustmentMaxShiftClamped:
|
|
"""Weight shift should be clamped to max_weight_shift_pct."""
|
|
|
|
def test_weight_adjustment_max_shift_clamped(self):
|
|
config = _make_config(
|
|
learning_rate=0.5,
|
|
max_weight_shift_pct=0.05, # only 5% shift allowed
|
|
)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
# Raw: (1-0.5)*0.3 + 0.5*0.9 = 0.15 + 0.45 = 0.60
|
|
# Shift = 0.60 - 0.30 = 0.30 > 0.05, so clamp to +0.05
|
|
# Result = 0.30 + 0.05 = 0.35
|
|
new_weight = adjuster.adjust_weight(0.3, 0.9)
|
|
assert new_weight == pytest.approx(0.35)
|
|
|
|
def test_weight_adjustment_negative_clamped(self):
|
|
"""Negative shift should also be clamped."""
|
|
config = _make_config(
|
|
learning_rate=0.5,
|
|
max_weight_shift_pct=0.05,
|
|
weight_floor=0.0,
|
|
)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
# Raw: (1-0.5)*0.3 + 0.5*(-1.0) = 0.15 - 0.5 = -0.35
|
|
# Shift = -0.35 - 0.30 = -0.65 < -0.05, so clamp to -0.05
|
|
# Result = 0.30 - 0.05 = 0.25
|
|
new_weight = adjuster.adjust_weight(0.3, -1.0)
|
|
assert new_weight == pytest.approx(0.25)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Weight adjustment — floor applied
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestWeightAdjustmentFloorApplied:
|
|
"""Weight should never go below the configured floor."""
|
|
|
|
def test_weight_adjustment_floor_applied(self):
|
|
config = _make_config(
|
|
learning_rate=0.5,
|
|
max_weight_shift_pct=1.0, # no clamping
|
|
weight_floor=0.05,
|
|
)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
# Raw: (1-0.5)*0.1 + 0.5*(-1.0) = 0.05 - 0.5 = -0.45
|
|
# -0.45 < 0.05 floor, so clamp to 0.05
|
|
new_weight = adjuster.adjust_weight(0.1, -1.0)
|
|
assert new_weight == pytest.approx(0.05)
|
|
|
|
def test_weight_at_floor_stays_at_floor(self):
|
|
"""Weight already at floor with negative reward stays at floor."""
|
|
config = _make_config(weight_floor=0.05, max_weight_shift_pct=1.0)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
new_weight = adjuster.adjust_weight(0.05, -1.0)
|
|
assert new_weight >= 0.05
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Normalize weights — sums to one
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestNormalizeWeightsSumsToOne:
|
|
"""Normalized weights should always sum to 1.0."""
|
|
|
|
def test_normalize_weights_sums_to_one(self):
|
|
config = _make_config(weight_floor=0.05)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
weights = {"a": 0.5, "b": 0.3, "c": 0.2}
|
|
normalized = adjuster.normalize_weights(weights)
|
|
|
|
assert sum(normalized.values()) == pytest.approx(1.0)
|
|
|
|
def test_normalize_unequal_weights(self):
|
|
"""Very unequal weights should still normalize to 1.0."""
|
|
config = _make_config(weight_floor=0.01)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
weights = {"a": 10.0, "b": 1.0, "c": 0.1}
|
|
normalized = adjuster.normalize_weights(weights)
|
|
|
|
assert sum(normalized.values()) == pytest.approx(1.0)
|
|
# Proportions should be maintained
|
|
assert normalized["a"] > normalized["b"] > normalized["c"]
|
|
|
|
def test_normalize_empty_weights(self):
|
|
"""Empty weights dict should return empty."""
|
|
config = _make_config()
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
assert adjuster.normalize_weights({}) == {}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Normalize weights — respects floor
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestNormalizeWeightsRespectsFloor:
|
|
"""No strategy weight should drop below the floor after normalization."""
|
|
|
|
def test_normalize_weights_respects_floor(self):
|
|
config = _make_config(weight_floor=0.10)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
# One very small weight
|
|
weights = {"a": 1.0, "b": 0.01, "c": 0.01}
|
|
normalized = adjuster.normalize_weights(weights)
|
|
|
|
assert sum(normalized.values()) == pytest.approx(1.0)
|
|
for name, w in normalized.items():
|
|
assert w >= 0.10 - 1e-9, f"Weight for {name} is {w}, below floor 0.10"
|
|
|
|
def test_normalize_all_below_floor(self):
|
|
"""If all weights are tiny, they should all get at least the floor."""
|
|
config = _make_config(weight_floor=0.10)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
weights = {"a": 0.001, "b": 0.001, "c": 0.001}
|
|
normalized = adjuster.normalize_weights(weights)
|
|
|
|
assert sum(normalized.values()) == pytest.approx(1.0)
|
|
for w in normalized.values():
|
|
assert w >= 0.10 - 1e-9
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Should adjust — requires min trades
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestShouldAdjustRequiresMinTrades:
|
|
"""Weight adjustment should be gated by minimum trade count."""
|
|
|
|
def test_should_adjust_requires_min_trades(self):
|
|
config = _make_config(min_trades_before_adjustment=20)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
# No trades recorded yet
|
|
assert adjuster.should_adjust("momentum") is False
|
|
|
|
# Record 19 trades -- still not enough
|
|
for _ in range(19):
|
|
adjuster.record_trade("momentum")
|
|
assert adjuster.should_adjust("momentum") is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Should adjust — after enough trades
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestShouldAdjustAfterEnoughTrades:
|
|
"""Once enough trades are recorded, adjustment should be allowed."""
|
|
|
|
def test_should_adjust_after_enough_trades(self):
|
|
config = _make_config(min_trades_before_adjustment=20)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
for _ in range(20):
|
|
adjuster.record_trade("momentum")
|
|
|
|
assert adjuster.should_adjust("momentum") is True
|
|
|
|
def test_other_strategies_still_gated(self):
|
|
"""Other strategies should still be gated independently."""
|
|
config = _make_config(min_trades_before_adjustment=5)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
for _ in range(5):
|
|
adjuster.record_trade("momentum")
|
|
|
|
assert adjuster.should_adjust("momentum") is True
|
|
assert adjuster.should_adjust("news_driven") is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Recency decay — recent trades weighted more
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRecencyDecayApplied:
|
|
"""Recent trades should carry more weight than older ones."""
|
|
|
|
def test_recency_decay_applied(self):
|
|
config = _make_config(recency_decay=0.5)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
# Record two rewards: first one will be decayed
|
|
adjuster.record_reward("momentum", 10.0)
|
|
# After recording the second, the first gets multiplied by 0.5
|
|
adjuster.record_reward("momentum", 10.0)
|
|
|
|
# History after second record: [10.0 * 0.5, 10.0] = [5.0, 10.0]
|
|
avg = adjuster.get_decayed_reward("momentum")
|
|
assert avg == pytest.approx(7.5) # (5.0 + 10.0) / 2
|
|
|
|
def test_recency_decay_multiple_rounds(self):
|
|
"""Multiple rounds of decay should compound."""
|
|
config = _make_config(recency_decay=0.5)
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
# First reward
|
|
adjuster.record_reward("momentum", 8.0)
|
|
# Second reward: first decayed to 8*0.5 = 4
|
|
adjuster.record_reward("momentum", 8.0)
|
|
# Third reward: existing [4, 8] decayed to [2, 4], then add 8
|
|
adjuster.record_reward("momentum", 8.0)
|
|
|
|
# History: [2.0, 4.0, 8.0]
|
|
avg = adjuster.get_decayed_reward("momentum")
|
|
assert avg == pytest.approx((2.0 + 4.0 + 8.0) / 3)
|
|
|
|
def test_no_decay_when_no_history(self):
|
|
"""Getting decayed reward with no history should return 0."""
|
|
config = _make_config()
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
assert adjuster.get_decayed_reward("momentum") == 0.0
|
|
|
|
def test_high_decay_preserves_more(self):
|
|
"""A high decay factor (close to 1) preserves older rewards more."""
|
|
config_high = _make_config(recency_decay=0.99)
|
|
adjuster_high = WeightAdjuster(config_high)
|
|
|
|
config_low = _make_config(recency_decay=0.5)
|
|
adjuster_low = WeightAdjuster(config_low)
|
|
|
|
for adj in [adjuster_high, adjuster_low]:
|
|
adj.record_reward("momentum", 10.0)
|
|
adj.record_reward("momentum", 10.0)
|
|
|
|
# High decay: [10*0.99, 10] = [9.9, 10] -> avg = 9.95
|
|
# Low decay: [10*0.5, 10] = [5.0, 10] -> avg = 7.5
|
|
high_avg = adjuster_high.get_decayed_reward("momentum")
|
|
low_avg = adjuster_low.get_decayed_reward("momentum")
|
|
|
|
assert high_avg > low_avg
|
|
assert high_avg == pytest.approx(9.95)
|
|
assert low_avg == pytest.approx(7.5)
|