trading/services/signal_generator/main.py

400 lines
15 KiB
Python

"""Signal Generator service -- main entry point.
Consumes ``news:scored`` articles and ``market:bars`` OHLCV data from
Redis Streams, updates sentiment context and market data per ticker,
runs the weighted ensemble of trading strategies, and publishes
qualifying ``TradeSignal`` messages to ``signals:generated``.
"""
from __future__ import annotations
import asyncio
import logging
import signal
import uuid
from collections import defaultdict, deque
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import async_sessionmaker
from services.signal_generator.config import SignalGeneratorConfig
from services.signal_generator.ensemble import WeightedEnsemble
from services.signal_generator.market_data import MarketDataManager
from shared.db import create_db
from shared.models.trading import Signal as SignalModel
from shared.models.trading import SignalDirection as SignalDirectionModel
from shared.redis_streams import StreamConsumer, StreamPublisher
from shared.schemas.news import ScoredArticle
from shared.schemas.trading import FundamentalsSnapshot, MarketSnapshot, SentimentContext
from shared.strategies import (
BollingerBreakoutStrategy,
LiquidityStrategy,
MACDCrossoverStrategy,
MAStackStrategy,
MeanReversionStrategy,
MomentumStrategy,
NewsDrivenStrategy,
ValueStrategy,
VWAPStrategy,
)
from shared.fundamentals.alpha_vantage import AlphaVantageProvider
from shared.fundamentals.fmp import FMPProvider
from shared.fundamentals.yahoo import YahooFinanceProvider
from shared.fundamentals.rotating import RotatingProvider
from shared.fundamentals.cache import CachedFundamentalsProvider
from shared.telemetry import setup_telemetry
logger = logging.getLogger(__name__)
# Maximum number of recent sentiment scores to retain per ticker
_MAX_SENTIMENT_SCORES = 50
# Default strategy weights (equal weighting)
_DEFAULT_WEIGHTS: dict[str, float] = {
"momentum": 0.111,
"mean_reversion": 0.111,
"news_driven": 0.111,
"value": 0.111,
"macd_crossover": 0.111,
"bollinger_breakout": 0.111,
"vwap": 0.111,
"liquidity": 0.112,
"ma_stack": 0.111,
}
def _build_sentiment_context(
ticker: str,
scores: deque[float],
confidences: deque[float],
) -> SentimentContext:
"""Build a ``SentimentContext`` from accumulated per-ticker scores."""
score_list = list(scores)
conf_list = list(confidences)
return SentimentContext(
ticker=ticker,
avg_score=sum(score_list) / len(score_list) if score_list else 0.0,
article_count=len(score_list),
recent_scores=score_list[-10:],
avg_confidence=sum(conf_list) / len(conf_list) if conf_list else 0.0,
)
async def _consume_market_bars(
bars_consumer: StreamConsumer,
market_data: MarketDataManager,
shutdown_event: asyncio.Event,
bars_received_counter,
) -> None:
"""Consume OHLCV bars from ``market:bars`` and feed them to the MarketDataManager.
Runs as a concurrent task alongside the scored-article consumer.
"""
logger.info("Starting market:bars consumer")
async for _msg_id, data in bars_consumer.consume():
if shutdown_event.is_set():
break
try:
ticker = data.get("ticker")
if not ticker:
logger.warning("Received bar message without ticker field: %s", data)
continue
# Build bar_data dict without the ticker key (OHLCVBar doesn't have it)
bar_data = {k: v for k, v in data.items() if k != "ticker"}
market_data.add_bar(ticker, bar_data)
bars_received_counter.add(1)
logger.debug("Added bar for %s: close=%s", ticker, data.get("close"))
except Exception:
logger.exception("Error processing market bar: %s", data)
async def _consume_scored_articles(
articles_consumer: StreamConsumer,
market_data: MarketDataManager,
ensemble: WeightedEnsemble,
weights: dict[str, float],
publisher: StreamPublisher,
shutdown_event: asyncio.Event,
signals_generated,
per_strategy_signal_count,
db_session_factory: async_sessionmaker | None = None,
fundamentals_cache: dict[str, FundamentalsSnapshot] | None = None,
) -> None:
"""Consume scored articles from ``news:scored``, run the ensemble, and publish signals.
Runs as a concurrent task alongside the market-bars consumer.
"""
# Per-ticker sentiment accumulators
sentiment_scores: dict[str, deque[float]] = defaultdict(
lambda: deque(maxlen=_MAX_SENTIMENT_SCORES)
)
sentiment_confidences: dict[str, deque[float]] = defaultdict(
lambda: deque(maxlen=_MAX_SENTIMENT_SCORES)
)
logger.info("Starting news:scored consumer")
async for _msg_id, data in articles_consumer.consume():
if shutdown_event.is_set():
break
try:
article = ScoredArticle.model_validate(data)
ticker = article.ticker
# Update sentiment accumulators
sentiment_scores[ticker].append(article.sentiment_score)
sentiment_confidences[ticker].append(article.confidence)
# Build sentiment context
sentiment = _build_sentiment_context(
ticker,
sentiment_scores[ticker],
sentiment_confidences[ticker],
)
# Get market snapshot (may be None if no bars received yet)
snapshot = market_data.get_snapshot(ticker)
if snapshot is None:
# Create a minimal snapshot from sentiment data alone
# (the news_driven strategy does not require market indicators)
snapshot = MarketSnapshot(
ticker=ticker,
current_price=0.0,
open=0.0,
high=0.0,
low=0.0,
close=0.0,
volume=0.0,
)
# Inject fundamentals into snapshot
if fundamentals_cache:
snapshot.fundamentals = fundamentals_cache.get(ticker)
# Run ensemble
signal_result = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
if signal_result is not None:
# Inject current price for trade executor position sizing
if snapshot and snapshot.current_price > 0:
if signal_result.sentiment_context is None:
signal_result.sentiment_context = {}
signal_result.sentiment_context["current_price"] = snapshot.current_price
# Persist signal to DB
if db_session_factory is not None:
try:
async with db_session_factory() as session:
direction_map = {
"LONG": SignalDirectionModel.LONG,
"SHORT": SignalDirectionModel.SHORT,
"NEUTRAL": SignalDirectionModel.NEUTRAL,
}
db_signal = SignalModel(
id=signal_result.signal_id,
ticker=ticker,
direction=direction_map[signal_result.direction.value],
strength=signal_result.strength,
strategy_sources=signal_result.strategy_sources,
sentiment_score=sentiment.avg_score if sentiment else None,
acted_on=False,
)
session.add(db_signal)
await session.commit()
except Exception:
logger.exception("Failed to persist signal to DB")
await publisher.publish(signal_result.model_dump(mode="json"))
signals_generated.add(1)
for src in signal_result.strategy_sources:
strategy_name = src.split(":")[0]
per_strategy_signal_count.add(1, {"strategy": strategy_name})
logger.info(
"Signal generated: %s %s strength=%.4f sources=%s",
signal_result.direction.value,
ticker,
signal_result.strength,
signal_result.strategy_sources,
)
except Exception:
logger.exception(
"Error processing scored article: %s", data.get("title", "<unknown>")
)
async def _refresh_fundamentals(
provider: CachedFundamentalsProvider,
cache: dict[str, FundamentalsSnapshot],
watchlist: list[str],
shutdown_event: asyncio.Event,
) -> None:
"""Periodically refresh fundamental data for all watchlist tickers."""
while not shutdown_event.is_set():
await asyncio.sleep(3600 * 24) # 24 hours
if shutdown_event.is_set():
break
logger.info("Starting daily fundamentals refresh")
for ticker in watchlist:
if shutdown_event.is_set():
break
try:
snap = await provider.fetch(ticker)
if snap:
cache[ticker] = snap
except Exception:
logger.exception("Failed to refresh fundamentals for %s", ticker)
logger.info("Fundamentals refresh complete")
async def run(config: SignalGeneratorConfig | None = None) -> None:
"""Main service loop.
Connects to Redis, initialises strategies and telemetry, then
continuously consumes from ``news:scored`` and ``market:bars``,
publishing qualifying signals to ``signals:generated``.
"""
if config is None:
config = SignalGeneratorConfig()
logging.basicConfig(level=config.log_level)
logger.info("Starting Signal Generator service")
# --- Telemetry ---
meter = setup_telemetry("signal-generator", config.otel_metrics_port)
signals_generated = meter.create_counter(
"signals_generated",
description="Total trade signals emitted by the signal generator",
)
per_strategy_signal_count = meter.create_counter(
"per_strategy_signal_count",
description="Signals emitted, broken down by strategy",
)
bars_received_counter = meter.create_counter(
"bars_received",
description="Total OHLCV bars received from market:bars stream",
)
# --- Redis ---
redis = Redis.from_url(config.redis_url, decode_responses=False)
articles_consumer = StreamConsumer(
redis, "news:scored", "signal-generator", "worker-1"
)
bars_consumer = StreamConsumer(
redis, "market:bars", "signal-generator", "bars-worker"
)
publisher = StreamPublisher(redis, "signals:generated")
# --- Market data ---
market_data = MarketDataManager()
# --- Strategies ---
strategies = [
MomentumStrategy(),
MeanReversionStrategy(),
NewsDrivenStrategy(),
ValueStrategy(),
MACDCrossoverStrategy(),
BollingerBreakoutStrategy(),
VWAPStrategy(),
LiquidityStrategy(),
MAStackStrategy(),
]
ensemble = WeightedEnsemble(strategies, threshold=config.signal_strength_threshold)
# --- Strategy weights (default equal; could load from DB) ---
weights = dict(_DEFAULT_WEIGHTS)
# --- Database (for persisting signals) ---
db_session_factory = None
try:
_engine, db_session_factory = create_db(config)
logger.info("Database session factory initialised for signal persistence")
except Exception:
logger.exception("Failed to initialise DB — signals will NOT be persisted")
# --- Fundamentals ---
fundamentals_cache: dict[str, FundamentalsSnapshot] = {}
cached_fundamentals_provider: CachedFundamentalsProvider | None = None
try:
providers = []
if config.alpha_vantage_api_key:
providers.append(AlphaVantageProvider(api_key=config.alpha_vantage_api_key))
if config.fmp_api_key:
providers.append(FMPProvider(api_key=config.fmp_api_key))
providers.append(YahooFinanceProvider()) # no API key needed
if providers and db_session_factory is not None:
rotating = RotatingProvider(providers)
cached_fundamentals_provider = CachedFundamentalsProvider(
rotating, db_session_factory, cache_ttl_hours=config.fundamentals_cache_ttl_hours,
)
# Pre-fetch fundamentals for watchlist
for ticker in config.watchlist:
try:
snap = await cached_fundamentals_provider.fetch(ticker)
if snap:
fundamentals_cache[ticker] = snap
logger.info("Loaded fundamentals for %s", ticker)
except Exception:
logger.exception("Failed to fetch fundamentals for %s", ticker)
logger.info("Fundamentals loaded for %d/%d tickers", len(fundamentals_cache), len(config.watchlist))
except Exception:
logger.exception("Failed to initialise fundamentals — strategies will run without fundamental data")
logger.info(
"Consuming from news:scored and market:bars, publishing to signals:generated"
)
# Graceful shutdown on SIGTERM/SIGINT
shutdown_event = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, shutdown_event.set)
# --- Run both consumers concurrently ---
try:
async with asyncio.TaskGroup() as tg:
tg.create_task(
_consume_scored_articles(
articles_consumer,
market_data,
ensemble,
weights,
publisher,
shutdown_event,
signals_generated,
per_strategy_signal_count,
db_session_factory,
fundamentals_cache,
)
)
tg.create_task(
_consume_market_bars(
bars_consumer,
market_data,
shutdown_event,
bars_received_counter,
)
)
if cached_fundamentals_provider is not None:
tg.create_task(
_refresh_fundamentals(
cached_fundamentals_provider,
fundamentals_cache,
config.watchlist,
shutdown_event,
)
)
finally:
await redis.aclose()
logger.info("Signal generator stopped gracefully")
def main() -> None:
"""CLI entry point."""
asyncio.run(run())
if __name__ == "__main__":
main()