feat: learning engine — multi-armed bandit strategy weight adjustment
This commit is contained in:
parent
1d9900838d
commit
c089bcb92c
6 changed files with 1177 additions and 0 deletions
1
services/learning_engine/__init__.py
Normal file
1
services/learning_engine/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Learning Engine service -- multi-armed bandit strategy weight adjustment."""
|
||||
16
services/learning_engine/config.py
Normal file
16
services/learning_engine/config.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""Configuration for the learning engine service."""
|
||||
|
||||
from shared.config import BaseConfig
|
||||
|
||||
|
||||
class LearningEngineConfig(BaseConfig):
|
||||
"""Extends BaseConfig with learning-engine-specific settings."""
|
||||
|
||||
learning_rate: float = 0.1
|
||||
min_trades_before_adjustment: int = 20
|
||||
max_weight_shift_pct: float = 0.10
|
||||
weight_floor: float = 0.05
|
||||
recency_decay: float = 0.95
|
||||
evaluation_window_hours: int = 1
|
||||
|
||||
model_config = {"env_prefix": "TRADING_"}
|
||||
120
services/learning_engine/evaluator.py
Normal file
120
services/learning_engine/evaluator.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
"""Trade evaluator -- computes outcomes and attributes credit to strategies.
|
||||
|
||||
Given a closed trade (exit), this module computes realized P&L, ROI, and
|
||||
distributes reward signals to each contributing strategy proportionally
|
||||
to its signal strength.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from shared.schemas.learning import TradeOutcomeSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TradeEvaluator:
|
||||
"""Evaluates closed trades and attributes credit to strategies."""
|
||||
|
||||
def evaluate_trade(
|
||||
self,
|
||||
trade_id: UUID,
|
||||
entry_price: float,
|
||||
exit_price: float,
|
||||
qty: float,
|
||||
direction_sign: float,
|
||||
hold_duration_seconds: float,
|
||||
) -> TradeOutcomeSchema:
|
||||
"""Compute the outcome of a closed trade.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_id:
|
||||
Unique identifier of the closing trade.
|
||||
entry_price:
|
||||
The price at which the position was opened.
|
||||
exit_price:
|
||||
The price at which the position was closed.
|
||||
qty:
|
||||
Number of shares traded.
|
||||
direction_sign:
|
||||
+1.0 for long positions, -1.0 for short positions.
|
||||
hold_duration_seconds:
|
||||
How long the position was held, in seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TradeOutcomeSchema
|
||||
The evaluated outcome including realized P&L and ROI.
|
||||
"""
|
||||
realized_pnl = (exit_price - entry_price) * qty * direction_sign
|
||||
cost_basis = entry_price * qty
|
||||
roi_pct = (realized_pnl / cost_basis * 100.0) if cost_basis != 0 else 0.0
|
||||
was_profitable = realized_pnl > 0
|
||||
|
||||
return TradeOutcomeSchema(
|
||||
trade_id=trade_id,
|
||||
hold_duration_seconds=hold_duration_seconds,
|
||||
realized_pnl=realized_pnl,
|
||||
roi_pct=roi_pct,
|
||||
was_profitable=was_profitable,
|
||||
)
|
||||
|
||||
def attribute_credit(
|
||||
self,
|
||||
outcome: TradeOutcomeSchema,
|
||||
strategy_sources: list[str],
|
||||
) -> dict[str, float]:
|
||||
"""Distribute reward signal to contributing strategies.
|
||||
|
||||
Parses ``strategy_sources`` entries which may be formatted as either:
|
||||
- ``"name:DIRECTION:strength"`` (full format from the ensemble)
|
||||
- ``"name"`` (bare strategy name -- defaults to strength 1.0)
|
||||
|
||||
The reward signal is the trade's ROI percentage distributed
|
||||
proportionally to each strategy's signal strength.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outcome:
|
||||
The evaluated trade outcome.
|
||||
strategy_sources:
|
||||
List of strategy source strings from the signal.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, float]
|
||||
Mapping of strategy name to its reward signal.
|
||||
"""
|
||||
if not strategy_sources:
|
||||
return {}
|
||||
|
||||
# Parse strengths from strategy_sources
|
||||
parsed: list[tuple[str, float]] = []
|
||||
for source in strategy_sources:
|
||||
parts = source.split(":")
|
||||
name = parts[0]
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
strength = float(parts[2])
|
||||
except (ValueError, IndexError):
|
||||
strength = 1.0
|
||||
else:
|
||||
strength = 1.0
|
||||
parsed.append((name, strength))
|
||||
|
||||
# Compute total strength for proportional distribution
|
||||
total_strength = sum(s for _, s in parsed)
|
||||
if total_strength == 0:
|
||||
return {}
|
||||
|
||||
# Distribute reward proportionally
|
||||
rewards: dict[str, float] = {}
|
||||
for name, strength in parsed:
|
||||
proportion = strength / total_strength
|
||||
reward_signal = outcome.roi_pct * proportion
|
||||
rewards[name] = reward_signal
|
||||
|
||||
return rewards
|
||||
304
services/learning_engine/main.py
Normal file
304
services/learning_engine/main.py
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
"""Learning Engine service -- main entry point.
|
||||
|
||||
Consumes ``trades:executed`` from Redis Streams, evaluates closed positions,
|
||||
attributes credit to contributing strategies, adjusts strategy weights via
|
||||
a multi-armed bandit approach, and stores all adjustments for auditability.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from services.learning_engine.config import LearningEngineConfig
|
||||
from services.learning_engine.evaluator import TradeEvaluator
|
||||
from services.learning_engine.weight_adjuster import WeightAdjuster
|
||||
from shared.redis_streams import StreamConsumer
|
||||
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
|
||||
from shared.schemas.trading import OrderSide, TradeExecution
|
||||
from shared.telemetry import setup_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis key for cached strategy weights
|
||||
_STRATEGY_WEIGHTS_KEY = "strategy:weights"
|
||||
|
||||
|
||||
async def _load_strategy_weights(redis: Redis) -> dict[str, float]:
|
||||
"""Load current strategy weights from Redis cache.
|
||||
|
||||
Falls back to equal weights for the three default strategies
|
||||
if no cached weights exist.
|
||||
"""
|
||||
raw = await redis.get(_STRATEGY_WEIGHTS_KEY)
|
||||
if raw:
|
||||
return json.loads(raw)
|
||||
# Default equal weights
|
||||
return {
|
||||
"momentum": 0.333,
|
||||
"mean_reversion": 0.333,
|
||||
"news_driven": 0.334,
|
||||
}
|
||||
|
||||
|
||||
async def _save_strategy_weights(redis: Redis, weights: dict[str, float]) -> None:
|
||||
"""Persist strategy weights to Redis cache."""
|
||||
await redis.set(_STRATEGY_WEIGHTS_KEY, json.dumps(weights))
|
||||
|
||||
|
||||
async def _find_opening_trade(
|
||||
redis: Redis,
|
||||
ticker: str,
|
||||
closing_side: OrderSide,
|
||||
) -> dict | None:
|
||||
"""Look up the opening trade for a position close.
|
||||
|
||||
Searches the ``positions:history`` Redis hash for the ticker.
|
||||
Returns the stored entry data or None if not found.
|
||||
"""
|
||||
raw = await redis.hget("positions:history", ticker)
|
||||
if raw:
|
||||
return json.loads(raw)
|
||||
return None
|
||||
|
||||
|
||||
async def _store_opening_trade(
|
||||
redis: Redis,
|
||||
ticker: str,
|
||||
trade_data: dict,
|
||||
) -> None:
|
||||
"""Store a trade as the opening trade for a ticker."""
|
||||
await redis.hset("positions:history", ticker, json.dumps(trade_data))
|
||||
|
||||
|
||||
async def _clear_opening_trade(redis: Redis, ticker: str) -> None:
|
||||
"""Clear the stored opening trade after position close."""
|
||||
await redis.hdel("positions:history", ticker)
|
||||
|
||||
|
||||
def _is_position_close(trade: TradeExecution, opening: dict | None) -> bool:
|
||||
"""Determine if a trade closes a position.
|
||||
|
||||
A trade closes a position if there is an existing opening trade
|
||||
on the opposite side for the same ticker.
|
||||
"""
|
||||
if opening is None:
|
||||
return False
|
||||
opening_side = opening.get("side", "")
|
||||
if trade.side == OrderSide.SELL and opening_side == OrderSide.BUY.value:
|
||||
return True
|
||||
if trade.side == OrderSide.BUY and opening_side == OrderSide.SELL.value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def process_trade(
|
||||
trade: TradeExecution,
|
||||
redis: Redis,
|
||||
evaluator: TradeEvaluator,
|
||||
adjuster: WeightAdjuster,
|
||||
counters: dict,
|
||||
) -> list[WeightAdjustment]:
|
||||
"""Process a single trade execution.
|
||||
|
||||
If the trade closes a position:
|
||||
1. Evaluate the trade outcome (P&L, ROI)
|
||||
2. Attribute credit to contributing strategies
|
||||
3. Adjust weights for strategies with enough trades
|
||||
4. Normalize all weights
|
||||
5. Store adjustments and update cached weights
|
||||
|
||||
Returns a list of weight adjustments made (empty if none).
|
||||
"""
|
||||
adjustments: list[WeightAdjustment] = []
|
||||
|
||||
# Look up opening trade
|
||||
opening = await _find_opening_trade(redis, trade.ticker, trade.side)
|
||||
|
||||
if not _is_position_close(trade, opening):
|
||||
# This is an opening trade -- store it for later reference
|
||||
await _store_opening_trade(
|
||||
redis,
|
||||
trade.ticker,
|
||||
{
|
||||
"trade_id": str(trade.trade_id),
|
||||
"side": trade.side.value,
|
||||
"price": trade.price,
|
||||
"qty": trade.qty,
|
||||
"timestamp": trade.timestamp.isoformat(),
|
||||
"strategy_sources": [], # would come from signal
|
||||
},
|
||||
)
|
||||
return adjustments
|
||||
|
||||
# --- Position close detected ---
|
||||
entry_price = opening["price"]
|
||||
entry_qty = opening.get("qty", trade.qty)
|
||||
entry_time_str = opening.get("timestamp", "")
|
||||
strategy_sources = opening.get("strategy_sources", [])
|
||||
|
||||
# Determine direction sign
|
||||
opening_side = opening.get("side", "")
|
||||
direction_sign = 1.0 if opening_side == OrderSide.BUY.value else -1.0
|
||||
|
||||
# Compute hold duration
|
||||
hold_duration_seconds = 0.0
|
||||
if entry_time_str:
|
||||
try:
|
||||
entry_time = datetime.fromisoformat(entry_time_str)
|
||||
hold_duration_seconds = (trade.timestamp - entry_time).total_seconds()
|
||||
except (ValueError, TypeError):
|
||||
hold_duration_seconds = 0.0
|
||||
|
||||
# Step 1: Evaluate trade
|
||||
outcome = evaluator.evaluate_trade(
|
||||
trade_id=trade.trade_id,
|
||||
entry_price=entry_price,
|
||||
exit_price=trade.price,
|
||||
qty=min(trade.qty, entry_qty),
|
||||
direction_sign=direction_sign,
|
||||
hold_duration_seconds=max(hold_duration_seconds, 0.0),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Trade outcome: %s PnL=%.2f ROI=%.2f%% profitable=%s",
|
||||
trade.ticker,
|
||||
outcome.realized_pnl,
|
||||
outcome.roi_pct,
|
||||
outcome.was_profitable,
|
||||
)
|
||||
|
||||
# Step 2: Attribute credit
|
||||
rewards = evaluator.attribute_credit(outcome, strategy_sources)
|
||||
|
||||
# Record trades and rewards for each strategy
|
||||
for strategy_name, reward in rewards.items():
|
||||
adjuster.record_trade(strategy_name)
|
||||
adjuster.record_reward(strategy_name, reward)
|
||||
|
||||
# Step 3: Load current weights
|
||||
weights = await _load_strategy_weights(redis)
|
||||
|
||||
# Step 4: Adjust weights for strategies with enough trades
|
||||
any_adjusted = False
|
||||
for strategy_name, reward in rewards.items():
|
||||
if not adjuster.should_adjust(strategy_name):
|
||||
logger.debug(
|
||||
"Strategy %s has %d trades (need %d) -- skipping adjustment",
|
||||
strategy_name,
|
||||
adjuster.trade_counts.get(strategy_name, 0),
|
||||
adjuster.config.min_trades_before_adjustment,
|
||||
)
|
||||
continue
|
||||
|
||||
old_weight = weights.get(strategy_name, adjuster.config.weight_floor)
|
||||
decayed_reward = adjuster.get_decayed_reward(strategy_name)
|
||||
new_weight = adjuster.adjust_weight(old_weight, decayed_reward)
|
||||
weights[strategy_name] = new_weight
|
||||
any_adjusted = True
|
||||
|
||||
adjustment = WeightAdjustment(
|
||||
strategy_id=UUID(int=0), # placeholder -- DB would assign real ID
|
||||
strategy_name=strategy_name,
|
||||
old_weight=old_weight,
|
||||
new_weight=new_weight,
|
||||
reason=f"bandit_adjustment roi={outcome.roi_pct:.2f}%",
|
||||
reward_signal=reward,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
adjustments.append(adjustment)
|
||||
counters["adjustments_made"].add(1)
|
||||
|
||||
logger.info(
|
||||
"Weight adjusted: %s %.4f -> %.4f (reward=%.4f)",
|
||||
strategy_name,
|
||||
old_weight,
|
||||
new_weight,
|
||||
reward,
|
||||
)
|
||||
|
||||
# Step 5: Normalize weights
|
||||
if any_adjusted:
|
||||
weights = adjuster.normalize_weights(weights)
|
||||
await _save_strategy_weights(redis, weights)
|
||||
|
||||
# Track weight drift
|
||||
for name, weight in weights.items():
|
||||
default = 1.0 / len(weights)
|
||||
drift = abs(weight - default)
|
||||
counters["weight_drift"].record(drift, {"strategy": name})
|
||||
|
||||
# Clean up opening trade
|
||||
await _clear_opening_trade(redis, trade.ticker)
|
||||
|
||||
return adjustments
|
||||
|
||||
|
||||
async def run(config: LearningEngineConfig | None = None) -> None:
|
||||
"""Main service loop.
|
||||
|
||||
Connects to Redis, initialises evaluator and weight adjuster, then
|
||||
continuously consumes from ``trades:executed`` and processes closed
|
||||
positions through the learning pipeline.
|
||||
"""
|
||||
if config is None:
|
||||
config = LearningEngineConfig()
|
||||
|
||||
logging.basicConfig(level=config.log_level)
|
||||
logger.info("Starting Learning Engine service")
|
||||
|
||||
# --- Telemetry ---
|
||||
meter = setup_telemetry("learning-engine", config.otel_metrics_port)
|
||||
counters = {
|
||||
"adjustments_made": meter.create_counter(
|
||||
"adjustments_made",
|
||||
description="Total strategy weight adjustments performed",
|
||||
),
|
||||
"weight_drift": meter.create_histogram(
|
||||
"weight_drift",
|
||||
description="Absolute deviation of each strategy weight from equal weight",
|
||||
),
|
||||
}
|
||||
|
||||
# --- Redis ---
|
||||
redis = Redis.from_url(config.redis_url, decode_responses=False)
|
||||
consumer = StreamConsumer(redis, "trades:executed", "learning-engine", "worker-1")
|
||||
|
||||
# --- Components ---
|
||||
evaluator = TradeEvaluator()
|
||||
adjuster = WeightAdjuster(config)
|
||||
|
||||
logger.info("Consuming from trades:executed")
|
||||
|
||||
# --- Consume loop ---
|
||||
async for _msg_id, data in consumer.consume():
|
||||
try:
|
||||
trade = TradeExecution.model_validate(data)
|
||||
|
||||
if trade.status.value != "FILLED":
|
||||
logger.debug("Skipping non-filled trade: %s", trade.status.value)
|
||||
continue
|
||||
|
||||
adjustments = await process_trade(trade, redis, evaluator, adjuster, counters)
|
||||
if adjustments:
|
||||
logger.info(
|
||||
"Made %d weight adjustment(s) for %s",
|
||||
len(adjustments),
|
||||
trade.ticker,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error processing trade execution: %s", data)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""CLI entry point."""
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
191
services/learning_engine/weight_adjuster.py
Normal file
191
services/learning_engine/weight_adjuster.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
"""Weight adjuster -- multi-armed bandit strategy weight updates.
|
||||
|
||||
Implements the exponential moving average weight adjustment formula with
|
||||
configurable guardrails: minimum trade count, max shift per cycle,
|
||||
weight floor, and normalization.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from services.learning_engine.config import LearningEngineConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeightAdjuster:
|
||||
"""Adjusts strategy weights using a multi-armed bandit approach.
|
||||
|
||||
The update rule is::
|
||||
|
||||
new_weight = (1 - lr) * current_weight + lr * reward_signal
|
||||
|
||||
Subject to guardrails:
|
||||
- No adjustment until ``min_trades_before_adjustment`` trades recorded
|
||||
- Max weight shift clamped to ``max_weight_shift_pct``
|
||||
- Weight floor enforced at ``weight_floor``
|
||||
- All weights normalized to sum to 1.0
|
||||
"""
|
||||
|
||||
def __init__(self, config: LearningEngineConfig) -> None:
|
||||
self.config = config
|
||||
self._trade_counts: dict[str, int] = defaultdict(int)
|
||||
self._reward_history: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
@property
|
||||
def trade_counts(self) -> dict[str, int]:
|
||||
"""Return a copy of the current trade counts per strategy."""
|
||||
return dict(self._trade_counts)
|
||||
|
||||
def should_adjust(self, strategy_name: str) -> bool:
|
||||
"""Return True if the strategy has enough trades for adjustment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
strategy_name:
|
||||
Name of the strategy to check.
|
||||
"""
|
||||
return self._trade_counts[strategy_name] >= self.config.min_trades_before_adjustment
|
||||
|
||||
def adjust_weight(self, current_weight: float, reward_signal: float) -> float:
|
||||
"""Compute a new weight from the current weight and reward signal.
|
||||
|
||||
Applies the exponential moving average formula, clamps the shift
|
||||
to ``max_weight_shift_pct``, and enforces the weight floor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
current_weight:
|
||||
The strategy's current weight (0..1).
|
||||
reward_signal:
|
||||
The reward signal (positive = good, negative = bad).
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The adjusted weight.
|
||||
"""
|
||||
lr = self.config.learning_rate
|
||||
raw_new = (1 - lr) * current_weight + lr * reward_signal
|
||||
|
||||
# Clamp the shift
|
||||
shift = raw_new - current_weight
|
||||
max_shift = self.config.max_weight_shift_pct
|
||||
if abs(shift) > max_shift:
|
||||
shift = max_shift if shift > 0 else -max_shift
|
||||
new_weight = current_weight + shift
|
||||
|
||||
# Apply floor
|
||||
new_weight = max(new_weight, self.config.weight_floor)
|
||||
|
||||
return new_weight
|
||||
|
||||
def normalize_weights(self, weights: dict[str, float]) -> dict[str, float]:
|
||||
"""Normalize weights so they sum to 1.0, respecting the floor.
|
||||
|
||||
Uses an iterative approach: after normalization, any weight below
|
||||
the floor is set to the floor, and the remaining weights are
|
||||
re-normalized from the leftover budget. This repeats until stable.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
weights:
|
||||
Mapping of strategy name to raw weight.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, float]
|
||||
Normalized weights summing to 1.0.
|
||||
"""
|
||||
if not weights:
|
||||
return {}
|
||||
|
||||
floor = self.config.weight_floor
|
||||
result = dict(weights)
|
||||
|
||||
for _ in range(10): # iterative convergence (bounded)
|
||||
total = sum(result.values())
|
||||
if total == 0:
|
||||
# Equal distribution
|
||||
equal = 1.0 / len(result)
|
||||
return {k: max(equal, floor) for k in result}
|
||||
|
||||
# Normalize
|
||||
result = {k: v / total for k, v in result.items()}
|
||||
|
||||
# Check floor violations
|
||||
floored: set[str] = set()
|
||||
for k, v in result.items():
|
||||
if v < floor:
|
||||
floored.add(k)
|
||||
|
||||
if not floored:
|
||||
break
|
||||
|
||||
# Fix floor violations and redistribute
|
||||
floored_budget = floor * len(floored)
|
||||
remaining_budget = 1.0 - floored_budget
|
||||
remaining_total = sum(v for k, v in result.items() if k not in floored)
|
||||
|
||||
for k in floored:
|
||||
result[k] = floor
|
||||
|
||||
if remaining_total > 0:
|
||||
scale = remaining_budget / remaining_total
|
||||
for k in result:
|
||||
if k not in floored:
|
||||
result[k] *= scale
|
||||
|
||||
# Final normalization to handle rounding
|
||||
total = sum(result.values())
|
||||
if total > 0 and abs(total - 1.0) > 1e-9:
|
||||
result = {k: v / total for k, v in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
def record_trade(self, strategy_name: str) -> None:
|
||||
"""Increment the trade count for a strategy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
strategy_name:
|
||||
Name of the strategy that contributed to the trade.
|
||||
"""
|
||||
self._trade_counts[strategy_name] += 1
|
||||
|
||||
def record_reward(self, strategy_name: str, reward: float) -> None:
|
||||
"""Record a reward signal for a strategy, applying recency decay.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
strategy_name:
|
||||
Name of the strategy.
|
||||
reward:
|
||||
The reward signal to record.
|
||||
"""
|
||||
decay = self.config.recency_decay
|
||||
# Decay existing rewards
|
||||
self._reward_history[strategy_name] = [
|
||||
r * decay for r in self._reward_history[strategy_name]
|
||||
]
|
||||
self._reward_history[strategy_name].append(reward)
|
||||
|
||||
def get_decayed_reward(self, strategy_name: str) -> float:
|
||||
"""Get the average decayed reward for a strategy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
strategy_name:
|
||||
Name of the strategy.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The average of all decayed reward signals, or 0.0 if none recorded.
|
||||
"""
|
||||
history = self._reward_history.get(strategy_name, [])
|
||||
if not history:
|
||||
return 0.0
|
||||
return sum(history) / len(history)
|
||||
Loading…
Add table
Add a link
Reference in a new issue