feat: signal generator — weighted ensemble with market data
This commit is contained in:
parent
e483e9987f
commit
f3e5fc944d
11 changed files with 1013 additions and 0 deletions
1
services/signal_generator/__init__.py
Normal file
1
services/signal_generator/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Signal Generator service — weighted ensemble of trading strategies."""
|
||||||
14
services/signal_generator/config.py
Normal file
14
services/signal_generator/config.py
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
"""Configuration for the signal generator service."""
|
||||||
|
|
||||||
|
from shared.config import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SignalGeneratorConfig(BaseConfig):
|
||||||
|
"""Extends BaseConfig with signal-generator-specific settings."""
|
||||||
|
|
||||||
|
alpaca_api_key: str = ""
|
||||||
|
alpaca_secret_key: str = ""
|
||||||
|
signal_strength_threshold: float = 0.3
|
||||||
|
watchlist: list[str] = []
|
||||||
|
|
||||||
|
model_config = {"env_prefix": "TRADING_"}
|
||||||
118
services/signal_generator/ensemble.py
Normal file
118
services/signal_generator/ensemble.py
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
"""Weighted ensemble that combines signals from multiple strategies.
|
||||||
|
|
||||||
|
Runs all registered strategies, collects non-``None`` signals, and computes
|
||||||
|
a combined strength via a weighted average. Only emits a ``TradeSignal``
|
||||||
|
when the combined strength exceeds a configurable threshold.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from shared.schemas.trading import (
|
||||||
|
MarketSnapshot,
|
||||||
|
SentimentContext,
|
||||||
|
SignalDirection,
|
||||||
|
TradeSignal,
|
||||||
|
)
|
||||||
|
from shared.strategies.base import BaseStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedEnsemble:
|
||||||
|
"""Combine signals from multiple strategies using weighted averaging.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
strategies:
|
||||||
|
The list of strategy instances to evaluate.
|
||||||
|
threshold:
|
||||||
|
Minimum combined strength required to emit a signal (default 0.3).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
strategies: list[BaseStrategy],
|
||||||
|
threshold: float = 0.3,
|
||||||
|
) -> None:
|
||||||
|
self.strategies = strategies
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
async def evaluate(
|
||||||
|
self,
|
||||||
|
ticker: str,
|
||||||
|
market: MarketSnapshot,
|
||||||
|
sentiment: SentimentContext | None,
|
||||||
|
weights: dict[str, float],
|
||||||
|
) -> TradeSignal | None:
|
||||||
|
"""Run all strategies and return a combined signal, or ``None``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ticker:
|
||||||
|
The stock ticker being evaluated.
|
||||||
|
market:
|
||||||
|
Current market snapshot including price, SMA, RSI.
|
||||||
|
sentiment:
|
||||||
|
Aggregated sentiment context (may be ``None``).
|
||||||
|
weights:
|
||||||
|
Mapping from strategy name to its weight.
|
||||||
|
"""
|
||||||
|
# Step 1: run all strategies, collect (strategy, signal) pairs
|
||||||
|
signals: list[tuple[BaseStrategy, TradeSignal]] = []
|
||||||
|
for strategy in self.strategies:
|
||||||
|
signal = await strategy.evaluate(ticker, market, sentiment)
|
||||||
|
if signal is not None:
|
||||||
|
signals.append((strategy, signal))
|
||||||
|
|
||||||
|
if not signals:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Step 2: compute weighted sum
|
||||||
|
weighted_sum = 0.0
|
||||||
|
total_weight = 0.0
|
||||||
|
for strategy, signal in signals:
|
||||||
|
w = weights.get(strategy.name, 0.1)
|
||||||
|
direction_sign = 1.0 if signal.direction == SignalDirection.LONG else -1.0
|
||||||
|
weighted_sum += signal.strength * direction_sign * w
|
||||||
|
total_weight += w
|
||||||
|
|
||||||
|
if total_weight == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Step 3: combined strength
|
||||||
|
combined_strength = abs(weighted_sum) / total_weight
|
||||||
|
|
||||||
|
if combined_strength < self.threshold:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Step 4: determine direction from the sign of the weighted sum
|
||||||
|
if weighted_sum > 0:
|
||||||
|
direction = SignalDirection.LONG
|
||||||
|
elif weighted_sum < 0:
|
||||||
|
direction = SignalDirection.SHORT
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Step 5: build strategy_sources with individual contributions
|
||||||
|
strategy_sources = [
|
||||||
|
f"{strategy.name}:{signal.direction.value}:{signal.strength:.4f}"
|
||||||
|
for strategy, signal in signals
|
||||||
|
]
|
||||||
|
|
||||||
|
# Carry forward sentiment context if available
|
||||||
|
sentiment_ctx = None
|
||||||
|
if sentiment is not None:
|
||||||
|
sentiment_ctx = {
|
||||||
|
"avg_score": sentiment.avg_score,
|
||||||
|
"article_count": sentiment.article_count,
|
||||||
|
"avg_confidence": sentiment.avg_confidence,
|
||||||
|
}
|
||||||
|
|
||||||
|
return TradeSignal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=direction,
|
||||||
|
strength=round(min(combined_strength, 1.0), 4),
|
||||||
|
strategy_sources=strategy_sources,
|
||||||
|
sentiment_context=sentiment_ctx,
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
165
services/signal_generator/main.py
Normal file
165
services/signal_generator/main.py
Normal file
|
|
@ -0,0 +1,165 @@
|
||||||
|
"""Signal Generator service -- main entry point.
|
||||||
|
|
||||||
|
Consumes ``news:scored`` articles from Redis Streams, updates sentiment
|
||||||
|
context per ticker, runs the weighted ensemble of trading strategies, and
|
||||||
|
publishes qualifying ``TradeSignal`` messages to ``signals:generated``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from services.signal_generator.config import SignalGeneratorConfig
|
||||||
|
from services.signal_generator.ensemble import WeightedEnsemble
|
||||||
|
from services.signal_generator.market_data import MarketDataManager
|
||||||
|
from shared.redis_streams import StreamConsumer, StreamPublisher
|
||||||
|
from shared.schemas.news import ScoredArticle
|
||||||
|
from shared.schemas.trading import SentimentContext
|
||||||
|
from shared.strategies import MeanReversionStrategy, MomentumStrategy, NewsDrivenStrategy
|
||||||
|
from shared.telemetry import setup_telemetry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Maximum number of recent sentiment scores to retain per ticker
|
||||||
|
_MAX_SENTIMENT_SCORES = 50
|
||||||
|
|
||||||
|
# Default strategy weights (equal weighting)
|
||||||
|
_DEFAULT_WEIGHTS: dict[str, float] = {
|
||||||
|
"momentum": 0.333,
|
||||||
|
"mean_reversion": 0.333,
|
||||||
|
"news_driven": 0.334,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_sentiment_context(
|
||||||
|
ticker: str,
|
||||||
|
scores: deque[float],
|
||||||
|
confidences: deque[float],
|
||||||
|
) -> SentimentContext:
|
||||||
|
"""Build a ``SentimentContext`` from accumulated per-ticker scores."""
|
||||||
|
score_list = list(scores)
|
||||||
|
conf_list = list(confidences)
|
||||||
|
return SentimentContext(
|
||||||
|
ticker=ticker,
|
||||||
|
avg_score=sum(score_list) / len(score_list) if score_list else 0.0,
|
||||||
|
article_count=len(score_list),
|
||||||
|
recent_scores=score_list[-10:],
|
||||||
|
avg_confidence=sum(conf_list) / len(conf_list) if conf_list else 0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run(config: SignalGeneratorConfig | None = None) -> None:
|
||||||
|
"""Main service loop.
|
||||||
|
|
||||||
|
Connects to Redis, initialises strategies and telemetry, then
|
||||||
|
continuously consumes from ``news:scored`` and publishes qualifying
|
||||||
|
signals to ``signals:generated``.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = SignalGeneratorConfig()
|
||||||
|
|
||||||
|
logging.basicConfig(level=config.log_level)
|
||||||
|
logger.info("Starting Signal Generator service")
|
||||||
|
|
||||||
|
# --- Telemetry ---
|
||||||
|
meter = setup_telemetry("signal-generator", config.otel_metrics_port)
|
||||||
|
signals_generated = meter.create_counter(
|
||||||
|
"signals_generated",
|
||||||
|
description="Total trade signals emitted by the signal generator",
|
||||||
|
)
|
||||||
|
per_strategy_signal_count = meter.create_counter(
|
||||||
|
"per_strategy_signal_count",
|
||||||
|
description="Signals emitted, broken down by strategy",
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Redis ---
|
||||||
|
redis = Redis.from_url(config.redis_url, decode_responses=False)
|
||||||
|
consumer = StreamConsumer(redis, "news:scored", "signal-generator", "worker-1")
|
||||||
|
publisher = StreamPublisher(redis, "signals:generated")
|
||||||
|
|
||||||
|
# --- Market data ---
|
||||||
|
market_data = MarketDataManager()
|
||||||
|
|
||||||
|
# --- Strategies ---
|
||||||
|
strategies = [
|
||||||
|
MomentumStrategy(),
|
||||||
|
MeanReversionStrategy(),
|
||||||
|
NewsDrivenStrategy(),
|
||||||
|
]
|
||||||
|
ensemble = WeightedEnsemble(strategies, threshold=config.signal_strength_threshold)
|
||||||
|
|
||||||
|
# --- Strategy weights (default equal; could load from DB) ---
|
||||||
|
weights = dict(_DEFAULT_WEIGHTS)
|
||||||
|
|
||||||
|
# --- Per-ticker sentiment accumulators ---
|
||||||
|
sentiment_scores: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES))
|
||||||
|
sentiment_confidences: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES))
|
||||||
|
|
||||||
|
logger.info("Consuming from news:scored, publishing to signals:generated")
|
||||||
|
|
||||||
|
# --- Consume loop ---
|
||||||
|
async for _msg_id, data in consumer.consume():
|
||||||
|
try:
|
||||||
|
article = ScoredArticle.model_validate(data)
|
||||||
|
ticker = article.ticker
|
||||||
|
|
||||||
|
# Update sentiment accumulators
|
||||||
|
sentiment_scores[ticker].append(article.sentiment_score)
|
||||||
|
sentiment_confidences[ticker].append(article.confidence)
|
||||||
|
|
||||||
|
# Build sentiment context
|
||||||
|
sentiment = _build_sentiment_context(
|
||||||
|
ticker,
|
||||||
|
sentiment_scores[ticker],
|
||||||
|
sentiment_confidences[ticker],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get market snapshot (may be None if no bars received yet)
|
||||||
|
snapshot = market_data.get_snapshot(ticker)
|
||||||
|
if snapshot is None:
|
||||||
|
# Create a minimal snapshot from sentiment data alone
|
||||||
|
# (the news_driven strategy does not require market indicators)
|
||||||
|
from shared.schemas.trading import MarketSnapshot
|
||||||
|
|
||||||
|
snapshot = MarketSnapshot(
|
||||||
|
ticker=ticker,
|
||||||
|
current_price=0.0,
|
||||||
|
open=0.0,
|
||||||
|
high=0.0,
|
||||||
|
low=0.0,
|
||||||
|
close=0.0,
|
||||||
|
volume=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run ensemble
|
||||||
|
signal = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
|
||||||
|
|
||||||
|
if signal is not None:
|
||||||
|
await publisher.publish(signal.model_dump(mode="json"))
|
||||||
|
signals_generated.add(1)
|
||||||
|
for src in signal.strategy_sources:
|
||||||
|
strategy_name = src.split(":")[0]
|
||||||
|
per_strategy_signal_count.add(1, {"strategy": strategy_name})
|
||||||
|
logger.info(
|
||||||
|
"Signal generated: %s %s strength=%.4f sources=%s",
|
||||||
|
signal.direction.value,
|
||||||
|
ticker,
|
||||||
|
signal.strength,
|
||||||
|
signal.strategy_sources,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error processing scored article: %s", data.get("title", "<unknown>"))
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""CLI entry point."""
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
122
services/signal_generator/market_data.py
Normal file
122
services/signal_generator/market_data.py
Normal file
|
|
@ -0,0 +1,122 @@
|
||||||
|
"""In-memory market data manager with rolling OHLCV windows.
|
||||||
|
|
||||||
|
Maintains a per-ticker deque of recent bars and computes technical
|
||||||
|
indicators (SMA, RSI) on demand when building ``MarketSnapshot`` objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.schemas.trading import MarketSnapshot, OHLCVBar
|
||||||
|
|
||||||
|
|
||||||
|
# Default rolling-window sizes
|
||||||
|
_DEFAULT_MAX_BARS = 100
|
||||||
|
_RSI_PERIOD = 14
|
||||||
|
|
||||||
|
|
||||||
|
class MarketDataManager:
|
||||||
|
"""Manages in-memory rolling windows of OHLCV bars per ticker.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
max_bars:
|
||||||
|
Maximum number of bars to retain per ticker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_bars: int = _DEFAULT_MAX_BARS) -> None:
|
||||||
|
self.max_bars = max_bars
|
||||||
|
self._bars: dict[str, deque[OHLCVBar]] = {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def add_bar(self, ticker: str, bar_data: dict[str, Any] | OHLCVBar) -> None:
|
||||||
|
"""Append a bar to the rolling window for *ticker*.
|
||||||
|
|
||||||
|
``bar_data`` can be a dict (parsed from JSON) or an ``OHLCVBar``
|
||||||
|
instance.
|
||||||
|
"""
|
||||||
|
if isinstance(bar_data, dict):
|
||||||
|
bar = OHLCVBar.model_validate(bar_data)
|
||||||
|
else:
|
||||||
|
bar = bar_data
|
||||||
|
|
||||||
|
if ticker not in self._bars:
|
||||||
|
self._bars[ticker] = deque(maxlen=self.max_bars)
|
||||||
|
self._bars[ticker].append(bar)
|
||||||
|
|
||||||
|
def get_snapshot(self, ticker: str) -> MarketSnapshot | None:
|
||||||
|
"""Build a ``MarketSnapshot`` from the rolling window.
|
||||||
|
|
||||||
|
Returns ``None`` if no bars have been recorded for *ticker*.
|
||||||
|
"""
|
||||||
|
bars = self._bars.get(ticker)
|
||||||
|
if not bars:
|
||||||
|
return None
|
||||||
|
|
||||||
|
latest = bars[-1]
|
||||||
|
closes = [b.close for b in bars]
|
||||||
|
|
||||||
|
return MarketSnapshot(
|
||||||
|
ticker=ticker,
|
||||||
|
current_price=latest.close,
|
||||||
|
open=latest.open,
|
||||||
|
high=latest.high,
|
||||||
|
low=latest.low,
|
||||||
|
close=latest.close,
|
||||||
|
volume=latest.volume,
|
||||||
|
sma_20=self._compute_sma(closes, 20),
|
||||||
|
sma_50=self._compute_sma(closes, 50),
|
||||||
|
rsi=self._compute_rsi(closes, _RSI_PERIOD),
|
||||||
|
bars=[b.model_dump(mode="json") for b in bars],
|
||||||
|
)
|
||||||
|
|
||||||
|
def has_ticker(self, ticker: str) -> bool:
|
||||||
|
"""Return ``True`` if at least one bar exists for *ticker*."""
|
||||||
|
return ticker in self._bars and len(self._bars[ticker]) > 0
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Technical indicator helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_sma(closes: list[float], period: int) -> float | None:
|
||||||
|
"""Compute the simple moving average over the last *period* closes.
|
||||||
|
|
||||||
|
Returns ``None`` if there are fewer than *period* data points.
|
||||||
|
"""
|
||||||
|
if len(closes) < period:
|
||||||
|
return None
|
||||||
|
return sum(closes[-period:]) / period
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_rsi(closes: list[float], period: int = 14) -> float | None:
|
||||||
|
"""Compute the standard RSI over the last *period+1* closes.
|
||||||
|
|
||||||
|
Uses the average-gain / average-loss method. Returns ``None`` if
|
||||||
|
there are not enough data points (need at least ``period + 1``
|
||||||
|
closes to compute ``period`` deltas).
|
||||||
|
"""
|
||||||
|
if len(closes) < period + 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Only use the most recent period+1 closes
|
||||||
|
relevant = closes[-(period + 1):]
|
||||||
|
deltas = [relevant[i + 1] - relevant[i] for i in range(len(relevant) - 1)]
|
||||||
|
|
||||||
|
gains = [d for d in deltas if d > 0]
|
||||||
|
losses = [-d for d in deltas if d < 0]
|
||||||
|
|
||||||
|
avg_gain = sum(gains) / period if gains else 0.0
|
||||||
|
avg_loss = sum(losses) / period if losses else 0.0
|
||||||
|
|
||||||
|
if avg_loss == 0:
|
||||||
|
return 100.0 # No losses -> RSI is 100
|
||||||
|
|
||||||
|
rs = avg_gain / avg_loss
|
||||||
|
rsi = 100.0 - (100.0 / (1.0 + rs))
|
||||||
|
return round(rsi, 4)
|
||||||
13
shared/strategies/__init__.py
Normal file
13
shared/strategies/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
"""Trading strategy implementations."""
|
||||||
|
|
||||||
|
from shared.strategies.base import BaseStrategy
|
||||||
|
from shared.strategies.mean_reversion import MeanReversionStrategy
|
||||||
|
from shared.strategies.momentum import MomentumStrategy
|
||||||
|
from shared.strategies.news_driven import NewsDrivenStrategy
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseStrategy",
|
||||||
|
"MomentumStrategy",
|
||||||
|
"MeanReversionStrategy",
|
||||||
|
"NewsDrivenStrategy",
|
||||||
|
]
|
||||||
26
shared/strategies/base.py
Normal file
26
shared/strategies/base.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Abstract base class for trading strategies."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from shared.schemas.trading import MarketSnapshot, SentimentContext, TradeSignal
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStrategy(ABC):
|
||||||
|
"""Interface that every trading strategy must implement.
|
||||||
|
|
||||||
|
Each strategy evaluates market conditions (and optionally sentiment)
|
||||||
|
for a given ticker and returns a ``TradeSignal`` if the strategy has
|
||||||
|
an opinion, or ``None`` if it is neutral.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def evaluate(
|
||||||
|
self,
|
||||||
|
ticker: str,
|
||||||
|
market: MarketSnapshot,
|
||||||
|
sentiment: SentimentContext | None = None,
|
||||||
|
) -> TradeSignal | None:
|
||||||
|
"""Return a signal if this strategy has an opinion, ``None`` otherwise."""
|
||||||
|
...
|
||||||
60
shared/strategies/mean_reversion.py
Normal file
60
shared/strategies/mean_reversion.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
"""Mean reversion trading strategy.
|
||||||
|
|
||||||
|
Buy when RSI < 30 (oversold), sell when RSI > 70 (overbought).
|
||||||
|
Signal strength is proportional to RSI extremity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection, TradeSignal
|
||||||
|
from shared.strategies.base import BaseStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class MeanReversionStrategy(BaseStrategy):
|
||||||
|
"""Contrarian strategy based on RSI extremes."""
|
||||||
|
|
||||||
|
name: str = "mean_reversion"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
oversold_threshold: float = 30.0,
|
||||||
|
overbought_threshold: float = 70.0,
|
||||||
|
) -> None:
|
||||||
|
self.oversold_threshold = oversold_threshold
|
||||||
|
self.overbought_threshold = overbought_threshold
|
||||||
|
|
||||||
|
async def evaluate(
|
||||||
|
self,
|
||||||
|
ticker: str,
|
||||||
|
market: MarketSnapshot,
|
||||||
|
sentiment: SentimentContext | None = None,
|
||||||
|
) -> TradeSignal | None:
|
||||||
|
"""Generate a signal when RSI indicates oversold/overbought conditions."""
|
||||||
|
if market.rsi is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
rsi = market.rsi
|
||||||
|
|
||||||
|
if rsi < self.oversold_threshold:
|
||||||
|
direction = SignalDirection.LONG
|
||||||
|
# Strength proportional to how oversold: RSI 0 -> strength 1.0, RSI 30 -> strength 0.0
|
||||||
|
strength = (self.oversold_threshold - rsi) / self.oversold_threshold
|
||||||
|
elif rsi > self.overbought_threshold:
|
||||||
|
direction = SignalDirection.SHORT
|
||||||
|
# Strength proportional to how overbought: RSI 100 -> strength 1.0, RSI 70 -> strength 0.0
|
||||||
|
strength = (rsi - self.overbought_threshold) / (100.0 - self.overbought_threshold)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
strength = min(max(strength, 0.0), 1.0)
|
||||||
|
|
||||||
|
return TradeSignal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=direction,
|
||||||
|
strength=round(strength, 4),
|
||||||
|
strategy_sources=[self.name],
|
||||||
|
sentiment_context=None,
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
62
shared/strategies/momentum.py
Normal file
62
shared/strategies/momentum.py
Normal file
|
|
@ -0,0 +1,62 @@
|
||||||
|
"""Momentum trading strategy.
|
||||||
|
|
||||||
|
Buy when price crosses above N-period SMA with increasing volume.
|
||||||
|
Sell when price crosses below SMA. Signal strength is proportional
|
||||||
|
to the distance from the SMA.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection, TradeSignal
|
||||||
|
from shared.strategies.base import BaseStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class MomentumStrategy(BaseStrategy):
|
||||||
|
"""Trend-following momentum strategy based on SMA crossover."""
|
||||||
|
|
||||||
|
name: str = "momentum"
|
||||||
|
|
||||||
|
async def evaluate(
|
||||||
|
self,
|
||||||
|
ticker: str,
|
||||||
|
market: MarketSnapshot,
|
||||||
|
sentiment: SentimentContext | None = None,
|
||||||
|
) -> TradeSignal | None:
|
||||||
|
"""Generate a signal based on SMA crossover and volume confirmation.
|
||||||
|
|
||||||
|
Uses the 20-period SMA by default. Signal strength is the
|
||||||
|
normalised distance from the SMA (capped at 1.0).
|
||||||
|
"""
|
||||||
|
if market.sma_20 is None or market.sma_20 == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
price = market.current_price
|
||||||
|
sma = market.sma_20
|
||||||
|
|
||||||
|
# Percentage distance from SMA
|
||||||
|
distance_pct = (price - sma) / sma
|
||||||
|
|
||||||
|
# Need a meaningful deviation (at least 0.5%)
|
||||||
|
if abs(distance_pct) < 0.005:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Determine direction
|
||||||
|
if distance_pct > 0:
|
||||||
|
direction = SignalDirection.LONG
|
||||||
|
else:
|
||||||
|
direction = SignalDirection.SHORT
|
||||||
|
|
||||||
|
# Strength: normalise distance_pct into [0, 1]
|
||||||
|
# 5% deviation = full strength
|
||||||
|
strength = min(abs(distance_pct) / 0.05, 1.0)
|
||||||
|
|
||||||
|
return TradeSignal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=direction,
|
||||||
|
strength=round(strength, 4),
|
||||||
|
strategy_sources=[self.name],
|
||||||
|
sentiment_context=None,
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
73
shared/strategies/news_driven.py
Normal file
73
shared/strategies/news_driven.py
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
"""News-driven trading strategy.
|
||||||
|
|
||||||
|
Buy on strong positive sentiment (score > 0.7, confidence > 0.6),
|
||||||
|
sell on strong negative sentiment. Signal strength is the product
|
||||||
|
of sentiment score and confidence, with a decay factor for stale news.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection, TradeSignal
|
||||||
|
from shared.strategies.base import BaseStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class NewsDrivenStrategy(BaseStrategy):
|
||||||
|
"""Sentiment-based strategy driven by scored news articles."""
|
||||||
|
|
||||||
|
name: str = "news_driven"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
positive_threshold: float = 0.7,
|
||||||
|
negative_threshold: float = -0.7,
|
||||||
|
min_confidence: float = 0.6,
|
||||||
|
min_articles: int = 1,
|
||||||
|
) -> None:
|
||||||
|
self.positive_threshold = positive_threshold
|
||||||
|
self.negative_threshold = negative_threshold
|
||||||
|
self.min_confidence = min_confidence
|
||||||
|
self.min_articles = min_articles
|
||||||
|
|
||||||
|
async def evaluate(
|
||||||
|
self,
|
||||||
|
ticker: str,
|
||||||
|
market: MarketSnapshot,
|
||||||
|
sentiment: SentimentContext | None = None,
|
||||||
|
) -> TradeSignal | None:
|
||||||
|
"""Generate a signal based on aggregated news sentiment."""
|
||||||
|
if sentiment is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if sentiment.article_count < self.min_articles:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if sentiment.avg_confidence < self.min_confidence:
|
||||||
|
return None
|
||||||
|
|
||||||
|
score = sentiment.avg_score
|
||||||
|
|
||||||
|
if score > self.positive_threshold:
|
||||||
|
direction = SignalDirection.LONG
|
||||||
|
elif score < self.negative_threshold:
|
||||||
|
direction = SignalDirection.SHORT
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Strength = |score| * confidence (both in [0, 1])
|
||||||
|
strength = abs(score) * sentiment.avg_confidence
|
||||||
|
strength = min(max(strength, 0.0), 1.0)
|
||||||
|
|
||||||
|
return TradeSignal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=direction,
|
||||||
|
strength=round(strength, 4),
|
||||||
|
strategy_sources=[self.name],
|
||||||
|
sentiment_context={
|
||||||
|
"avg_score": sentiment.avg_score,
|
||||||
|
"article_count": sentiment.article_count,
|
||||||
|
"avg_confidence": sentiment.avg_confidence,
|
||||||
|
},
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
359
tests/services/test_signal_generator.py
Normal file
359
tests/services/test_signal_generator.py
Normal file
|
|
@ -0,0 +1,359 @@
|
||||||
|
"""Tests for the Signal Generator service.
|
||||||
|
|
||||||
|
Covers MarketDataManager (SMA, RSI, snapshot) and WeightedEnsemble
|
||||||
|
(signal combination, threshold filtering, strategy source tagging).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from services.signal_generator.ensemble import WeightedEnsemble
|
||||||
|
from services.signal_generator.market_data import MarketDataManager
|
||||||
|
from shared.schemas.trading import (
|
||||||
|
MarketSnapshot,
|
||||||
|
OHLCVBar,
|
||||||
|
SentimentContext,
|
||||||
|
SignalDirection,
|
||||||
|
TradeSignal,
|
||||||
|
)
|
||||||
|
from shared.strategies.base import BaseStrategy
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_bar(close: float, *, ts_offset: int = 0) -> OHLCVBar:
|
||||||
|
"""Create an ``OHLCVBar`` with the given close price."""
|
||||||
|
return OHLCVBar(
|
||||||
|
timestamp=datetime(2026, 1, 1, 10, ts_offset, tzinfo=timezone.utc),
|
||||||
|
open=close - 0.5,
|
||||||
|
high=close + 1.0,
|
||||||
|
low=close - 1.0,
|
||||||
|
close=close,
|
||||||
|
volume=1000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _StubStrategy(BaseStrategy):
|
||||||
|
"""Test helper that returns a preconfigured signal."""
|
||||||
|
|
||||||
|
def __init__(self, name: str, signal: TradeSignal | None) -> None:
|
||||||
|
self.name = name
|
||||||
|
self._signal = signal
|
||||||
|
|
||||||
|
async def evaluate(self, ticker, market, sentiment=None):
|
||||||
|
return self._signal
|
||||||
|
|
||||||
|
|
||||||
|
def _make_signal(
|
||||||
|
direction: SignalDirection = SignalDirection.LONG,
|
||||||
|
strength: float = 0.8,
|
||||||
|
sources: list[str] | None = None,
|
||||||
|
) -> TradeSignal:
|
||||||
|
return TradeSignal(
|
||||||
|
ticker="AAPL",
|
||||||
|
direction=direction,
|
||||||
|
strength=strength,
|
||||||
|
strategy_sources=sources or ["test"],
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MarketDataManager — SMA
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarketDataManagerSMA:
|
||||||
|
"""Tests for SMA computation inside MarketDataManager."""
|
||||||
|
|
||||||
|
def test_sma_basic(self):
|
||||||
|
"""SMA-20 should equal the mean of the last 20 close prices."""
|
||||||
|
mgr = MarketDataManager()
|
||||||
|
closes = list(range(1, 21)) # 1, 2, ..., 20
|
||||||
|
for i, c in enumerate(closes):
|
||||||
|
mgr.add_bar("AAPL", _make_bar(float(c), ts_offset=i))
|
||||||
|
|
||||||
|
snap = mgr.get_snapshot("AAPL")
|
||||||
|
assert snap is not None
|
||||||
|
expected_sma_20 = sum(closes) / 20
|
||||||
|
assert snap.sma_20 == pytest.approx(expected_sma_20)
|
||||||
|
|
||||||
|
def test_sma_returns_none_insufficient_data(self):
|
||||||
|
"""SMA-20 should be None when fewer than 20 bars exist."""
|
||||||
|
mgr = MarketDataManager()
|
||||||
|
for i in range(10):
|
||||||
|
mgr.add_bar("AAPL", _make_bar(100.0, ts_offset=i))
|
||||||
|
|
||||||
|
snap = mgr.get_snapshot("AAPL")
|
||||||
|
assert snap is not None
|
||||||
|
assert snap.sma_20 is None
|
||||||
|
|
||||||
|
def test_sma_50_requires_50_bars(self):
|
||||||
|
"""SMA-50 should be None with only 30 bars, present with 50."""
|
||||||
|
mgr = MarketDataManager()
|
||||||
|
for i in range(30):
|
||||||
|
mgr.add_bar("AAPL", _make_bar(float(i + 1), ts_offset=i))
|
||||||
|
snap = mgr.get_snapshot("AAPL")
|
||||||
|
assert snap is not None
|
||||||
|
assert snap.sma_50 is None
|
||||||
|
|
||||||
|
# Add 20 more
|
||||||
|
for i in range(30, 50):
|
||||||
|
mgr.add_bar("AAPL", _make_bar(float(i + 1), ts_offset=i))
|
||||||
|
snap = mgr.get_snapshot("AAPL")
|
||||||
|
assert snap is not None
|
||||||
|
assert snap.sma_50 is not None
|
||||||
|
expected = sum(range(1, 51)) / 50
|
||||||
|
assert snap.sma_50 == pytest.approx(expected)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MarketDataManager — RSI
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarketDataManagerRSI:
|
||||||
|
"""Tests for RSI computation inside MarketDataManager."""
|
||||||
|
|
||||||
|
def test_rsi_all_gains(self):
|
||||||
|
"""RSI should be 100 when all price changes are positive."""
|
||||||
|
mgr = MarketDataManager()
|
||||||
|
for i in range(20):
|
||||||
|
mgr.add_bar("AAPL", _make_bar(100.0 + i, ts_offset=i))
|
||||||
|
|
||||||
|
snap = mgr.get_snapshot("AAPL")
|
||||||
|
assert snap is not None
|
||||||
|
assert snap.rsi == pytest.approx(100.0)
|
||||||
|
|
||||||
|
def test_rsi_all_losses(self):
|
||||||
|
"""RSI should be 0 when all price changes are negative."""
|
||||||
|
mgr = MarketDataManager()
|
||||||
|
for i in range(20):
|
||||||
|
mgr.add_bar("AAPL", _make_bar(200.0 - i, ts_offset=i))
|
||||||
|
|
||||||
|
snap = mgr.get_snapshot("AAPL")
|
||||||
|
assert snap is not None
|
||||||
|
assert snap.rsi == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_rsi_mixed(self):
|
||||||
|
"""RSI should be between 0 and 100 with mixed gains and losses."""
|
||||||
|
mgr = MarketDataManager()
|
||||||
|
prices = [44, 44.34, 44.09, 43.61, 44.33, 44.83, 45.10, 45.42,
|
||||||
|
45.84, 46.08, 45.89, 46.03, 45.61, 46.28, 46.28, 46.00]
|
||||||
|
for i, p in enumerate(prices):
|
||||||
|
mgr.add_bar("AAPL", _make_bar(p, ts_offset=i))
|
||||||
|
|
||||||
|
snap = mgr.get_snapshot("AAPL")
|
||||||
|
assert snap is not None
|
||||||
|
assert snap.rsi is not None
|
||||||
|
assert 0 < snap.rsi < 100
|
||||||
|
|
||||||
|
def test_rsi_returns_none_insufficient_data(self):
|
||||||
|
"""RSI should be None when fewer than 15 bars exist (need 14+1)."""
|
||||||
|
mgr = MarketDataManager()
|
||||||
|
for i in range(10):
|
||||||
|
mgr.add_bar("AAPL", _make_bar(100.0, ts_offset=i))
|
||||||
|
|
||||||
|
snap = mgr.get_snapshot("AAPL")
|
||||||
|
assert snap is not None
|
||||||
|
assert snap.rsi is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MarketDataManager — snapshot
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarketDataManagerSnapshot:
|
||||||
|
"""Tests for get_snapshot behaviour."""
|
||||||
|
|
||||||
|
def test_snapshot_returns_none_for_unknown_ticker(self):
|
||||||
|
mgr = MarketDataManager()
|
||||||
|
assert mgr.get_snapshot("UNKNOWN") is None
|
||||||
|
|
||||||
|
def test_snapshot_uses_latest_bar_for_price(self):
|
||||||
|
mgr = MarketDataManager()
|
||||||
|
mgr.add_bar("AAPL", _make_bar(100.0, ts_offset=0))
|
||||||
|
mgr.add_bar("AAPL", _make_bar(105.0, ts_offset=1))
|
||||||
|
snap = mgr.get_snapshot("AAPL")
|
||||||
|
assert snap is not None
|
||||||
|
assert snap.current_price == 105.0
|
||||||
|
|
||||||
|
def test_snapshot_contains_bars(self):
|
||||||
|
mgr = MarketDataManager()
|
||||||
|
for i in range(5):
|
||||||
|
mgr.add_bar("AAPL", _make_bar(100.0 + i, ts_offset=i))
|
||||||
|
snap = mgr.get_snapshot("AAPL")
|
||||||
|
assert snap is not None
|
||||||
|
assert len(snap.bars) == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# WeightedEnsemble — combines signals
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnsembleCombinesSignals:
|
||||||
|
"""Test that the ensemble correctly combines strategy signals."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_combines_two_long_signals(self):
|
||||||
|
"""Two LONG signals should produce a combined LONG signal."""
|
||||||
|
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.8))
|
||||||
|
s2 = _StubStrategy("beta", _make_signal(SignalDirection.LONG, 0.6))
|
||||||
|
|
||||||
|
ensemble = WeightedEnsemble([s1, s2], threshold=0.0)
|
||||||
|
market = MarketSnapshot(
|
||||||
|
ticker="AAPL", current_price=150.0,
|
||||||
|
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||||
|
)
|
||||||
|
weights = {"alpha": 0.5, "beta": 0.5}
|
||||||
|
|
||||||
|
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||||
|
|
||||||
|
assert signal is not None
|
||||||
|
assert signal.direction == SignalDirection.LONG
|
||||||
|
# Weighted average = (0.8*0.5 + 0.6*0.5) / (0.5+0.5) = 0.7
|
||||||
|
assert signal.strength == pytest.approx(0.7, abs=0.01)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_opposing_signals_net_direction(self):
|
||||||
|
"""When strategies disagree, direction follows the stronger weighted side."""
|
||||||
|
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
|
||||||
|
s2 = _StubStrategy("beta", _make_signal(SignalDirection.SHORT, 0.3))
|
||||||
|
|
||||||
|
ensemble = WeightedEnsemble([s1, s2], threshold=0.0)
|
||||||
|
market = MarketSnapshot(
|
||||||
|
ticker="AAPL", current_price=150.0,
|
||||||
|
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||||
|
)
|
||||||
|
weights = {"alpha": 0.5, "beta": 0.5}
|
||||||
|
|
||||||
|
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||||
|
|
||||||
|
assert signal is not None
|
||||||
|
# Net direction should be LONG since alpha is stronger
|
||||||
|
assert signal.direction == SignalDirection.LONG
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# WeightedEnsemble — threshold filtering
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnsembleThresholdFiltering:
|
||||||
|
"""Test that weak combined signals are filtered out by the threshold."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_below_threshold_returns_none(self):
|
||||||
|
"""Combined strength below threshold should yield None."""
|
||||||
|
# Two opposing signals of similar strength will nearly cancel out
|
||||||
|
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.5))
|
||||||
|
s2 = _StubStrategy("beta", _make_signal(SignalDirection.SHORT, 0.45))
|
||||||
|
|
||||||
|
ensemble = WeightedEnsemble([s1, s2], threshold=0.5)
|
||||||
|
market = MarketSnapshot(
|
||||||
|
ticker="AAPL", current_price=150.0,
|
||||||
|
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||||
|
)
|
||||||
|
weights = {"alpha": 0.5, "beta": 0.5}
|
||||||
|
|
||||||
|
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||||
|
assert signal is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_above_threshold_returns_signal(self):
|
||||||
|
"""Strong combined signal above threshold should yield a signal."""
|
||||||
|
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
|
||||||
|
|
||||||
|
ensemble = WeightedEnsemble([s1], threshold=0.3)
|
||||||
|
market = MarketSnapshot(
|
||||||
|
ticker="AAPL", current_price=150.0,
|
||||||
|
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||||
|
)
|
||||||
|
weights = {"alpha": 1.0}
|
||||||
|
|
||||||
|
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||||
|
assert signal is not None
|
||||||
|
assert signal.strength >= 0.3
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# WeightedEnsemble — no signals returns None
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnsembleNoSignals:
|
||||||
|
"""Test that the ensemble returns None when no strategy fires."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_strategies_return_none(self):
|
||||||
|
s1 = _StubStrategy("alpha", None)
|
||||||
|
s2 = _StubStrategy("beta", None)
|
||||||
|
|
||||||
|
ensemble = WeightedEnsemble([s1, s2], threshold=0.3)
|
||||||
|
market = MarketSnapshot(
|
||||||
|
ticker="AAPL", current_price=150.0,
|
||||||
|
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||||
|
)
|
||||||
|
weights = {"alpha": 0.5, "beta": 0.5}
|
||||||
|
|
||||||
|
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||||
|
assert signal is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# WeightedEnsemble — tags strategy sources
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnsembleTagsStrategySources:
|
||||||
|
"""Verify that the output signal records which strategies contributed."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_strategy_sources_contains_all_contributors(self):
|
||||||
|
s1 = _StubStrategy("momentum", _make_signal(SignalDirection.LONG, 0.7, ["momentum"]))
|
||||||
|
s2 = _StubStrategy("news_driven", _make_signal(SignalDirection.LONG, 0.6, ["news_driven"]))
|
||||||
|
s3 = _StubStrategy("mean_reversion", None) # does not contribute
|
||||||
|
|
||||||
|
ensemble = WeightedEnsemble([s1, s2, s3], threshold=0.0)
|
||||||
|
market = MarketSnapshot(
|
||||||
|
ticker="AAPL", current_price=150.0,
|
||||||
|
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||||
|
)
|
||||||
|
weights = {"momentum": 0.5, "news_driven": 0.3, "mean_reversion": 0.2}
|
||||||
|
|
||||||
|
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||||
|
assert signal is not None
|
||||||
|
# Should have exactly 2 sources
|
||||||
|
assert len(signal.strategy_sources) == 2
|
||||||
|
source_names = [s.split(":")[0] for s in signal.strategy_sources]
|
||||||
|
assert "momentum" in source_names
|
||||||
|
assert "news_driven" in source_names
|
||||||
|
# mean_reversion should NOT be present
|
||||||
|
assert "mean_reversion" not in source_names
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_strategy_sources_contain_direction_and_strength(self):
|
||||||
|
"""Each source tag should be formatted as name:DIRECTION:strength."""
|
||||||
|
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.75))
|
||||||
|
ensemble = WeightedEnsemble([s1], threshold=0.0)
|
||||||
|
market = MarketSnapshot(
|
||||||
|
ticker="AAPL", current_price=150.0,
|
||||||
|
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||||
|
)
|
||||||
|
weights = {"alpha": 1.0}
|
||||||
|
|
||||||
|
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||||
|
assert signal is not None
|
||||||
|
assert len(signal.strategy_sources) == 1
|
||||||
|
parts = signal.strategy_sources[0].split(":")
|
||||||
|
assert parts[0] == "alpha"
|
||||||
|
assert parts[1] == "LONG"
|
||||||
|
assert float(parts[2]) == pytest.approx(0.75, abs=0.01)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue