diff --git a/services/learning_engine/__init__.py b/services/learning_engine/__init__.py new file mode 100644 index 0000000..88cb642 --- /dev/null +++ b/services/learning_engine/__init__.py @@ -0,0 +1 @@ +"""Learning Engine service -- multi-armed bandit strategy weight adjustment.""" diff --git a/services/learning_engine/config.py b/services/learning_engine/config.py new file mode 100644 index 0000000..5186f04 --- /dev/null +++ b/services/learning_engine/config.py @@ -0,0 +1,16 @@ +"""Configuration for the learning engine service.""" + +from shared.config import BaseConfig + + +class LearningEngineConfig(BaseConfig): + """Extends BaseConfig with learning-engine-specific settings.""" + + learning_rate: float = 0.1 + min_trades_before_adjustment: int = 20 + max_weight_shift_pct: float = 0.10 + weight_floor: float = 0.05 + recency_decay: float = 0.95 + evaluation_window_hours: int = 1 + + model_config = {"env_prefix": "TRADING_"} diff --git a/services/learning_engine/evaluator.py b/services/learning_engine/evaluator.py new file mode 100644 index 0000000..3ea3764 --- /dev/null +++ b/services/learning_engine/evaluator.py @@ -0,0 +1,120 @@ +"""Trade evaluator -- computes outcomes and attributes credit to strategies. + +Given a closed trade (exit), this module computes realized P&L, ROI, and +distributes reward signals to each contributing strategy proportionally +to its signal strength. +""" + +from __future__ import annotations + +import logging +from uuid import UUID + +from shared.schemas.learning import TradeOutcomeSchema + +logger = logging.getLogger(__name__) + + +class TradeEvaluator: + """Evaluates closed trades and attributes credit to strategies.""" + + def evaluate_trade( + self, + trade_id: UUID, + entry_price: float, + exit_price: float, + qty: float, + direction_sign: float, + hold_duration_seconds: float, + ) -> TradeOutcomeSchema: + """Compute the outcome of a closed trade. + + Parameters + ---------- + trade_id: + Unique identifier of the closing trade. + entry_price: + The price at which the position was opened. + exit_price: + The price at which the position was closed. + qty: + Number of shares traded. + direction_sign: + +1.0 for long positions, -1.0 for short positions. + hold_duration_seconds: + How long the position was held, in seconds. + + Returns + ------- + TradeOutcomeSchema + The evaluated outcome including realized P&L and ROI. + """ + realized_pnl = (exit_price - entry_price) * qty * direction_sign + cost_basis = entry_price * qty + roi_pct = (realized_pnl / cost_basis * 100.0) if cost_basis != 0 else 0.0 + was_profitable = realized_pnl > 0 + + return TradeOutcomeSchema( + trade_id=trade_id, + hold_duration_seconds=hold_duration_seconds, + realized_pnl=realized_pnl, + roi_pct=roi_pct, + was_profitable=was_profitable, + ) + + def attribute_credit( + self, + outcome: TradeOutcomeSchema, + strategy_sources: list[str], + ) -> dict[str, float]: + """Distribute reward signal to contributing strategies. + + Parses ``strategy_sources`` entries which may be formatted as either: + - ``"name:DIRECTION:strength"`` (full format from the ensemble) + - ``"name"`` (bare strategy name -- defaults to strength 1.0) + + The reward signal is the trade's ROI percentage distributed + proportionally to each strategy's signal strength. + + Parameters + ---------- + outcome: + The evaluated trade outcome. + strategy_sources: + List of strategy source strings from the signal. + + Returns + ------- + dict[str, float] + Mapping of strategy name to its reward signal. + """ + if not strategy_sources: + return {} + + # Parse strengths from strategy_sources + parsed: list[tuple[str, float]] = [] + for source in strategy_sources: + parts = source.split(":") + name = parts[0] + if len(parts) >= 3: + try: + strength = float(parts[2]) + except (ValueError, IndexError): + strength = 1.0 + else: + strength = 1.0 + parsed.append((name, strength)) + + # Compute total strength for proportional distribution + total_strength = sum(s for _, s in parsed) + if total_strength == 0: + return {} + + # Distribute reward proportionally + rewards: dict[str, float] = {} + for name, strength in parsed: + proportion = strength / total_strength + reward_signal = outcome.roi_pct * proportion + rewards[name] = reward_signal + + return rewards diff --git a/services/learning_engine/main.py b/services/learning_engine/main.py new file mode 100644 index 0000000..30de800 --- /dev/null +++ b/services/learning_engine/main.py @@ -0,0 +1,304 @@ +"""Learning Engine service -- main entry point. + +Consumes ``trades:executed`` from Redis Streams, evaluates closed positions, +attributes credit to contributing strategies, adjusts strategy weights via +a multi-armed bandit approach, and stores all adjustments for auditability. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from datetime import datetime, timezone +from uuid import UUID + +from redis.asyncio import Redis + +from services.learning_engine.config import LearningEngineConfig +from services.learning_engine.evaluator import TradeEvaluator +from services.learning_engine.weight_adjuster import WeightAdjuster +from shared.redis_streams import StreamConsumer +from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment +from shared.schemas.trading import OrderSide, TradeExecution +from shared.telemetry import setup_telemetry + +logger = logging.getLogger(__name__) + +# Redis key for cached strategy weights +_STRATEGY_WEIGHTS_KEY = "strategy:weights" + + +async def _load_strategy_weights(redis: Redis) -> dict[str, float]: + """Load current strategy weights from Redis cache. + + Falls back to equal weights for the three default strategies + if no cached weights exist. + """ + raw = await redis.get(_STRATEGY_WEIGHTS_KEY) + if raw: + return json.loads(raw) + # Default equal weights + return { + "momentum": 0.333, + "mean_reversion": 0.333, + "news_driven": 0.334, + } + + +async def _save_strategy_weights(redis: Redis, weights: dict[str, float]) -> None: + """Persist strategy weights to Redis cache.""" + await redis.set(_STRATEGY_WEIGHTS_KEY, json.dumps(weights)) + + +async def _find_opening_trade( + redis: Redis, + ticker: str, + closing_side: OrderSide, +) -> dict | None: + """Look up the opening trade for a position close. + + Searches the ``positions:history`` Redis hash for the ticker. + Returns the stored entry data or None if not found. + """ + raw = await redis.hget("positions:history", ticker) + if raw: + return json.loads(raw) + return None + + +async def _store_opening_trade( + redis: Redis, + ticker: str, + trade_data: dict, +) -> None: + """Store a trade as the opening trade for a ticker.""" + await redis.hset("positions:history", ticker, json.dumps(trade_data)) + + +async def _clear_opening_trade(redis: Redis, ticker: str) -> None: + """Clear the stored opening trade after position close.""" + await redis.hdel("positions:history", ticker) + + +def _is_position_close(trade: TradeExecution, opening: dict | None) -> bool: + """Determine if a trade closes a position. + + A trade closes a position if there is an existing opening trade + on the opposite side for the same ticker. + """ + if opening is None: + return False + opening_side = opening.get("side", "") + if trade.side == OrderSide.SELL and opening_side == OrderSide.BUY.value: + return True + if trade.side == OrderSide.BUY and opening_side == OrderSide.SELL.value: + return True + return False + + +async def process_trade( + trade: TradeExecution, + redis: Redis, + evaluator: TradeEvaluator, + adjuster: WeightAdjuster, + counters: dict, +) -> list[WeightAdjustment]: + """Process a single trade execution. + + If the trade closes a position: + 1. Evaluate the trade outcome (P&L, ROI) + 2. Attribute credit to contributing strategies + 3. Adjust weights for strategies with enough trades + 4. Normalize all weights + 5. Store adjustments and update cached weights + + Returns a list of weight adjustments made (empty if none). + """ + adjustments: list[WeightAdjustment] = [] + + # Look up opening trade + opening = await _find_opening_trade(redis, trade.ticker, trade.side) + + if not _is_position_close(trade, opening): + # This is an opening trade -- store it for later reference + await _store_opening_trade( + redis, + trade.ticker, + { + "trade_id": str(trade.trade_id), + "side": trade.side.value, + "price": trade.price, + "qty": trade.qty, + "timestamp": trade.timestamp.isoformat(), + "strategy_sources": [], # would come from signal + }, + ) + return adjustments + + # --- Position close detected --- + entry_price = opening["price"] + entry_qty = opening.get("qty", trade.qty) + entry_time_str = opening.get("timestamp", "") + strategy_sources = opening.get("strategy_sources", []) + + # Determine direction sign + opening_side = opening.get("side", "") + direction_sign = 1.0 if opening_side == OrderSide.BUY.value else -1.0 + + # Compute hold duration + hold_duration_seconds = 0.0 + if entry_time_str: + try: + entry_time = datetime.fromisoformat(entry_time_str) + hold_duration_seconds = (trade.timestamp - entry_time).total_seconds() + except (ValueError, TypeError): + hold_duration_seconds = 0.0 + + # Step 1: Evaluate trade + outcome = evaluator.evaluate_trade( + trade_id=trade.trade_id, + entry_price=entry_price, + exit_price=trade.price, + qty=min(trade.qty, entry_qty), + direction_sign=direction_sign, + hold_duration_seconds=max(hold_duration_seconds, 0.0), + ) + + logger.info( + "Trade outcome: %s PnL=%.2f ROI=%.2f%% profitable=%s", + trade.ticker, + outcome.realized_pnl, + outcome.roi_pct, + outcome.was_profitable, + ) + + # Step 2: Attribute credit + rewards = evaluator.attribute_credit(outcome, strategy_sources) + + # Record trades and rewards for each strategy + for strategy_name, reward in rewards.items(): + adjuster.record_trade(strategy_name) + adjuster.record_reward(strategy_name, reward) + + # Step 3: Load current weights + weights = await _load_strategy_weights(redis) + + # Step 4: Adjust weights for strategies with enough trades + any_adjusted = False + for strategy_name, reward in rewards.items(): + if not adjuster.should_adjust(strategy_name): + logger.debug( + "Strategy %s has %d trades (need %d) -- skipping adjustment", + strategy_name, + adjuster.trade_counts.get(strategy_name, 0), + adjuster.config.min_trades_before_adjustment, + ) + continue + + old_weight = weights.get(strategy_name, adjuster.config.weight_floor) + decayed_reward = adjuster.get_decayed_reward(strategy_name) + new_weight = adjuster.adjust_weight(old_weight, decayed_reward) + weights[strategy_name] = new_weight + any_adjusted = True + + adjustment = WeightAdjustment( + strategy_id=UUID(int=0), # placeholder -- DB would assign real ID + strategy_name=strategy_name, + old_weight=old_weight, + new_weight=new_weight, + reason=f"bandit_adjustment roi={outcome.roi_pct:.2f}%", + reward_signal=reward, + timestamp=datetime.now(timezone.utc), + ) + adjustments.append(adjustment) + counters["adjustments_made"].add(1) + + logger.info( + "Weight adjusted: %s %.4f -> %.4f (reward=%.4f)", + strategy_name, + old_weight, + new_weight, + reward, + ) + + # Step 5: Normalize weights + if any_adjusted: + weights = adjuster.normalize_weights(weights) + await _save_strategy_weights(redis, weights) + + # Track weight drift + for name, weight in weights.items(): + default = 1.0 / len(weights) + drift = abs(weight - default) + counters["weight_drift"].record(drift, {"strategy": name}) + + # Clean up opening trade + await _clear_opening_trade(redis, trade.ticker) + + return adjustments + + +async def run(config: LearningEngineConfig | None = None) -> None: + """Main service loop. + + Connects to Redis, initialises evaluator and weight adjuster, then + continuously consumes from ``trades:executed`` and processes closed + positions through the learning pipeline. + """ + if config is None: + config = LearningEngineConfig() + + logging.basicConfig(level=config.log_level) + logger.info("Starting Learning Engine service") + + # --- Telemetry --- + meter = setup_telemetry("learning-engine", config.otel_metrics_port) + counters = { + "adjustments_made": meter.create_counter( + "adjustments_made", + description="Total strategy weight adjustments performed", + ), + "weight_drift": meter.create_histogram( + "weight_drift", + description="Absolute deviation of each strategy weight from equal weight", + ), + } + + # --- Redis --- + redis = Redis.from_url(config.redis_url, decode_responses=False) + consumer = StreamConsumer(redis, "trades:executed", "learning-engine", "worker-1") + + # --- Components --- + evaluator = TradeEvaluator() + adjuster = WeightAdjuster(config) + + logger.info("Consuming from trades:executed") + + # --- Consume loop --- + async for _msg_id, data in consumer.consume(): + try: + trade = TradeExecution.model_validate(data) + + if trade.status.value != "FILLED": + logger.debug("Skipping non-filled trade: %s", trade.status.value) + continue + + adjustments = await process_trade(trade, redis, evaluator, adjuster, counters) + if adjustments: + logger.info( + "Made %d weight adjustment(s) for %s", + len(adjustments), + trade.ticker, + ) + except Exception: + logger.exception("Error processing trade execution: %s", data) + + +def main() -> None: + """CLI entry point.""" + asyncio.run(run()) + + +if __name__ == "__main__": + main() diff --git a/services/learning_engine/weight_adjuster.py b/services/learning_engine/weight_adjuster.py new file mode 100644 index 0000000..653b51e --- /dev/null +++ b/services/learning_engine/weight_adjuster.py @@ -0,0 +1,191 @@ +"""Weight adjuster -- multi-armed bandit strategy weight updates. + +Implements the exponential moving average weight adjustment formula with +configurable guardrails: minimum trade count, max shift per cycle, +weight floor, and normalization. +""" + +from __future__ import annotations + +import logging +from collections import defaultdict + +from services.learning_engine.config import LearningEngineConfig + +logger = logging.getLogger(__name__) + + +class WeightAdjuster: + """Adjusts strategy weights using a multi-armed bandit approach. + + The update rule is:: + + new_weight = (1 - lr) * current_weight + lr * reward_signal + + Subject to guardrails: + - No adjustment until ``min_trades_before_adjustment`` trades recorded + - Max weight shift clamped to ``max_weight_shift_pct`` + - Weight floor enforced at ``weight_floor`` + - All weights normalized to sum to 1.0 + """ + + def __init__(self, config: LearningEngineConfig) -> None: + self.config = config + self._trade_counts: dict[str, int] = defaultdict(int) + self._reward_history: dict[str, list[float]] = defaultdict(list) + + @property + def trade_counts(self) -> dict[str, int]: + """Return a copy of the current trade counts per strategy.""" + return dict(self._trade_counts) + + def should_adjust(self, strategy_name: str) -> bool: + """Return True if the strategy has enough trades for adjustment. + + Parameters + ---------- + strategy_name: + Name of the strategy to check. + """ + return self._trade_counts[strategy_name] >= self.config.min_trades_before_adjustment + + def adjust_weight(self, current_weight: float, reward_signal: float) -> float: + """Compute a new weight from the current weight and reward signal. + + Applies the exponential moving average formula, clamps the shift + to ``max_weight_shift_pct``, and enforces the weight floor. + + Parameters + ---------- + current_weight: + The strategy's current weight (0..1). + reward_signal: + The reward signal (positive = good, negative = bad). + + Returns + ------- + float + The adjusted weight. + """ + lr = self.config.learning_rate + raw_new = (1 - lr) * current_weight + lr * reward_signal + + # Clamp the shift + shift = raw_new - current_weight + max_shift = self.config.max_weight_shift_pct + if abs(shift) > max_shift: + shift = max_shift if shift > 0 else -max_shift + new_weight = current_weight + shift + + # Apply floor + new_weight = max(new_weight, self.config.weight_floor) + + return new_weight + + def normalize_weights(self, weights: dict[str, float]) -> dict[str, float]: + """Normalize weights so they sum to 1.0, respecting the floor. + + Uses an iterative approach: after normalization, any weight below + the floor is set to the floor, and the remaining weights are + re-normalized from the leftover budget. This repeats until stable. + + Parameters + ---------- + weights: + Mapping of strategy name to raw weight. + + Returns + ------- + dict[str, float] + Normalized weights summing to 1.0. + """ + if not weights: + return {} + + floor = self.config.weight_floor + result = dict(weights) + + for _ in range(10): # iterative convergence (bounded) + total = sum(result.values()) + if total == 0: + # Equal distribution + equal = 1.0 / len(result) + return {k: max(equal, floor) for k in result} + + # Normalize + result = {k: v / total for k, v in result.items()} + + # Check floor violations + floored: set[str] = set() + for k, v in result.items(): + if v < floor: + floored.add(k) + + if not floored: + break + + # Fix floor violations and redistribute + floored_budget = floor * len(floored) + remaining_budget = 1.0 - floored_budget + remaining_total = sum(v for k, v in result.items() if k not in floored) + + for k in floored: + result[k] = floor + + if remaining_total > 0: + scale = remaining_budget / remaining_total + for k in result: + if k not in floored: + result[k] *= scale + + # Final normalization to handle rounding + total = sum(result.values()) + if total > 0 and abs(total - 1.0) > 1e-9: + result = {k: v / total for k, v in result.items()} + + return result + + def record_trade(self, strategy_name: str) -> None: + """Increment the trade count for a strategy. + + Parameters + ---------- + strategy_name: + Name of the strategy that contributed to the trade. + """ + self._trade_counts[strategy_name] += 1 + + def record_reward(self, strategy_name: str, reward: float) -> None: + """Record a reward signal for a strategy, applying recency decay. + + Parameters + ---------- + strategy_name: + Name of the strategy. + reward: + The reward signal to record. + """ + decay = self.config.recency_decay + # Decay existing rewards + self._reward_history[strategy_name] = [ + r * decay for r in self._reward_history[strategy_name] + ] + self._reward_history[strategy_name].append(reward) + + def get_decayed_reward(self, strategy_name: str) -> float: + """Get the average decayed reward for a strategy. + + Parameters + ---------- + strategy_name: + Name of the strategy. + + Returns + ------- + float + The average of all decayed reward signals, or 0.0 if none recorded. + """ + history = self._reward_history.get(strategy_name, []) + if not history: + return 0.0 + return sum(history) / len(history) diff --git a/tests/services/test_learning_engine.py b/tests/services/test_learning_engine.py new file mode 100644 index 0000000..a0e5628 --- /dev/null +++ b/tests/services/test_learning_engine.py @@ -0,0 +1,545 @@ +"""Tests for the Learning Engine service. + +Covers trade evaluation (P&L, ROI for long/short), credit attribution +(proportional, single strategy), weight adjustment (formula, clamping, +floor), normalization (sum to 1, floor respected), minimum trade count +gating, and recency decay. +""" + +from __future__ import annotations + +import uuid + +import pytest + +from services.learning_engine.config import LearningEngineConfig +from services.learning_engine.evaluator import TradeEvaluator +from services.learning_engine.weight_adjuster import WeightAdjuster + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_config(**overrides) -> LearningEngineConfig: + defaults = dict( + learning_rate=0.1, + min_trades_before_adjustment=20, + max_weight_shift_pct=0.10, + weight_floor=0.05, + recency_decay=0.95, + evaluation_window_hours=1, + ) + defaults.update(overrides) + return LearningEngineConfig(**defaults) + + +def _make_trade_id() -> uuid.UUID: + return uuid.uuid4() + + +# --------------------------------------------------------------------------- +# TradeEvaluator — profitable trade +# --------------------------------------------------------------------------- + + +class TestEvaluateProfitableTrade: + """A long trade that gains in price should have positive PnL and ROI.""" + + def test_evaluate_profitable_trade(self): + evaluator = TradeEvaluator() + outcome = evaluator.evaluate_trade( + trade_id=_make_trade_id(), + entry_price=100.0, + exit_price=110.0, + qty=10.0, + direction_sign=1.0, # long + hold_duration_seconds=3600.0, + ) + + # PnL = (110 - 100) * 10 * 1.0 = 100 + assert outcome.realized_pnl == pytest.approx(100.0) + # ROI = 100 / (100 * 10) * 100 = 10% + assert outcome.roi_pct == pytest.approx(10.0) + assert outcome.was_profitable is True + assert outcome.hold_duration_seconds == pytest.approx(3600.0) + + +# --------------------------------------------------------------------------- +# TradeEvaluator — losing trade +# --------------------------------------------------------------------------- + + +class TestEvaluateLosingTrade: + """A long trade that drops should have negative PnL.""" + + def test_evaluate_losing_trade(self): + evaluator = TradeEvaluator() + outcome = evaluator.evaluate_trade( + trade_id=_make_trade_id(), + entry_price=100.0, + exit_price=95.0, + qty=10.0, + direction_sign=1.0, + hold_duration_seconds=7200.0, + ) + + # PnL = (95 - 100) * 10 * 1.0 = -50 + assert outcome.realized_pnl == pytest.approx(-50.0) + # ROI = -50 / (100 * 10) * 100 = -5% + assert outcome.roi_pct == pytest.approx(-5.0) + assert outcome.was_profitable is False + + +# --------------------------------------------------------------------------- +# TradeEvaluator — short trade PnL +# --------------------------------------------------------------------------- + + +class TestEvaluateShortTradePnl: + """A short trade profits when the price drops.""" + + def test_evaluate_short_trade_pnl(self): + evaluator = TradeEvaluator() + outcome = evaluator.evaluate_trade( + trade_id=_make_trade_id(), + entry_price=100.0, + exit_price=90.0, + qty=10.0, + direction_sign=-1.0, # short + hold_duration_seconds=1800.0, + ) + + # PnL = (90 - 100) * 10 * (-1) = 100 + assert outcome.realized_pnl == pytest.approx(100.0) + # ROI = 100 / (100 * 10) * 100 = 10% + assert outcome.roi_pct == pytest.approx(10.0) + assert outcome.was_profitable is True + + def test_short_trade_loss_when_price_rises(self): + """A short trade loses when price rises.""" + evaluator = TradeEvaluator() + outcome = evaluator.evaluate_trade( + trade_id=_make_trade_id(), + entry_price=100.0, + exit_price=110.0, + qty=10.0, + direction_sign=-1.0, + hold_duration_seconds=1800.0, + ) + + # PnL = (110 - 100) * 10 * (-1) = -100 + assert outcome.realized_pnl == pytest.approx(-100.0) + assert outcome.was_profitable is False + + +# --------------------------------------------------------------------------- +# Credit attribution — proportional +# --------------------------------------------------------------------------- + + +class TestCreditAttributionProportional: + """Credit should be distributed proportionally to signal strength.""" + + def test_credit_attribution_proportional(self): + evaluator = TradeEvaluator() + outcome = evaluator.evaluate_trade( + trade_id=_make_trade_id(), + entry_price=100.0, + exit_price=110.0, + qty=10.0, + direction_sign=1.0, + hold_duration_seconds=3600.0, + ) + # ROI = 10% + + strategy_sources = [ + "momentum:LONG:0.8", + "news_driven:LONG:0.2", + ] + rewards = evaluator.attribute_credit(outcome, strategy_sources) + + assert "momentum" in rewards + assert "news_driven" in rewards + + # Total strength = 0.8 + 0.2 = 1.0 + # momentum proportion = 0.8 / 1.0 = 0.8 + # news_driven proportion = 0.2 / 1.0 = 0.2 + # momentum reward = 10.0 * 0.8 = 8.0 + # news_driven reward = 10.0 * 0.2 = 2.0 + assert rewards["momentum"] == pytest.approx(8.0) + assert rewards["news_driven"] == pytest.approx(2.0) + + def test_credit_attribution_three_strategies(self): + """Three strategies should split reward proportionally.""" + evaluator = TradeEvaluator() + outcome = evaluator.evaluate_trade( + trade_id=_make_trade_id(), + entry_price=100.0, + exit_price=105.0, + qty=20.0, + direction_sign=1.0, + hold_duration_seconds=1000.0, + ) + # ROI = 5% + + strategy_sources = [ + "momentum:LONG:0.5", + "mean_reversion:LONG:0.3", + "news_driven:LONG:0.2", + ] + rewards = evaluator.attribute_credit(outcome, strategy_sources) + + total = sum(rewards.values()) + assert total == pytest.approx(5.0) # sum of rewards = ROI + assert rewards["momentum"] == pytest.approx(2.5) + assert rewards["mean_reversion"] == pytest.approx(1.5) + assert rewards["news_driven"] == pytest.approx(1.0) + + +# --------------------------------------------------------------------------- +# Credit attribution — single strategy +# --------------------------------------------------------------------------- + + +class TestCreditAttributionSingleStrategy: + """A single strategy should receive the full reward.""" + + def test_credit_attribution_single_strategy(self): + evaluator = TradeEvaluator() + outcome = evaluator.evaluate_trade( + trade_id=_make_trade_id(), + entry_price=100.0, + exit_price=108.0, + qty=10.0, + direction_sign=1.0, + hold_duration_seconds=3600.0, + ) + # ROI = 8% + + strategy_sources = ["momentum:LONG:0.9"] + rewards = evaluator.attribute_credit(outcome, strategy_sources) + + assert len(rewards) == 1 + assert rewards["momentum"] == pytest.approx(8.0) + + def test_credit_attribution_bare_name(self): + """A bare strategy name (no colon format) defaults to strength 1.0.""" + evaluator = TradeEvaluator() + outcome = evaluator.evaluate_trade( + trade_id=_make_trade_id(), + entry_price=100.0, + exit_price=110.0, + qty=10.0, + direction_sign=1.0, + hold_duration_seconds=3600.0, + ) + # ROI = 10% + + strategy_sources = ["momentum"] + rewards = evaluator.attribute_credit(outcome, strategy_sources) + + assert rewards["momentum"] == pytest.approx(10.0) + + def test_credit_attribution_empty_sources(self): + """Empty strategy sources should return empty rewards.""" + evaluator = TradeEvaluator() + outcome = evaluator.evaluate_trade( + trade_id=_make_trade_id(), + entry_price=100.0, + exit_price=110.0, + qty=10.0, + direction_sign=1.0, + hold_duration_seconds=3600.0, + ) + rewards = evaluator.attribute_credit(outcome, []) + assert rewards == {} + + +# --------------------------------------------------------------------------- +# Weight adjustment — formula +# --------------------------------------------------------------------------- + + +class TestWeightAdjustmentFormula: + """Verify the EMA weight adjustment formula.""" + + def test_weight_adjustment_formula(self): + config = _make_config(learning_rate=0.1, max_weight_shift_pct=1.0) + adjuster = WeightAdjuster(config) + + # new_weight = (1 - 0.1) * 0.3 + 0.1 * 0.5 = 0.27 + 0.05 = 0.32 + new_weight = adjuster.adjust_weight(0.3, 0.5) + assert new_weight == pytest.approx(0.32) + + def test_weight_adjustment_negative_reward(self): + """Negative reward should decrease the weight.""" + config = _make_config(learning_rate=0.1, max_weight_shift_pct=1.0, weight_floor=0.0) + adjuster = WeightAdjuster(config) + + # new_weight = (1 - 0.1) * 0.3 + 0.1 * (-0.5) = 0.27 + (-0.05) = 0.22 + new_weight = adjuster.adjust_weight(0.3, -0.5) + assert new_weight == pytest.approx(0.22) + + def test_weight_adjustment_high_learning_rate(self): + """Higher learning rate should cause larger shifts.""" + config = _make_config(learning_rate=0.5, max_weight_shift_pct=1.0) + adjuster = WeightAdjuster(config) + + # new_weight = (1 - 0.5) * 0.3 + 0.5 * 0.8 = 0.15 + 0.4 = 0.55 + new_weight = adjuster.adjust_weight(0.3, 0.8) + assert new_weight == pytest.approx(0.55) + + +# --------------------------------------------------------------------------- +# Weight adjustment — max shift clamped +# --------------------------------------------------------------------------- + + +class TestWeightAdjustmentMaxShiftClamped: + """Weight shift should be clamped to max_weight_shift_pct.""" + + def test_weight_adjustment_max_shift_clamped(self): + config = _make_config( + learning_rate=0.5, + max_weight_shift_pct=0.05, # only 5% shift allowed + ) + adjuster = WeightAdjuster(config) + + # Raw: (1-0.5)*0.3 + 0.5*0.9 = 0.15 + 0.45 = 0.60 + # Shift = 0.60 - 0.30 = 0.30 > 0.05, so clamp to +0.05 + # Result = 0.30 + 0.05 = 0.35 + new_weight = adjuster.adjust_weight(0.3, 0.9) + assert new_weight == pytest.approx(0.35) + + def test_weight_adjustment_negative_clamped(self): + """Negative shift should also be clamped.""" + config = _make_config( + learning_rate=0.5, + max_weight_shift_pct=0.05, + weight_floor=0.0, + ) + adjuster = WeightAdjuster(config) + + # Raw: (1-0.5)*0.3 + 0.5*(-1.0) = 0.15 - 0.5 = -0.35 + # Shift = -0.35 - 0.30 = -0.65 < -0.05, so clamp to -0.05 + # Result = 0.30 - 0.05 = 0.25 + new_weight = adjuster.adjust_weight(0.3, -1.0) + assert new_weight == pytest.approx(0.25) + + +# --------------------------------------------------------------------------- +# Weight adjustment — floor applied +# --------------------------------------------------------------------------- + + +class TestWeightAdjustmentFloorApplied: + """Weight should never go below the configured floor.""" + + def test_weight_adjustment_floor_applied(self): + config = _make_config( + learning_rate=0.5, + max_weight_shift_pct=1.0, # no clamping + weight_floor=0.05, + ) + adjuster = WeightAdjuster(config) + + # Raw: (1-0.5)*0.1 + 0.5*(-1.0) = 0.05 - 0.5 = -0.45 + # -0.45 < 0.05 floor, so clamp to 0.05 + new_weight = adjuster.adjust_weight(0.1, -1.0) + assert new_weight == pytest.approx(0.05) + + def test_weight_at_floor_stays_at_floor(self): + """Weight already at floor with negative reward stays at floor.""" + config = _make_config(weight_floor=0.05, max_weight_shift_pct=1.0) + adjuster = WeightAdjuster(config) + + new_weight = adjuster.adjust_weight(0.05, -1.0) + assert new_weight >= 0.05 + + +# --------------------------------------------------------------------------- +# Normalize weights — sums to one +# --------------------------------------------------------------------------- + + +class TestNormalizeWeightsSumsToOne: + """Normalized weights should always sum to 1.0.""" + + def test_normalize_weights_sums_to_one(self): + config = _make_config(weight_floor=0.05) + adjuster = WeightAdjuster(config) + + weights = {"a": 0.5, "b": 0.3, "c": 0.2} + normalized = adjuster.normalize_weights(weights) + + assert sum(normalized.values()) == pytest.approx(1.0) + + def test_normalize_unequal_weights(self): + """Very unequal weights should still normalize to 1.0.""" + config = _make_config(weight_floor=0.01) + adjuster = WeightAdjuster(config) + + weights = {"a": 10.0, "b": 1.0, "c": 0.1} + normalized = adjuster.normalize_weights(weights) + + assert sum(normalized.values()) == pytest.approx(1.0) + # Proportions should be maintained + assert normalized["a"] > normalized["b"] > normalized["c"] + + def test_normalize_empty_weights(self): + """Empty weights dict should return empty.""" + config = _make_config() + adjuster = WeightAdjuster(config) + + assert adjuster.normalize_weights({}) == {} + + +# --------------------------------------------------------------------------- +# Normalize weights — respects floor +# --------------------------------------------------------------------------- + + +class TestNormalizeWeightsRespectsFloor: + """No strategy weight should drop below the floor after normalization.""" + + def test_normalize_weights_respects_floor(self): + config = _make_config(weight_floor=0.10) + adjuster = WeightAdjuster(config) + + # One very small weight + weights = {"a": 1.0, "b": 0.01, "c": 0.01} + normalized = adjuster.normalize_weights(weights) + + assert sum(normalized.values()) == pytest.approx(1.0) + for name, w in normalized.items(): + assert w >= 0.10 - 1e-9, f"Weight for {name} is {w}, below floor 0.10" + + def test_normalize_all_below_floor(self): + """If all weights are tiny, they should all get at least the floor.""" + config = _make_config(weight_floor=0.10) + adjuster = WeightAdjuster(config) + + weights = {"a": 0.001, "b": 0.001, "c": 0.001} + normalized = adjuster.normalize_weights(weights) + + assert sum(normalized.values()) == pytest.approx(1.0) + for w in normalized.values(): + assert w >= 0.10 - 1e-9 + + +# --------------------------------------------------------------------------- +# Should adjust — requires min trades +# --------------------------------------------------------------------------- + + +class TestShouldAdjustRequiresMinTrades: + """Weight adjustment should be gated by minimum trade count.""" + + def test_should_adjust_requires_min_trades(self): + config = _make_config(min_trades_before_adjustment=20) + adjuster = WeightAdjuster(config) + + # No trades recorded yet + assert adjuster.should_adjust("momentum") is False + + # Record 19 trades -- still not enough + for _ in range(19): + adjuster.record_trade("momentum") + assert adjuster.should_adjust("momentum") is False + + +# --------------------------------------------------------------------------- +# Should adjust — after enough trades +# --------------------------------------------------------------------------- + + +class TestShouldAdjustAfterEnoughTrades: + """Once enough trades are recorded, adjustment should be allowed.""" + + def test_should_adjust_after_enough_trades(self): + config = _make_config(min_trades_before_adjustment=20) + adjuster = WeightAdjuster(config) + + for _ in range(20): + adjuster.record_trade("momentum") + + assert adjuster.should_adjust("momentum") is True + + def test_other_strategies_still_gated(self): + """Other strategies should still be gated independently.""" + config = _make_config(min_trades_before_adjustment=5) + adjuster = WeightAdjuster(config) + + for _ in range(5): + adjuster.record_trade("momentum") + + assert adjuster.should_adjust("momentum") is True + assert adjuster.should_adjust("news_driven") is False + + +# --------------------------------------------------------------------------- +# Recency decay — recent trades weighted more +# --------------------------------------------------------------------------- + + +class TestRecencyDecayApplied: + """Recent trades should carry more weight than older ones.""" + + def test_recency_decay_applied(self): + config = _make_config(recency_decay=0.5) + adjuster = WeightAdjuster(config) + + # Record two rewards: first one will be decayed + adjuster.record_reward("momentum", 10.0) + # After recording the second, the first gets multiplied by 0.5 + adjuster.record_reward("momentum", 10.0) + + # History after second record: [10.0 * 0.5, 10.0] = [5.0, 10.0] + avg = adjuster.get_decayed_reward("momentum") + assert avg == pytest.approx(7.5) # (5.0 + 10.0) / 2 + + def test_recency_decay_multiple_rounds(self): + """Multiple rounds of decay should compound.""" + config = _make_config(recency_decay=0.5) + adjuster = WeightAdjuster(config) + + # First reward + adjuster.record_reward("momentum", 8.0) + # Second reward: first decayed to 8*0.5 = 4 + adjuster.record_reward("momentum", 8.0) + # Third reward: existing [4, 8] decayed to [2, 4], then add 8 + adjuster.record_reward("momentum", 8.0) + + # History: [2.0, 4.0, 8.0] + avg = adjuster.get_decayed_reward("momentum") + assert avg == pytest.approx((2.0 + 4.0 + 8.0) / 3) + + def test_no_decay_when_no_history(self): + """Getting decayed reward with no history should return 0.""" + config = _make_config() + adjuster = WeightAdjuster(config) + + assert adjuster.get_decayed_reward("momentum") == 0.0 + + def test_high_decay_preserves_more(self): + """A high decay factor (close to 1) preserves older rewards more.""" + config_high = _make_config(recency_decay=0.99) + adjuster_high = WeightAdjuster(config_high) + + config_low = _make_config(recency_decay=0.5) + adjuster_low = WeightAdjuster(config_low) + + for adj in [adjuster_high, adjuster_low]: + adj.record_reward("momentum", 10.0) + adj.record_reward("momentum", 10.0) + + # High decay: [10*0.99, 10] = [9.9, 10] -> avg = 9.95 + # Low decay: [10*0.5, 10] = [5.0, 10] -> avg = 7.5 + high_avg = adjuster_high.get_decayed_reward("momentum") + low_avg = adjuster_low.get_decayed_reward("momentum") + + assert high_avg > low_avg + assert high_avg == pytest.approx(9.95) + assert low_avg == pytest.approx(7.5)