"""Weight adjuster -- multi-armed bandit strategy weight updates. Implements the exponential moving average weight adjustment formula with configurable guardrails: minimum trade count, max shift per cycle, weight floor, and normalization. """ from __future__ import annotations import logging from collections import defaultdict from services.learning_engine.config import LearningEngineConfig logger = logging.getLogger(__name__) class WeightAdjuster: """Adjusts strategy weights using a multi-armed bandit approach. The update rule is:: new_weight = (1 - lr) * current_weight + lr * reward_signal Subject to guardrails: - No adjustment until ``min_trades_before_adjustment`` trades recorded - Max weight shift clamped to ``max_weight_shift_pct`` - Weight floor enforced at ``weight_floor`` - All weights normalized to sum to 1.0 """ def __init__(self, config: LearningEngineConfig) -> None: self.config = config self._trade_counts: dict[str, int] = defaultdict(int) self._reward_history: dict[str, list[float]] = defaultdict(list) @property def trade_counts(self) -> dict[str, int]: """Return a copy of the current trade counts per strategy.""" return dict(self._trade_counts) def should_adjust(self, strategy_name: str) -> bool: """Return True if the strategy has enough trades for adjustment. Parameters ---------- strategy_name: Name of the strategy to check. """ return self._trade_counts[strategy_name] >= self.config.min_trades_before_adjustment def adjust_weight(self, current_weight: float, reward_signal: float) -> float: """Compute a new weight from the current weight and reward signal. Applies the exponential moving average formula, clamps the shift to ``max_weight_shift_pct``, and enforces the weight floor. Parameters ---------- current_weight: The strategy's current weight (0..1). reward_signal: The reward signal (positive = good, negative = bad). Returns ------- float The adjusted weight. """ lr = self.config.learning_rate raw_new = (1 - lr) * current_weight + lr * reward_signal # Clamp the shift shift = raw_new - current_weight max_shift = self.config.max_weight_shift_pct if abs(shift) > max_shift: shift = max_shift if shift > 0 else -max_shift new_weight = current_weight + shift # Apply floor new_weight = max(new_weight, self.config.weight_floor) return new_weight def normalize_weights(self, weights: dict[str, float]) -> dict[str, float]: """Normalize weights so they sum to 1.0, respecting the floor. Uses an iterative approach: after normalization, any weight below the floor is set to the floor, and the remaining weights are re-normalized from the leftover budget. This repeats until stable. Parameters ---------- weights: Mapping of strategy name to raw weight. Returns ------- dict[str, float] Normalized weights summing to 1.0. """ if not weights: return {} floor = self.config.weight_floor result = dict(weights) for _ in range(10): # iterative convergence (bounded) total = sum(result.values()) if total == 0: # Equal distribution equal = 1.0 / len(result) return {k: max(equal, floor) for k in result} # Normalize result = {k: v / total for k, v in result.items()} # Check floor violations floored: set[str] = set() for k, v in result.items(): if v < floor: floored.add(k) if not floored: break # Fix floor violations and redistribute floored_budget = floor * len(floored) remaining_budget = 1.0 - floored_budget remaining_total = sum(v for k, v in result.items() if k not in floored) for k in floored: result[k] = floor if remaining_total > 0: scale = remaining_budget / remaining_total for k in result: if k not in floored: result[k] *= scale # Final normalization to handle rounding total = sum(result.values()) if total > 0 and abs(total - 1.0) > 1e-9: result = {k: v / total for k, v in result.items()} return result def record_trade(self, strategy_name: str) -> None: """Increment the trade count for a strategy. Parameters ---------- strategy_name: Name of the strategy that contributed to the trade. """ self._trade_counts[strategy_name] += 1 def record_reward(self, strategy_name: str, reward: float) -> None: """Record a reward signal for a strategy, applying recency decay. Parameters ---------- strategy_name: Name of the strategy. reward: The reward signal to record. """ decay = self.config.recency_decay # Decay existing rewards self._reward_history[strategy_name] = [ r * decay for r in self._reward_history[strategy_name] ] self._reward_history[strategy_name].append(reward) def get_decayed_reward(self, strategy_name: str) -> float: """Get the average decayed reward for a strategy. Parameters ---------- strategy_name: Name of the strategy. Returns ------- float The average of all decayed reward signals, or 0.0 if none recorded. """ history = self._reward_history.get(strategy_name, []) if not history: return 0.0 return sum(history) / len(history)