feat: real data pipeline — market data, DB persistence, portfolio sync, signal-trade linkage

Wire the trading bot to real Alpaca market data and persist pipeline
state to the database so the dashboard displays live information.

- Add market-data service fetching OHLCV bars from Alpaca, publishing
  to market:bars Redis Stream; signal generator consumes bars and
  injects current_price into signals for position sizing
- Sentiment analyzer now persists Article + ArticleSentiment rows to
  DB after scoring, with duplicate and error handling
- API gateway runs a background portfolio sync task that snapshots
  Alpaca account state into PortfolioSnapshot/Position DB tables
  during market hours
- TradeSignal carries a signal_id UUID; signal generator and trade
  executor both persist their records to DB with cross-references
- 303 unit tests pass (57 new tests added)
This commit is contained in:
Viktor Barzin 2026-02-22 19:52:45 +00:00
parent 5a6b20c8f1
commit e2a3bd456d
No known key found for this signature in database
GPG key ID: 0EB088298288D958
19 changed files with 2238 additions and 72 deletions

View file

@ -16,6 +16,12 @@ class ApiGatewayConfig(BaseConfig):
access_token_expire_minutes: int = 15
refresh_token_expire_days: int = 7
# Alpaca brokerage credentials (for portfolio sync)
alpaca_api_key: str = ""
alpaca_secret_key: str = ""
paper_trading: bool = True
snapshot_interval_seconds: int = 60
# CORS settings
cors_origins: list[str] = ["http://localhost:5173"]

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import AsyncIterator
@ -43,9 +44,23 @@ def create_app(config: ApiGatewayConfig | None = None) -> FastAPI:
)
app.state.config = config
# Start portfolio sync background task
from services.api_gateway.tasks.portfolio_sync import portfolio_sync_loop
sync_task = asyncio.create_task(
portfolio_sync_loop(config, session_factory)
)
logger.info("API Gateway started")
yield
# Cancel the sync task
sync_task.cancel()
try:
await sync_task
except asyncio.CancelledError:
pass
# Cleanup
await app.state.redis.aclose()
await engine.dispose()

View file

View file

