feat: signal generator — weighted ensemble with market data
This commit is contained in:
parent
e483e9987f
commit
f3e5fc944d
11 changed files with 1013 additions and 0 deletions
1
services/signal_generator/__init__.py
Normal file
1
services/signal_generator/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Signal Generator service — weighted ensemble of trading strategies."""
|
||||
14
services/signal_generator/config.py
Normal file
14
services/signal_generator/config.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""Configuration for the signal generator service."""
|
||||
|
||||
from shared.config import BaseConfig
|
||||
|
||||
|
||||
class SignalGeneratorConfig(BaseConfig):
|
||||
"""Extends BaseConfig with signal-generator-specific settings."""
|
||||
|
||||
alpaca_api_key: str = ""
|
||||
alpaca_secret_key: str = ""
|
||||
signal_strength_threshold: float = 0.3
|
||||
watchlist: list[str] = []
|
||||
|
||||
model_config = {"env_prefix": "TRADING_"}
|
||||
118
services/signal_generator/ensemble.py
Normal file
118
services/signal_generator/ensemble.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""Weighted ensemble that combines signals from multiple strategies.
|
||||
|
||||
Runs all registered strategies, collects non-``None`` signals, and computes
|
||||
a combined strength via a weighted average. Only emits a ``TradeSignal``
|
||||
when the combined strength exceeds a configurable threshold.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from shared.schemas.trading import (
|
||||
MarketSnapshot,
|
||||
SentimentContext,
|
||||
SignalDirection,
|
||||
TradeSignal,
|
||||
)
|
||||
from shared.strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class WeightedEnsemble:
|
||||
"""Combine signals from multiple strategies using weighted averaging.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
strategies:
|
||||
The list of strategy instances to evaluate.
|
||||
threshold:
|
||||
Minimum combined strength required to emit a signal (default 0.3).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategies: list[BaseStrategy],
|
||||
threshold: float = 0.3,
|
||||
) -> None:
|
||||
self.strategies = strategies
|
||||
self.threshold = threshold
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
ticker: str,
|
||||
market: MarketSnapshot,
|
||||
sentiment: SentimentContext | None,
|
||||
weights: dict[str, float],
|
||||
) -> TradeSignal | None:
|
||||
"""Run all strategies and return a combined signal, or ``None``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ticker:
|
||||
The stock ticker being evaluated.
|
||||
market:
|
||||
Current market snapshot including price, SMA, RSI.
|
||||
sentiment:
|
||||
Aggregated sentiment context (may be ``None``).
|
||||
weights:
|
||||
Mapping from strategy name to its weight.
|
||||
"""
|
||||
# Step 1: run all strategies, collect (strategy, signal) pairs
|
||||
signals: list[tuple[BaseStrategy, TradeSignal]] = []
|
||||
for strategy in self.strategies:
|
||||
signal = await strategy.evaluate(ticker, market, sentiment)
|
||||
if signal is not None:
|
||||
signals.append((strategy, signal))
|
||||
|
||||
if not signals:
|
||||
return None
|
||||
|
||||
# Step 2: compute weighted sum
|
||||
weighted_sum = 0.0
|
||||
total_weight = 0.0
|
||||
for strategy, signal in signals:
|
||||
w = weights.get(strategy.name, 0.1)
|
||||
direction_sign = 1.0 if signal.direction == SignalDirection.LONG else -1.0
|
||||
weighted_sum += signal.strength * direction_sign * w
|
||||
total_weight += w
|
||||
|
||||
if total_weight == 0:
|
||||
return None
|
||||
|
||||
# Step 3: combined strength
|
||||
combined_strength = abs(weighted_sum) / total_weight
|
||||
|
||||
if combined_strength < self.threshold:
|
||||
return None
|
||||
|
||||
# Step 4: determine direction from the sign of the weighted sum
|
||||
if weighted_sum > 0:
|
||||
direction = SignalDirection.LONG
|
||||
elif weighted_sum < 0:
|
||||
direction = SignalDirection.SHORT
|
||||
else:
|
||||
return None
|
||||
|
||||
# Step 5: build strategy_sources with individual contributions
|
||||
strategy_sources = [
|
||||
f"{strategy.name}:{signal.direction.value}:{signal.strength:.4f}"
|
||||
for strategy, signal in signals
|
||||
]
|
||||
|
||||
# Carry forward sentiment context if available
|
||||
sentiment_ctx = None
|
||||
if sentiment is not None:
|
||||
sentiment_ctx = {
|
||||
"avg_score": sentiment.avg_score,
|
||||
"article_count": sentiment.article_count,
|
||||
"avg_confidence": sentiment.avg_confidence,
|
||||
}
|
||||
|
||||
return TradeSignal(
|
||||
ticker=ticker,
|
||||
direction=direction,
|
||||
strength=round(min(combined_strength, 1.0), 4),
|
||||
strategy_sources=strategy_sources,
|
||||
sentiment_context=sentiment_ctx,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
165
services/signal_generator/main.py
Normal file
165
services/signal_generator/main.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
"""Signal Generator service -- main entry point.
|
||||
|
||||
Consumes ``news:scored`` articles from Redis Streams, updates sentiment
|
||||
context per ticker, runs the weighted ensemble of trading strategies, and
|
||||
publishes qualifying ``TradeSignal`` messages to ``signals:generated``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from services.signal_generator.config import SignalGeneratorConfig
|
||||
from services.signal_generator.ensemble import WeightedEnsemble
|
||||
from services.signal_generator.market_data import MarketDataManager
|
||||
from shared.redis_streams import StreamConsumer, StreamPublisher
|
||||
from shared.schemas.news import ScoredArticle
|
||||
from shared.schemas.trading import SentimentContext
|
||||
from shared.strategies import MeanReversionStrategy, MomentumStrategy, NewsDrivenStrategy
|
||||
from shared.telemetry import setup_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum number of recent sentiment scores to retain per ticker
|
||||
_MAX_SENTIMENT_SCORES = 50
|
||||
|
||||
# Default strategy weights (equal weighting)
|
||||
_DEFAULT_WEIGHTS: dict[str, float] = {
|
||||
"momentum": 0.333,
|
||||
"mean_reversion": 0.333,
|
||||
"news_driven": 0.334,
|
||||
}
|
||||
|
||||
|
||||
def _build_sentiment_context(
|
||||
ticker: str,
|
||||
scores: deque[float],
|
||||
confidences: deque[float],
|
||||
) -> SentimentContext:
|
||||
"""Build a ``SentimentContext`` from accumulated per-ticker scores."""
|
||||
score_list = list(scores)
|
||||
conf_list = list(confidences)
|
||||
return SentimentContext(
|
||||
ticker=ticker,
|
||||
avg_score=sum(score_list) / len(score_list) if score_list else 0.0,
|
||||
article_count=len(score_list),
|
||||
recent_scores=score_list[-10:],
|
||||
avg_confidence=sum(conf_list) / len(conf_list) if conf_list else 0.0,
|
||||
)
|
||||
|
||||
|
||||
async def run(config: SignalGeneratorConfig | None = None) -> None:
|
||||
"""Main service loop.
|
||||
|
||||
Connects to Redis, initialises strategies and telemetry, then
|
||||
continuously consumes from ``news:scored`` and publishes qualifying
|
||||
signals to ``signals:generated``.
|
||||
"""
|
||||
if config is None:
|
||||
config = SignalGeneratorConfig()
|
||||
|
||||
logging.basicConfig(level=config.log_level)
|
||||
logger.info("Starting Signal Generator service")
|
||||
|
||||
# --- Telemetry ---
|
||||
meter = setup_telemetry("signal-generator", config.otel_metrics_port)
|
||||
signals_generated = meter.create_counter(
|
||||
"signals_generated",
|
||||
description="Total trade signals emitted by the signal generator",
|
||||
)
|
||||
per_strategy_signal_count = meter.create_counter(
|
||||
"per_strategy_signal_count",
|
||||
description="Signals emitted, broken down by strategy",
|
||||
)
|
||||
|
||||
# --- Redis ---
|
||||
redis = Redis.from_url(config.redis_url, decode_responses=False)
|
||||
consumer = StreamConsumer(redis, "news:scored", "signal-generator", "worker-1")
|
||||
publisher = StreamPublisher(redis, "signals:generated")
|
||||
|
||||
# --- Market data ---
|
||||
market_data = MarketDataManager()
|
||||
|
||||
# --- Strategies ---
|
||||
strategies = [
|
||||
MomentumStrategy(),
|
||||
MeanReversionStrategy(),
|
||||
NewsDrivenStrategy(),
|
||||
]
|
||||
ensemble = WeightedEnsemble(strategies, threshold=config.signal_strength_threshold)
|
||||
|
||||
# --- Strategy weights (default equal; could load from DB) ---
|
||||
weights = dict(_DEFAULT_WEIGHTS)
|
||||
|
||||
# --- Per-ticker sentiment accumulators ---
|
||||
sentiment_scores: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES))
|
||||
sentiment_confidences: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES))
|
||||
|
||||
logger.info("Consuming from news:scored, publishing to signals:generated")
|
||||
|
||||
# --- Consume loop ---
|
||||
async for _msg_id, data in consumer.consume():
|
||||
try:
|
||||
article = ScoredArticle.model_validate(data)
|
||||
ticker = article.ticker
|
||||
|
||||
# Update sentiment accumulators
|
||||
sentiment_scores[ticker].append(article.sentiment_score)
|
||||
sentiment_confidences[ticker].append(article.confidence)
|
||||
|
||||
# Build sentiment context
|
||||
sentiment = _build_sentiment_context(
|
||||
ticker,
|
||||
sentiment_scores[ticker],
|
||||
sentiment_confidences[ticker],
|
||||
)
|
||||
|
||||
# Get market snapshot (may be None if no bars received yet)
|
||||
snapshot = market_data.get_snapshot(ticker)
|
||||
if snapshot is None:
|
||||
# Create a minimal snapshot from sentiment data alone
|
||||
# (the news_driven strategy does not require market indicators)
|
||||
from shared.schemas.trading import MarketSnapshot
|
||||
|
||||
snapshot = MarketSnapshot(
|
||||
ticker=ticker,
|
||||
current_price=0.0,
|
||||
open=0.0,
|
||||
high=0.0,
|
||||
low=0.0,
|
||||
close=0.0,
|
||||
volume=0.0,
|
||||
)
|
||||
|
||||
# Run ensemble
|
||||
signal = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
|
||||
|
||||
if signal is not None:
|
||||
await publisher.publish(signal.model_dump(mode="json"))
|
||||
signals_generated.add(1)
|
||||
for src in signal.strategy_sources:
|
||||
strategy_name = src.split(":")[0]
|
||||
per_strategy_signal_count.add(1, {"strategy": strategy_name})
|
||||
logger.info(
|
||||
"Signal generated: %s %s strength=%.4f sources=%s",
|
||||
signal.direction.value,
|
||||
ticker,
|
||||
signal.strength,
|
||||
signal.strategy_sources,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing scored article: %s", data.get("title", "<unknown>"))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""CLI entry point."""
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
122
services/signal_generator/market_data.py
Normal file
122
services/signal_generator/market_data.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""In-memory market data manager with rolling OHLCV windows.
|
||||
|
||||
Maintains a per-ticker deque of recent bars and computes technical
|
||||
indicators (SMA, RSI) on demand when building ``MarketSnapshot`` objects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
from shared.schemas.trading import MarketSnapshot, OHLCVBar
|
||||
|
||||
|
||||
# Default rolling-window sizes
|
||||
_DEFAULT_MAX_BARS = 100
|
||||
_RSI_PERIOD = 14
|
||||
|
||||
|
||||
class MarketDataManager:
|
||||
"""Manages in-memory rolling windows of OHLCV bars per ticker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_bars:
|
||||
Maximum number of bars to retain per ticker.
|
||||
"""
|
||||
|
||||
def __init__(self, max_bars: int = _DEFAULT_MAX_BARS) -> None:
|
||||
self.max_bars = max_bars
|
||||
self._bars: dict[str, deque[OHLCVBar]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def add_bar(self, ticker: str, bar_data: dict[str, Any] | OHLCVBar) -> None:
|
||||
"""Append a bar to the rolling window for *ticker*.
|
||||
|
||||
``bar_data`` can be a dict (parsed from JSON) or an ``OHLCVBar``
|
||||
instance.
|
||||
"""
|
||||
if isinstance(bar_data, dict):
|
||||
bar = OHLCVBar.model_validate(bar_data)
|
||||
else:
|
||||
bar = bar_data
|
||||
|
||||
if ticker not in self._bars:
|
||||
self._bars[ticker] = deque(maxlen=self.max_bars)
|
||||
self._bars[ticker].append(bar)
|
||||
|
||||
def get_snapshot(self, ticker: str) -> MarketSnapshot | None:
|
||||
"""Build a ``MarketSnapshot`` from the rolling window.
|
||||
|
||||
Returns ``None`` if no bars have been recorded for *ticker*.
|
||||
"""
|
||||
bars = self._bars.get(ticker)
|
||||
if not bars:
|
||||
return None
|
||||
|
||||
latest = bars[-1]
|
||||
closes = [b.close for b in bars]
|
||||
|
||||
return MarketSnapshot(
|
||||
ticker=ticker,
|
||||
current_price=latest.close,
|
||||
open=latest.open,
|
||||
high=latest.high,
|
||||
low=latest.low,
|
||||
close=latest.close,
|
||||
volume=latest.volume,
|
||||
sma_20=self._compute_sma(closes, 20),
|
||||
sma_50=self._compute_sma(closes, 50),
|
||||
rsi=self._compute_rsi(closes, _RSI_PERIOD),
|
||||
bars=[b.model_dump(mode="json") for b in bars],
|
||||
)
|
||||
|
||||
def has_ticker(self, ticker: str) -> bool:
|
||||
"""Return ``True`` if at least one bar exists for *ticker*."""
|
||||
return ticker in self._bars and len(self._bars[ticker]) > 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Technical indicator helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _compute_sma(closes: list[float], period: int) -> float | None:
|
||||
"""Compute the simple moving average over the last *period* closes.
|
||||
|
||||
Returns ``None`` if there are fewer than *period* data points.
|
||||
"""
|
||||
if len(closes) < period:
|
||||
return None
|
||||
return sum(closes[-period:]) / period
|
||||
|
||||
@staticmethod
|
||||
def _compute_rsi(closes: list[float], period: int = 14) -> float | None:
|
||||
"""Compute the standard RSI over the last *period+1* closes.
|
||||
|
||||
Uses the average-gain / average-loss method. Returns ``None`` if
|
||||
there are not enough data points (need at least ``period + 1``
|
||||
closes to compute ``period`` deltas).
|
||||
"""
|
||||
if len(closes) < period + 1:
|
||||
return None
|
||||
|
||||
# Only use the most recent period+1 closes
|
||||
relevant = closes[-(period + 1):]
|
||||
deltas = [relevant[i + 1] - relevant[i] for i in range(len(relevant) - 1)]
|
||||
|
||||
gains = [d for d in deltas if d > 0]
|
||||
losses = [-d for d in deltas if d < 0]
|
||||
|
||||
avg_gain = sum(gains) / period if gains else 0.0
|
||||
avg_loss = sum(losses) / period if losses else 0.0
|
||||
|
||||
if avg_loss == 0:
|
||||
return 100.0 # No losses -> RSI is 100
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100.0 - (100.0 / (1.0 + rs))
|
||||
return round(rsi, 4)
|
||||
Loading…
Add table
Add a link
Reference in a new issue