feat: learning engine — multi-armed bandit strategy weight adjustment
This commit is contained in:
parent
1d9900838d
commit
c089bcb92c
6 changed files with 1177 additions and 0 deletions
191
services/learning_engine/weight_adjuster.py
Normal file
191
services/learning_engine/weight_adjuster.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
"""Weight adjuster -- multi-armed bandit strategy weight updates.
|
||||
|
||||
Implements the exponential moving average weight adjustment formula with
|
||||
configurable guardrails: minimum trade count, max shift per cycle,
|
||||
weight floor, and normalization.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from services.learning_engine.config import LearningEngineConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeightAdjuster:
|
||||
"""Adjusts strategy weights using a multi-armed bandit approach.
|
||||
|
||||
The update rule is::
|
||||
|
||||
new_weight = (1 - lr) * current_weight + lr * reward_signal
|
||||
|
||||
Subject to guardrails:
|
||||
- No adjustment until ``min_trades_before_adjustment`` trades recorded
|
||||
- Max weight shift clamped to ``max_weight_shift_pct``
|
||||
- Weight floor enforced at ``weight_floor``
|
||||
- All weights normalized to sum to 1.0
|
||||
"""
|
||||
|
||||
def __init__(self, config: LearningEngineConfig) -> None:
|
||||
self.config = config
|
||||
self._trade_counts: dict[str, int] = defaultdict(int)
|
||||
self._reward_history: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
@property
|
||||
def trade_counts(self) -> dict[str, int]:
|
||||
"""Return a copy of the current trade counts per strategy."""
|
||||
return dict(self._trade_counts)
|
||||
|
||||
def should_adjust(self, strategy_name: str) -> bool:
|
||||
"""Return True if the strategy has enough trades for adjustment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
strategy_name:
|
||||
Name of the strategy to check.
|
||||
"""
|
||||
return self._trade_counts[strategy_name] >= self.config.min_trades_before_adjustment
|
||||
|
||||
def adjust_weight(self, current_weight: float, reward_signal: float) -> float:
|
||||
"""Compute a new weight from the current weight and reward signal.
|
||||
|
||||
Applies the exponential moving average formula, clamps the shift
|
||||
to ``max_weight_shift_pct``, and enforces the weight floor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
current_weight:
|
||||
The strategy's current weight (0..1).
|
||||
reward_signal:
|
||||
The reward signal (positive = good, negative = bad).
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The adjusted weight.
|
||||
"""
|
||||
lr = self.config.learning_rate
|
||||
raw_new = (1 - lr) * current_weight + lr * reward_signal
|
||||
|
||||
# Clamp the shift
|
||||
shift = raw_new - current_weight
|
||||
max_shift = self.config.max_weight_shift_pct
|
||||
if abs(shift) > max_shift:
|
||||
shift = max_shift if shift > 0 else -max_shift
|
||||
new_weight = current_weight + shift
|
||||
|
||||
# Apply floor
|
||||
new_weight = max(new_weight, self.config.weight_floor)
|
||||
|
||||
return new_weight
|
||||
|
||||
def normalize_weights(self, weights: dict[str, float]) -> dict[str, float]:
|
||||
"""Normalize weights so they sum to 1.0, respecting the floor.
|
||||
|
||||
Uses an iterative approach: after normalization, any weight below
|
||||
the floor is set to the floor, and the remaining weights are
|
||||
re-normalized from the leftover budget. This repeats until stable.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
weights:
|
||||
Mapping of strategy name to raw weight.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, float]
|
||||
Normalized weights summing to 1.0.
|
||||
"""
|
||||
if not weights:
|
||||
return {}
|
||||
|
||||
floor = self.config.weight_floor
|
||||
result = dict(weights)
|
||||
|
||||
for _ in range(10): # iterative convergence (bounded)
|
||||
total = sum(result.values())
|
||||
if total == 0:
|
||||
# Equal distribution
|
||||
equal = 1.0 / len(result)
|
||||
return {k: max(equal, floor) for k in result}
|
||||
|
||||
# Normalize
|
||||
result = {k: v / total for k, v in result.items()}
|
||||
|
||||
# Check floor violations
|
||||
floored: set[str] = set()
|
||||
for k, v in result.items():
|
||||
if v < floor:
|
||||
floored.add(k)
|
||||
|
||||
if not floored:
|
||||
break
|
||||
|
||||
# Fix floor violations and redistribute
|
||||
floored_budget = floor * len(floored)
|
||||
remaining_budget = 1.0 - floored_budget
|
||||
remaining_total = sum(v for k, v in result.items() if k not in floored)
|
||||
|
||||
for k in floored:
|
||||
result[k] = floor
|
||||
|
||||
if remaining_total > 0:
|
||||
scale = remaining_budget / remaining_total
|
||||
for k in result:
|
||||
if k not in floored:
|
||||
result[k] *= scale
|
||||
|
||||
# Final normalization to handle rounding
|
||||
total = sum(result.values())
|
||||
if total > 0 and abs(total - 1.0) > 1e-9:
|
||||
result = {k: v / total for k, v in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
def record_trade(self, strategy_name: str) -> None:
|
||||
"""Increment the trade count for a strategy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
strategy_name:
|
||||
Name of the strategy that contributed to the trade.
|
||||
"""
|
||||
self._trade_counts[strategy_name] += 1
|
||||
|
||||
def record_reward(self, strategy_name: str, reward: float) -> None:
|
||||
"""Record a reward signal for a strategy, applying recency decay.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
strategy_name:
|
||||
Name of the strategy.
|
||||
reward:
|
||||
The reward signal to record.
|
||||
"""
|
||||
decay = self.config.recency_decay
|
||||
# Decay existing rewards
|
||||
self._reward_history[strategy_name] = [
|
||||
r * decay for r in self._reward_history[strategy_name]
|
||||
]
|
||||
self._reward_history[strategy_name].append(reward)
|
||||
|
||||
def get_decayed_reward(self, strategy_name: str) -> float:
|
||||
"""Get the average decayed reward for a strategy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
strategy_name:
|
||||
Name of the strategy.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The average of all decayed reward signals, or 0.0 if none recorded.
|
||||
"""
|
||||
history = self._reward_history.get(strategy_name, [])
|
||||
if not history:
|
||||
return 0.0
|
||||
return sum(history) / len(history)
|
||||
Loading…
Add table
Add a link
Reference in a new issue