@ -0,0 +1,155 @@
"""Background task that periodically snapshots Alpaca account state into the DB.
Runs on a configurable interval (default 60s) during US market hours,
creating ``PortfolioSnapshot`` rows and upserting ``Position`` rows so
the dashboard portfolio page reflects real brokerage data.
"""
from __future__ import annotations
import asyncio
import logging
from datetime import datetime, time, timezone
from zoneinfo import ZoneInfo
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import async_sessionmaker
from services.api_gateway.config import ApiGatewayConfig
from shared.broker.alpaca_broker import AlpacaBroker
from shared.models.timeseries import PortfolioSnapshot
from shared.models.trading import Position
logger = logging.getLogger(__name__)
# US Eastern timezone for market hours check
_ET = ZoneInfo("America/New_York")
_MARKET_OPEN = time(9, 30)
_MARKET_CLOSE = time(16, 0)
def is_market_open(now_utc: datetime | None = None) -> bool:
"""Return ``True`` if the US stock market is currently open.
Checks for weekday (Mon-Fri) and time between 9:30 AM and 4:00 PM ET.
"""
if now_utc is None:
now_utc = datetime.now(timezone.utc)
now_et = now_utc.astimezone(_ET)
# Weekday check: Monday=0 ... Friday=4
if now_et.weekday() >= 5:
return False
return _MARKET_OPEN <= now_et.time() < _MARKET_CLOSE
async def _sync_once(
broker: AlpacaBroker,
session_factory: async_sessionmaker,
) -> None:
"""Perform a single portfolio snapshot and position upsert cycle."""
now = datetime.now(timezone.utc)
# 1. Snapshot account state
account = await broker.get_account()
snapshot = PortfolioSnapshot(
timestamp=now,
total_value=account.portfolio_value,
cash=account.cash,
positions_value=account.portfolio_value - account.cash,
daily_pnl=0.0,
)
# 2. Fetch broker positions
broker_positions = await broker.get_positions()
broker_tickers = {p.ticker for p in broker_positions}
async with session_factory() as session:
async with session.begin():
# Insert portfolio snapshot
session.add(snapshot)
# Upsert positions
for pos_info in broker_positions:
result = await session.execute(
select(Position).where(Position.ticker == pos_info.ticker)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.qty = pos_info.qty
existing.avg_entry = pos_info.avg_entry
existing.unrealized_pnl = pos_info.unrealized_pnl
else:
new_pos = Position(
ticker=pos_info.ticker,
qty=pos_info.qty,
avg_entry=pos_info.avg_entry,
unrealized_pnl=pos_info.unrealized_pnl,
stop_loss=None,
take_profit=None,
)
session.add(new_pos)
# 3. Remove positions that are no longer held at the broker
if broker_tickers:
await session.execute(
delete(Position).where(Position.ticker.notin_(broker_tickers))
)
else:
# No positions at broker — delete all local positions
await session.execute(delete(Position))
logger.info(
"Portfolio sync complete: value=%.2f, cash=%.2f, positions=%d",
account.portfolio_value,
account.cash,
len(broker_positions),
)
async def portfolio_sync_loop(
config: ApiGatewayConfig,
session_factory: async_sessionmaker,
) -> None:
"""Run the portfolio sync loop until cancelled.
Parameters
----------
config:
API Gateway configuration containing Alpaca credentials and
the snapshot interval.
session_factory:
SQLAlchemy async session factory for DB access.
"""
if not config.alpaca_api_key or not config.alpaca_secret_key:
logger.warning(
"Alpaca API credentials not configured — portfolio sync disabled"
)
return
broker = AlpacaBroker(
api_key=config.alpaca_api_key,
secret_key=config.alpaca_secret_key,
paper=config.paper_trading,
)
logger.info(
"Portfolio sync started (interval=%ds, paper=%s)",
config.snapshot_interval_seconds,
config.paper_trading,
)
while True:
try:
if is_market_open():
await _sync_once(broker, session_factory)
else:
logger.debug("Market closed — skipping portfolio snapshot")
except asyncio.CancelledError:
logger.info("Portfolio sync task cancelled — shutting down")
raise
except Exception:
logger.exception("Portfolio sync error — will retry next interval")
await asyncio.sleep(config.snapshot_interval_seconds)

View file

@ -0,0 +1 @@
"""Market Data service -- fetches OHLCV bars from Alpaca and publishes to Redis Streams."""

View file

@ -0,0 +1,3 @@
from services.market_data.main import main
main()

View file

@ -0,0 +1,16 @@
"""Configuration for the market data service."""
from shared.config import BaseConfig
class MarketDataConfig(BaseConfig):
"""Extends BaseConfig with market-data-specific settings."""
watchlist: list[str] = ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"]
bar_timeframe: str = "5Min"
poll_interval_seconds: int = 60
historical_bars: int = 100
alpaca_api_key: str = ""
alpaca_secret_key: str = ""
model_config = {"env_prefix": "TRADING_"}

View file

@ -0,0 +1,257 @@
"""Market Data service -- main entry point.
Fetches historical and live OHLCV bars from Alpaca's market data API
and publishes them to the ``market:bars`` Redis Stream for consumption
by the signal generator and other downstream services.
"""
from __future__ import annotations
import asyncio
import logging
import signal
from datetime import datetime, timedelta, timezone
from redis.asyncio import Redis
from services.market_data.config import MarketDataConfig
from shared.redis_streams import StreamPublisher
from shared.telemetry import setup_telemetry
logger = logging.getLogger(__name__)
MARKET_BARS_STREAM = "market:bars"
def _parse_timeframe(timeframe_str: str):
"""Parse a timeframe string like '5Min' into an Alpaca TimeFrame object.
Returns a ``TimeFrame`` instance suitable for ``StockBarsRequest``.
"""
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
# Supported formats: "1Min", "5Min", "15Min", "1Hour", "1Day"
tf_map = {
"1Min": TimeFrame(1, TimeFrameUnit.Minute),
"5Min": TimeFrame(5, TimeFrameUnit.Minute),
"15Min": TimeFrame(15, TimeFrameUnit.Minute),
"1Hour": TimeFrame(1, TimeFrameUnit.Hour),
"1Day": TimeFrame(1, TimeFrameUnit.Day),
}
tf = tf_map.get(timeframe_str)
if tf is None:
raise ValueError(
f"Unsupported timeframe '{timeframe_str}'. "
f"Supported values: {list(tf_map.keys())}"
)
return tf
def _bar_to_dict(ticker: str, bar) -> dict:
"""Convert an Alpaca Bar object to a flat dictionary for Redis publishing."""
return {
"ticker": ticker,
"timestamp": bar.timestamp.isoformat(),
"open": float(bar.open),
"high": float(bar.high),
"low": float(bar.low),
"close": float(bar.close),
"volume": float(bar.volume),
}
async def _fetch_historical_bars(
client,
watchlist: list[str],
timeframe,
limit: int,
publisher: StreamPublisher,
bars_published_counter,
) -> int:
"""Fetch historical bars for each ticker and publish to Redis.
Returns the total number of bars published.
"""
from alpaca.data.requests import StockBarsRequest
total_published = 0
# Use a start time far enough back to get the requested number of bars
start = datetime.now(timezone.utc) - timedelta(days=30)
for ticker in watchlist:
try:
request = StockBarsRequest(
symbol_or_symbols=[ticker],
timeframe=timeframe,
start=start,
limit=limit,
)
bars = await asyncio.to_thread(client.get_stock_bars, request)
ticker_bars = bars[ticker] if ticker in bars else []
for bar in ticker_bars:
msg = _bar_to_dict(ticker, bar)
await publisher.publish(msg)
total_published += 1
logger.info(
"Published %d historical bars for %s",
len(ticker_bars),
ticker,
)
except Exception:
logger.exception("Failed to fetch historical bars for %s", ticker)
if total_published:
bars_published_counter.add(total_published)
return total_published
async def _poll_latest_bars(
client,
watchlist: list[str],
timeframe,
publisher: StreamPublisher,
bars_published_counter,
) -> int:
"""Fetch the latest bar for each ticker and publish to Redis.
Returns the number of bars published.
"""
from alpaca.data.requests import StockBarsRequest
published = 0
# Fetch bars from the last 10 minutes to ensure we get at least one
start = datetime.now(timezone.utc) - timedelta(minutes=10)
for ticker in watchlist:
try:
request = StockBarsRequest(
symbol_or_symbols=[ticker],
timeframe=timeframe,
start=start,
limit=1,
)
bars = await asyncio.to_thread(client.get_stock_bars, request)
ticker_bars = bars[ticker] if ticker in bars else []
if ticker_bars:
# Publish only the most recent bar
bar = ticker_bars[-1]
msg = _bar_to_dict(ticker, bar)
await publisher.publish(msg)
published += 1
logger.debug("Published latest bar for %s: close=%.2f", ticker, bar.close)
except Exception:
logger.exception("Failed to fetch latest bar for %s", ticker)
if published:
bars_published_counter.add(published)
return published
async def run(config: MarketDataConfig | None = None) -> None:
"""Main service loop.
Connects to Alpaca and Redis, fetches historical bars on startup,
then polls for new bars at the configured interval.
"""
if config is None:
config = MarketDataConfig()
logging.basicConfig(level=config.log_level)
logger.info("Starting Market Data service")
# --- Telemetry ---
meter = setup_telemetry("market-data", config.otel_metrics_port)
bars_published_counter = meter.create_counter(
"market_data.bars_published",
description="Total OHLCV bars published to market:bars stream",
)
poll_errors_counter = meter.create_counter(
"market_data.poll_errors",
description="Total poll cycle errors",
)
# --- Alpaca client ---
from alpaca.data.historical import StockHistoricalDataClient
client = StockHistoricalDataClient(
api_key=config.alpaca_api_key,
secret_key=config.alpaca_secret_key,
)
# --- Redis ---
redis = Redis.from_url(config.redis_url, decode_responses=False)
publisher = StreamPublisher(redis, MARKET_BARS_STREAM)
# --- Parse timeframe ---
timeframe = _parse_timeframe(config.bar_timeframe)
# --- Graceful shutdown ---
shutdown_event = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, shutdown_event.set)
try:
# Fetch historical bars on startup
logger.info(
"Fetching %d historical bars for watchlist: %s",
config.historical_bars,
config.watchlist,
)
total = await _fetch_historical_bars(
client,
config.watchlist,
timeframe,
config.historical_bars,
publisher,
bars_published_counter,
)
logger.info("Historical backfill complete: %d total bars published", total)
# Poll loop
logger.info(
"Starting poll loop (interval=%ds) for watchlist: %s",
config.poll_interval_seconds,
config.watchlist,
)
while not shutdown_event.is_set():
try:
await asyncio.wait_for(
shutdown_event.wait(),
timeout=config.poll_interval_seconds,
)
break # Shutdown signaled
except asyncio.TimeoutError:
pass # Normal timeout — time to poll
try:
count = await _poll_latest_bars(
client,
config.watchlist,
timeframe,
publisher,
bars_published_counter,
)
logger.info("Poll cycle complete: %d bars published", count)
except Exception:
logger.exception("Poll cycle failed")
poll_errors_counter.add(1)
finally:
await redis.aclose()
logger.info("Market data service stopped gracefully")
def main() -> None:
"""CLI entry point."""
asyncio.run(run())
if __name__ == "__main__":
main()

