118 lines
3.6 KiB
Python
118 lines
3.6 KiB
Python
"""Weighted ensemble that combines signals from multiple strategies.
|
|
|
|
Runs all registered strategies, collects non-``None`` signals, and computes
|
|
a combined strength via a weighted average. Only emits a ``TradeSignal``
|
|
when the combined strength exceeds a configurable threshold.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
from shared.schemas.trading import (
|
|
MarketSnapshot,
|
|
SentimentContext,
|
|
SignalDirection,
|
|
TradeSignal,
|
|
)
|
|
from shared.strategies.base import BaseStrategy
|
|
|
|
|
|
class WeightedEnsemble:
|
|
"""Combine signals from multiple strategies using weighted averaging.
|
|
|
|
Parameters
|
|
----------
|
|
strategies:
|
|
The list of strategy instances to evaluate.
|
|
threshold:
|
|
Minimum combined strength required to emit a signal (default 0.3).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
strategies: list[BaseStrategy],
|
|
threshold: float = 0.3,
|
|
) -> None:
|
|
self.strategies = strategies
|
|
self.threshold = threshold
|
|
|
|
async def evaluate(
|
|
self,
|
|
ticker: str,
|
|
market: MarketSnapshot,
|
|
sentiment: SentimentContext | None,
|
|
weights: dict[str, float],
|
|
) -> TradeSignal | None:
|
|
"""Run all strategies and return a combined signal, or ``None``.
|
|
|
|
Parameters
|
|
----------
|
|
ticker:
|
|
The stock ticker being evaluated.
|
|
market:
|
|
Current market snapshot including price, SMA, RSI.
|
|
sentiment:
|
|
Aggregated sentiment context (may be ``None``).
|
|
weights:
|
|
Mapping from strategy name to its weight.
|
|
"""
|
|
# Step 1: run all strategies, collect (strategy, signal) pairs
|
|
signals: list[tuple[BaseStrategy, TradeSignal]] = []
|
|
for strategy in self.strategies:
|
|
signal = await strategy.evaluate(ticker, market, sentiment)
|
|
if signal is not None:
|
|
signals.append((strategy, signal))
|
|
|
|
if not signals:
|
|
return None
|
|
|
|
# Step 2: compute weighted sum
|
|
weighted_sum = 0.0
|
|
total_weight = 0.0
|
|
for strategy, signal in signals:
|
|
w = weights.get(strategy.name, 0.1)
|
|
direction_sign = 1.0 if signal.direction == SignalDirection.LONG else -1.0
|
|
weighted_sum += signal.strength * direction_sign * w
|
|
total_weight += w
|
|
|
|
if total_weight == 0:
|
|
return None
|
|
|
|
# Step 3: combined strength
|
|
combined_strength = abs(weighted_sum) / total_weight
|
|
|
|
if combined_strength < self.threshold:
|
|
return None
|
|
|
|
# Step 4: determine direction from the sign of the weighted sum
|
|
if weighted_sum > 0:
|
|
direction = SignalDirection.LONG
|
|
elif weighted_sum < 0:
|
|
direction = SignalDirection.SHORT
|
|
else:
|
|
return None
|
|
|
|
# Step 5: build strategy_sources with individual contributions
|
|
strategy_sources = [
|
|
f"{strategy.name}:{signal.direction.value}:{signal.strength:.4f}"
|
|
for strategy, signal in signals
|
|
]
|
|
|
|
# Carry forward sentiment context if available
|
|
sentiment_ctx = None
|
|
if sentiment is not None:
|
|
sentiment_ctx = {
|
|
"avg_score": sentiment.avg_score,
|
|
"article_count": sentiment.article_count,
|
|
"avg_confidence": sentiment.avg_confidence,
|
|
}
|
|
|
|
return TradeSignal(
|
|
ticker=ticker,
|
|
direction=direction,
|
|
strength=round(min(combined_strength, 1.0), 4),
|
|
strategy_sources=strategy_sources,
|
|
sentiment_context=sentiment_ctx,
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|