View file

@ -3,7 +3,8 @@
Consumes ``news:raw`` articles from Redis Streams, scores them using a
tiered approach (FinBERT first, Ollama fallback for low-confidence results),
extracts ticker mentions, and publishes ``ScoredArticle`` messages to
``news:scored``.
``news:scored``. Also persists scored articles to the database (articles +
article_sentiments tables) so the dashboard can display real data.
"""
from __future__ import annotations
@ -14,11 +15,15 @@ import signal
import time
from redis.asyncio import Redis
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import async_sessionmaker
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
from services.sentiment_analyzer.config import SentimentAnalyzerConfig
from services.sentiment_analyzer.ticker_extractor import extract_tickers
from shared.db import create_db
from shared.models.news import Article, ArticleSentiment
from shared.redis_streams import StreamConsumer, StreamPublisher
from shared.schemas.news import RawArticle, ScoredArticle
from shared.telemetry import setup_telemetry
@ -33,6 +38,7 @@ async def process_article(
publisher: StreamPublisher,
config: SentimentAnalyzerConfig,
counters: dict,
db_session_factory: async_sessionmaker | None = None,
) -> None:
"""Score a single article and publish one ScoredArticle per extracted ticker.
@ -50,6 +56,9 @@ async def process_article(
Service configuration (confidence threshold, etc.).
counters:
Dict of OpenTelemetry counter/histogram instruments.
db_session_factory:
Optional async session factory for persisting to the DB.
When ``None``, DB persistence is skipped (backward compatible).
"""
start = time.monotonic()
@ -103,6 +112,46 @@ async def process_article(
counters["articles_scored"].add(1)
# --- Step 5: Persist to DB ---
if db_session_factory is not None:
try:
async with db_session_factory() as session:
db_article = Article(
source=article.source,
url=article.url,
title=article.title,
published_at=article.published_at,
fetched_at=article.fetched_at,
content_hash=article.content_hash,
)
session.add(db_article)
for ticker in tickers:
sentiment = ArticleSentiment(
article_id=db_article.id,
ticker=ticker,
score=score,
confidence=confidence,
model_used=model_used,
)
session.add(sentiment)
await session.commit()
logger.debug(
"Persisted article '%s' with %d sentiments to DB",
article.title[:60],
len(tickers),
)
except IntegrityError:
logger.debug(
"Article already exists in DB (content_hash=%s), skipping",
article.content_hash,
)
except Exception:
logger.exception(
"Failed to persist article to DB: %s", article.title[:60]
)
async def run(config: SentimentAnalyzerConfig | None = None) -> None:
"""Main service loop.
@ -150,6 +199,14 @@ async def run(config: SentimentAnalyzerConfig | None = None) -> None:
)
ollama = OllamaAnalyzer(model=config.ollama_model, host=config.ollama_host)
# --- Database ---
db_session_factory = None
try:
_engine, db_session_factory = create_db(config)
logger.info("Database session factory initialised")
except Exception:
logger.exception("Failed to initialise DB — articles will NOT be persisted")
logger.info("Consuming from news:raw, publishing to news:scored")
# Graceful shutdown on SIGTERM/SIGINT
@ -165,7 +222,7 @@ async def run(config: SentimentAnalyzerConfig | None = None) -> None:
break
try:
article = RawArticle.model_validate(data)
await process_article(article, finbert, ollama, publisher, config, counters)
await process_article(article, finbert, ollama, publisher, config, counters, db_session_factory)
except Exception:
logger.exception("Error processing article: %s", data.get("title", "<unknown>"))
finally:

View file

@ -1,8 +1,9 @@
"""Signal Generator service -- main entry point.
Consumes ``news:scored`` articles from Redis Streams, updates sentiment
context per ticker, runs the weighted ensemble of trading strategies, and
publishes qualifying ``TradeSignal`` messages to ``signals:generated``.
Consumes ``news:scored`` articles and ``market:bars`` OHLCV data from
Redis Streams, updates sentiment context and market data per ticker,
runs the weighted ensemble of trading strategies, and publishes
qualifying ``TradeSignal`` messages to ``signals:generated``.
"""
from __future__ import annotations
@ -10,16 +11,21 @@ from __future__ import annotations
import asyncio
import logging
import signal
import uuid
from collections import defaultdict, deque
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import async_sessionmaker
from services.signal_generator.config import SignalGeneratorConfig
from services.signal_generator.ensemble import WeightedEnsemble
from services.signal_generator.market_data import MarketDataManager
from shared.db import create_db
from shared.models.trading import Signal as SignalModel
from shared.models.trading import SignalDirection as SignalDirectionModel
from shared.redis_streams import StreamConsumer, StreamPublisher
from shared.schemas.news import ScoredArticle
from shared.schemas.trading import SentimentContext
from shared.schemas.trading import MarketSnapshot, SentimentContext
from shared.strategies import MeanReversionStrategy, MomentumStrategy, NewsDrivenStrategy
from shared.telemetry import setup_telemetry
@ -53,12 +59,150 @@ def _build_sentiment_context(
)
async def _consume_market_bars(
bars_consumer: StreamConsumer,
market_data: MarketDataManager,
shutdown_event: asyncio.Event,
bars_received_counter,
) -> None:
"""Consume OHLCV bars from ``market:bars`` and feed them to the MarketDataManager.
Runs as a concurrent task alongside the scored-article consumer.
"""
logger.info("Starting market:bars consumer")
async for _msg_id, data in bars_consumer.consume():
if shutdown_event.is_set():
break
try:
ticker = data.get("ticker")
if not ticker:
logger.warning("Received bar message without ticker field: %s", data)
continue
# Build bar_data dict without the ticker key (OHLCVBar doesn't have it)
bar_data = {k: v for k, v in data.items() if k != "ticker"}
market_data.add_bar(ticker, bar_data)
bars_received_counter.add(1)
logger.debug("Added bar for %s: close=%s", ticker, data.get("close"))
except Exception:
logger.exception("Error processing market bar: %s", data)
async def _consume_scored_articles(
articles_consumer: StreamConsumer,
market_data: MarketDataManager,
ensemble: WeightedEnsemble,
weights: dict[str, float],
publisher: StreamPublisher,
shutdown_event: asyncio.Event,
signals_generated,
per_strategy_signal_count,
db_session_factory: async_sessionmaker | None = None,
) -> None:
"""Consume scored articles from ``news:scored``, run the ensemble, and publish signals.
Runs as a concurrent task alongside the market-bars consumer.
"""
# Per-ticker sentiment accumulators
sentiment_scores: dict[str, deque[float]] = defaultdict(
lambda: deque(maxlen=_MAX_SENTIMENT_SCORES)
)
sentiment_confidences: dict[str, deque[float]] = defaultdict(
lambda: deque(maxlen=_MAX_SENTIMENT_SCORES)
)
logger.info("Starting news:scored consumer")
async for _msg_id, data in articles_consumer.consume():
if shutdown_event.is_set():
break
try:
article = ScoredArticle.model_validate(data)
ticker = article.ticker
# Update sentiment accumulators
sentiment_scores[ticker].append(article.sentiment_score)
sentiment_confidences[ticker].append(article.confidence)
# Build sentiment context
sentiment = _build_sentiment_context(
ticker,
sentiment_scores[ticker],
sentiment_confidences[ticker],
)
# Get market snapshot (may be None if no bars received yet)
snapshot = market_data.get_snapshot(ticker)
if snapshot is None:
# Create a minimal snapshot from sentiment data alone
# (the news_driven strategy does not require market indicators)
snapshot = MarketSnapshot(
ticker=ticker,
current_price=0.0,
open=0.0,
high=0.0,
low=0.0,
close=0.0,
volume=0.0,
)
# Run ensemble
signal_result = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
if signal_result is not None:
# Inject current price for trade executor position sizing
if snapshot and snapshot.current_price > 0:
if signal_result.sentiment_context is None:
signal_result.sentiment_context = {}
signal_result.sentiment_context["current_price"] = snapshot.current_price
# Persist signal to DB
if db_session_factory is not None:
try:
async with db_session_factory() as session:
direction_map = {
"LONG": SignalDirectionModel.LONG,
"SHORT": SignalDirectionModel.SHORT,
"NEUTRAL": SignalDirectionModel.NEUTRAL,
}
db_signal = SignalModel(
id=signal_result.signal_id,
ticker=ticker,
direction=direction_map[signal_result.direction.value],
strength=signal_result.strength,
strategy_sources=signal_result.strategy_sources,
sentiment_score=sentiment.avg_score if sentiment else None,
acted_on=False,
)
session.add(db_signal)
await session.commit()
except Exception:
logger.exception("Failed to persist signal to DB")
await publisher.publish(signal_result.model_dump(mode="json"))
signals_generated.add(1)
for src in signal_result.strategy_sources:
strategy_name = src.split(":")[0]
per_strategy_signal_count.add(1, {"strategy": strategy_name})
logger.info(
"Signal generated: %s %s strength=%.4f sources=%s",
signal_result.direction.value,
ticker,
signal_result.strength,
signal_result.strategy_sources,
)
except Exception:
logger.exception(
"Error processing scored article: %s", data.get("title", "<unknown>")
)
async def run(config: SignalGeneratorConfig | None = None) -> None:
"""Main service loop.
Connects to Redis, initialises strategies and telemetry, then
continuously consumes from ``news:scored`` and publishes qualifying
signals to ``signals:generated``.
continuously consumes from ``news:scored`` and ``market:bars``,
publishing qualifying signals to ``signals:generated``.
"""
if config is None:
config = SignalGeneratorConfig()
@ -76,10 +220,19 @@ async def run(config: SignalGeneratorConfig | None = None) -> None:
"per_strategy_signal_count",
description="Signals emitted, broken down by strategy",
)
bars_received_counter = meter.create_counter(
"bars_received",
description="Total OHLCV bars received from market:bars stream",
)
# --- Redis ---
redis = Redis.from_url(config.redis_url, decode_responses=False)
consumer = StreamConsumer(redis, "news:scored", "signal-generator", "worker-1")
articles_consumer = StreamConsumer(
redis, "news:scored", "signal-generator", "worker-1"
)
bars_consumer = StreamConsumer(
redis, "market:bars", "signal-generator", "bars-worker"
)
publisher = StreamPublisher(redis, "signals:generated")
# --- Market data ---
@ -96,11 +249,17 @@ async def run(config: SignalGeneratorConfig | None = None) -> None:
# --- Strategy weights (default equal; could load from DB) ---
weights = dict(_DEFAULT_WEIGHTS)
# --- Per-ticker sentiment accumulators ---
sentiment_scores: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES))
sentiment_confidences: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES))
# --- Database (for persisting signals) ---
db_session_factory = None
try:
_engine, db_session_factory = create_db(config)
logger.info("Database session factory initialised for signal persistence")
except Exception:
logger.exception("Failed to initialise DB — signals will NOT be persisted")
logger.info("Consuming from news:scored, publishing to signals:generated")
logger.info(
"Consuming from news:scored and market:bars, publishing to signals:generated"
)
# Graceful shutdown on SIGTERM/SIGINT
shutdown_event = asyncio.Event()
@ -108,62 +267,30 @@ async def run(config: SignalGeneratorConfig | None = None) -> None:
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, shutdown_event.set)
# --- Consume loop ---
# --- Run both consumers concurrently ---
try:
async for _msg_id, data in consumer.consume():
if shutdown_event.is_set():
break
try:
article = ScoredArticle.model_validate(data)
ticker = article.ticker
# Update sentiment accumulators
sentiment_scores[ticker].append(article.sentiment_score)
sentiment_confidences[ticker].append(article.confidence)
# Build sentiment context
sentiment = _build_sentiment_context(
ticker,
sentiment_scores[ticker],
sentiment_confidences[ticker],
async with asyncio.TaskGroup() as tg:
tg.create_task(
_consume_scored_articles(
articles_consumer,
market_data,
ensemble,
weights,
publisher,
shutdown_event,
signals_generated,
per_strategy_signal_count,
db_session_factory,
)
# Get market snapshot (may be None if no bars received yet)
snapshot = market_data.get_snapshot(ticker)
if snapshot is None:
# Create a minimal snapshot from sentiment data alone
# (the news_driven strategy does not require market indicators)
from shared.schemas.trading import MarketSnapshot
snapshot = MarketSnapshot(
ticker=ticker,
current_price=0.0,
open=0.0,
high=0.0,
low=0.0,
close=0.0,
volume=0.0,
)
# Run ensemble
signal_result = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
if signal_result is not None:
await publisher.publish(signal_result.model_dump(mode="json"))
signals_generated.add(1)
for src in signal_result.strategy_sources:
strategy_name = src.split(":")[0]
per_strategy_signal_count.add(1, {"strategy": strategy_name})
logger.info(
"Signal generated: %s %s strength=%.4f sources=%s",
signal_result.direction.value,
ticker,
signal_result.strength,
signal_result.strategy_sources,
)
except Exception:
logger.exception("Error processing scored article: %s", data.get("title", "<unknown>"))
)
tg.create_task(
_consume_market_bars(
bars_consumer,
market_data,
shutdown_event,
bars_received_counter,
)
)
finally:
await redis.aclose()
logger.info("Signal generator stopped gracefully")

View file

@ -15,10 +15,15 @@ import time
import uuid
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import async_sessionmaker
from services.trade_executor.config import TradeExecutorConfig
from services.trade_executor.risk_manager import RiskManager
from shared.broker.alpaca_broker import AlpacaBroker
from shared.db import create_db
from shared.models.trading import Trade as TradeModel
from shared.models.trading import TradeSide as TradeSideModel
from shared.models.trading import TradeStatus as TradeStatusModel
from shared.redis_streams import StreamConsumer, StreamPublisher
from shared.schemas.trading import (
OrderRequest,
@ -39,6 +44,7 @@ async def process_signal(
broker: AlpacaBroker,
publisher: StreamPublisher,
counters: dict,
db_session_factory: async_sessionmaker | None = None,
) -> None:
"""Process a single trade signal: risk check, order, record, publish.
@ -54,6 +60,8 @@ async def process_signal(
Publishes execution results to ``trades:executed``.
counters:
Dict of OpenTelemetry counter/histogram instruments.
db_session_factory:
Optional async session factory for persisting trades to the DB.
"""
# --- Step 1: risk check ---
approved, reason = await risk_manager.check_risk(signal)
@ -93,12 +101,42 @@ async def process_signal(
qty=result.qty,
price=result.filled_price or 0.0,
status=result.status,
signal_id=None,
signal_id=signal.signal_id,
strategy_id=None,
timestamp=result.timestamp,
)
# --- Step 6: publish to trades:executed ---
# --- Step 6: persist trade to DB ---
if db_session_factory is not None:
try:
side_map = {
OrderSide.BUY: TradeSideModel.BUY,
OrderSide.SELL: TradeSideModel.SELL,
}
status_map = {
OrderStatus.PENDING: TradeStatusModel.PENDING,
OrderStatus.FILLED: TradeStatusModel.FILLED,
OrderStatus.CANCELLED: TradeStatusModel.CANCELLED,
OrderStatus.REJECTED: TradeStatusModel.REJECTED,
}
async with db_session_factory() as session:
db_trade = TradeModel(
id=trade_id,
ticker=signal.ticker,
side=side_map[side],
qty=result.qty,
price=result.filled_price or 0.0,
timestamp=str(result.timestamp),
signal_id=signal.signal_id,
status=status_map.get(result.status, TradeStatusModel.PENDING),
)
session.add(db_trade)
await session.commit()
logger.debug("Persisted trade %s to DB (signal_id=%s)", trade_id, signal.signal_id)
except Exception:
logger.exception("Failed to persist trade to DB")
# --- Step 7: publish to trades:executed ---
await publisher.publish(execution.model_dump(mode="json"))
counters["trades_executed"].add(1)
logger.info(
@ -157,6 +195,14 @@ async def run(config: TradeExecutorConfig | None = None) -> None:
# --- Risk manager ---
risk_manager = RiskManager(config, broker)
# --- Database (for persisting trades) ---
db_session_factory = None
try:
_engine, db_session_factory = create_db(config)
logger.info("Database session factory initialised for trade persistence")
except Exception:
logger.exception("Failed to initialise DB — trades will NOT be persisted")
logger.info("Consuming from signals:generated, publishing to trades:executed")
# Graceful shutdown on SIGTERM/SIGINT
@ -172,7 +218,7 @@ async def run(config: TradeExecutorConfig | None = None) -> None:
break
try:
signal_msg = TradeSignal.model_validate(data)
await process_signal(signal_msg, risk_manager, broker, publisher, counters)
await process_signal(signal_msg, risk_manager, broker, publisher, counters, db_session_factory)
except Exception:
logger.exception("Error processing signal: %s", data)
finally: