From e2a3bd456d4795c8248ed1b8655c4bdfb95f0ed0 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Sun, 22 Feb 2026 19:52:45 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20real=20data=20pipeline=20=E2=80=94=20ma?= =?UTF-8?q?rket=20data,=20DB=20persistence,=20portfolio=20sync,=20signal-t?= =?UTF-8?q?rade=20linkage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wire the trading bot to real Alpaca market data and persist pipeline state to the database so the dashboard displays live information. - Add market-data service fetching OHLCV bars from Alpaca, publishing to market:bars Redis Stream; signal generator consumes bars and injects current_price into signals for position sizing - Sentiment analyzer now persists Article + ArticleSentiment rows to DB after scoring, with duplicate and error handling - API gateway runs a background portfolio sync task that snapshots Alpaca account state into PortfolioSnapshot/Position DB tables during market hours - TradeSignal carries a signal_id UUID; signal generator and trade executor both persist their records to DB with cross-references - 303 unit tests pass (57 new tests added) --- .env.example | 10 + docker-compose.yml | 13 + services/api_gateway/config.py | 6 + services/api_gateway/main.py | 15 + services/api_gateway/tasks/__init__.py | 0 services/api_gateway/tasks/portfolio_sync.py | 155 +++++++ services/market_data/__init__.py | 1 + services/market_data/__main__.py | 3 + services/market_data/config.py | 16 + services/market_data/main.py | 257 +++++++++++ services/sentiment_analyzer/main.py | 61 ++- services/signal_generator/main.py | 257 ++++++++--- services/trade_executor/main.py | 52 ++- shared/schemas/trading.py | 3 +- tests/services/test_market_data.py | 462 +++++++++++++++++++ tests/services/test_portfolio_sync.py | 456 ++++++++++++++++++ tests/services/test_sentiment_analyzer.py | 229 ++++++++- tests/services/test_signal_generator.py | 198 ++++++++ tests/services/test_trade_executor.py | 116 +++++ 19 files changed, 2238 insertions(+), 72 deletions(-) create mode 100644 services/api_gateway/tasks/__init__.py create mode 100644 services/api_gateway/tasks/portfolio_sync.py create mode 100644 services/market_data/__init__.py create mode 100644 services/market_data/__main__.py create mode 100644 services/market_data/config.py create mode 100644 services/market_data/main.py create mode 100644 tests/services/test_market_data.py create mode 100644 tests/services/test_portfolio_sync.py diff --git a/.env.example b/.env.example index 11a8909..6f8d24d 100644 --- a/.env.example +++ b/.env.example @@ -10,6 +10,16 @@ TRADING_LOG_LEVEL=INFO TRADING_ALPACA_API_KEY=your_api_key_here TRADING_ALPACA_SECRET_KEY=your_secret_key_here TRADING_ALPACA_BASE_URL=https://paper-api.alpaca.markets +TRADING_PAPER_TRADING=true + +# Market data service — watchlist tickers (comma-separated) +TRADING_WATCHLIST=["AAPL","TSLA","NVDA","MSFT","GOOGL"] +TRADING_BAR_TIMEFRAME=5Min +TRADING_POLL_INTERVAL_SECONDS=60 +TRADING_HISTORICAL_BARS=100 + +# Portfolio sync interval (seconds, api-gateway background task) +TRADING_SNAPSHOT_INTERVAL_SECONDS=60 # JWT — REQUIRED, generate with: python -c "import secrets; print(secrets.token_hex(32))" TRADING_JWT_SECRET_KEY= diff --git a/docker-compose.yml b/docker-compose.yml index d070b0b..236ba5a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -142,6 +142,19 @@ services: env_file: .env restart: unless-stopped + market-data: + build: + context: . + dockerfile: docker/Dockerfile.service + args: + EXTRAS: "trading" + SERVICE_MODULE: "market_data" + depends_on: + redis: + condition: service_healthy + env_file: .env + restart: unless-stopped + api-gateway: build: context: . diff --git a/services/api_gateway/config.py b/services/api_gateway/config.py index a078208..b915626 100644 --- a/services/api_gateway/config.py +++ b/services/api_gateway/config.py @@ -16,6 +16,12 @@ class ApiGatewayConfig(BaseConfig): access_token_expire_minutes: int = 15 refresh_token_expire_days: int = 7 + # Alpaca brokerage credentials (for portfolio sync) + alpaca_api_key: str = "" + alpaca_secret_key: str = "" + paper_trading: bool = True + snapshot_interval_seconds: int = 60 + # CORS settings cors_origins: list[str] = ["http://localhost:5173"] diff --git a/services/api_gateway/main.py b/services/api_gateway/main.py index f2df699..ec11e0e 100644 --- a/services/api_gateway/main.py +++ b/services/api_gateway/main.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import logging from contextlib import asynccontextmanager from typing import AsyncIterator @@ -43,9 +44,23 @@ def create_app(config: ApiGatewayConfig | None = None) -> FastAPI: ) app.state.config = config + # Start portfolio sync background task + from services.api_gateway.tasks.portfolio_sync import portfolio_sync_loop + + sync_task = asyncio.create_task( + portfolio_sync_loop(config, session_factory) + ) + logger.info("API Gateway started") yield + # Cancel the sync task + sync_task.cancel() + try: + await sync_task + except asyncio.CancelledError: + pass + # Cleanup await app.state.redis.aclose() await engine.dispose() diff --git a/services/api_gateway/tasks/__init__.py b/services/api_gateway/tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/api_gateway/tasks/portfolio_sync.py b/services/api_gateway/tasks/portfolio_sync.py new file mode 100644 index 0000000..f97468d --- /dev/null +++ b/services/api_gateway/tasks/portfolio_sync.py @@ -0,0 +1,155 @@ +"""Background task that periodically snapshots Alpaca account state into the DB. + +Runs on a configurable interval (default 60s) during US market hours, +creating ``PortfolioSnapshot`` rows and upserting ``Position`` rows so +the dashboard portfolio page reflects real brokerage data. +""" + +from __future__ import annotations + +import asyncio +import logging +from datetime import datetime, time, timezone +from zoneinfo import ZoneInfo + +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import async_sessionmaker + +from services.api_gateway.config import ApiGatewayConfig +from shared.broker.alpaca_broker import AlpacaBroker +from shared.models.timeseries import PortfolioSnapshot +from shared.models.trading import Position + +logger = logging.getLogger(__name__) + +# US Eastern timezone for market hours check +_ET = ZoneInfo("America/New_York") +_MARKET_OPEN = time(9, 30) +_MARKET_CLOSE = time(16, 0) + + +def is_market_open(now_utc: datetime | None = None) -> bool: + """Return ``True`` if the US stock market is currently open. + + Checks for weekday (Mon-Fri) and time between 9:30 AM and 4:00 PM ET. + """ + if now_utc is None: + now_utc = datetime.now(timezone.utc) + now_et = now_utc.astimezone(_ET) + # Weekday check: Monday=0 ... Friday=4 + if now_et.weekday() >= 5: + return False + return _MARKET_OPEN <= now_et.time() < _MARKET_CLOSE + + +async def _sync_once( + broker: AlpacaBroker, + session_factory: async_sessionmaker, +) -> None: + """Perform a single portfolio snapshot and position upsert cycle.""" + now = datetime.now(timezone.utc) + + # 1. Snapshot account state + account = await broker.get_account() + + snapshot = PortfolioSnapshot( + timestamp=now, + total_value=account.portfolio_value, + cash=account.cash, + positions_value=account.portfolio_value - account.cash, + daily_pnl=0.0, + ) + + # 2. Fetch broker positions + broker_positions = await broker.get_positions() + broker_tickers = {p.ticker for p in broker_positions} + + async with session_factory() as session: + async with session.begin(): + # Insert portfolio snapshot + session.add(snapshot) + + # Upsert positions + for pos_info in broker_positions: + result = await session.execute( + select(Position).where(Position.ticker == pos_info.ticker) + ) + existing = result.scalar_one_or_none() + + if existing is not None: + existing.qty = pos_info.qty + existing.avg_entry = pos_info.avg_entry + existing.unrealized_pnl = pos_info.unrealized_pnl + else: + new_pos = Position( + ticker=pos_info.ticker, + qty=pos_info.qty, + avg_entry=pos_info.avg_entry, + unrealized_pnl=pos_info.unrealized_pnl, + stop_loss=None, + take_profit=None, + ) + session.add(new_pos) + + # 3. Remove positions that are no longer held at the broker + if broker_tickers: + await session.execute( + delete(Position).where(Position.ticker.notin_(broker_tickers)) + ) + else: + # No positions at broker — delete all local positions + await session.execute(delete(Position)) + + logger.info( + "Portfolio sync complete: value=%.2f, cash=%.2f, positions=%d", + account.portfolio_value, + account.cash, + len(broker_positions), + ) + + +async def portfolio_sync_loop( + config: ApiGatewayConfig, + session_factory: async_sessionmaker, +) -> None: + """Run the portfolio sync loop until cancelled. + + Parameters + ---------- + config: + API Gateway configuration containing Alpaca credentials and + the snapshot interval. + session_factory: + SQLAlchemy async session factory for DB access. + """ + if not config.alpaca_api_key or not config.alpaca_secret_key: + logger.warning( + "Alpaca API credentials not configured — portfolio sync disabled" + ) + return + + broker = AlpacaBroker( + api_key=config.alpaca_api_key, + secret_key=config.alpaca_secret_key, + paper=config.paper_trading, + ) + + logger.info( + "Portfolio sync started (interval=%ds, paper=%s)", + config.snapshot_interval_seconds, + config.paper_trading, + ) + + while True: + try: + if is_market_open(): + await _sync_once(broker, session_factory) + else: + logger.debug("Market closed — skipping portfolio snapshot") + except asyncio.CancelledError: + logger.info("Portfolio sync task cancelled — shutting down") + raise + except Exception: + logger.exception("Portfolio sync error — will retry next interval") + + await asyncio.sleep(config.snapshot_interval_seconds) diff --git a/services/market_data/__init__.py b/services/market_data/__init__.py new file mode 100644 index 0000000..17d58c2 --- /dev/null +++ b/services/market_data/__init__.py @@ -0,0 +1 @@ +"""Market Data service -- fetches OHLCV bars from Alpaca and publishes to Redis Streams.""" diff --git a/services/market_data/__main__.py b/services/market_data/__main__.py new file mode 100644 index 0000000..9f35174 --- /dev/null +++ b/services/market_data/__main__.py @@ -0,0 +1,3 @@ +from services.market_data.main import main + +main() diff --git a/services/market_data/config.py b/services/market_data/config.py new file mode 100644 index 0000000..41cba15 --- /dev/null +++ b/services/market_data/config.py @@ -0,0 +1,16 @@ +"""Configuration for the market data service.""" + +from shared.config import BaseConfig + + +class MarketDataConfig(BaseConfig): + """Extends BaseConfig with market-data-specific settings.""" + + watchlist: list[str] = ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"] + bar_timeframe: str = "5Min" + poll_interval_seconds: int = 60 + historical_bars: int = 100 + alpaca_api_key: str = "" + alpaca_secret_key: str = "" + + model_config = {"env_prefix": "TRADING_"} diff --git a/services/market_data/main.py b/services/market_data/main.py new file mode 100644 index 0000000..3378835 --- /dev/null +++ b/services/market_data/main.py @@ -0,0 +1,257 @@ +"""Market Data service -- main entry point. + +Fetches historical and live OHLCV bars from Alpaca's market data API +and publishes them to the ``market:bars`` Redis Stream for consumption +by the signal generator and other downstream services. +""" + +from __future__ import annotations + +import asyncio +import logging +import signal +from datetime import datetime, timedelta, timezone + +from redis.asyncio import Redis + +from services.market_data.config import MarketDataConfig +from shared.redis_streams import StreamPublisher +from shared.telemetry import setup_telemetry + +logger = logging.getLogger(__name__) + +MARKET_BARS_STREAM = "market:bars" + + +def _parse_timeframe(timeframe_str: str): + """Parse a timeframe string like '5Min' into an Alpaca TimeFrame object. + + Returns a ``TimeFrame`` instance suitable for ``StockBarsRequest``. + """ + from alpaca.data.timeframe import TimeFrame, TimeFrameUnit + + # Supported formats: "1Min", "5Min", "15Min", "1Hour", "1Day" + tf_map = { + "1Min": TimeFrame(1, TimeFrameUnit.Minute), + "5Min": TimeFrame(5, TimeFrameUnit.Minute), + "15Min": TimeFrame(15, TimeFrameUnit.Minute), + "1Hour": TimeFrame(1, TimeFrameUnit.Hour), + "1Day": TimeFrame(1, TimeFrameUnit.Day), + } + tf = tf_map.get(timeframe_str) + if tf is None: + raise ValueError( + f"Unsupported timeframe '{timeframe_str}'. " + f"Supported values: {list(tf_map.keys())}" + ) + return tf + + +def _bar_to_dict(ticker: str, bar) -> dict: + """Convert an Alpaca Bar object to a flat dictionary for Redis publishing.""" + return { + "ticker": ticker, + "timestamp": bar.timestamp.isoformat(), + "open": float(bar.open), + "high": float(bar.high), + "low": float(bar.low), + "close": float(bar.close), + "volume": float(bar.volume), + } + + +async def _fetch_historical_bars( + client, + watchlist: list[str], + timeframe, + limit: int, + publisher: StreamPublisher, + bars_published_counter, +) -> int: + """Fetch historical bars for each ticker and publish to Redis. + + Returns the total number of bars published. + """ + from alpaca.data.requests import StockBarsRequest + + total_published = 0 + + # Use a start time far enough back to get the requested number of bars + start = datetime.now(timezone.utc) - timedelta(days=30) + + for ticker in watchlist: + try: + request = StockBarsRequest( + symbol_or_symbols=[ticker], + timeframe=timeframe, + start=start, + limit=limit, + ) + bars = await asyncio.to_thread(client.get_stock_bars, request) + + ticker_bars = bars[ticker] if ticker in bars else [] + for bar in ticker_bars: + msg = _bar_to_dict(ticker, bar) + await publisher.publish(msg) + total_published += 1 + + logger.info( + "Published %d historical bars for %s", + len(ticker_bars), + ticker, + ) + except Exception: + logger.exception("Failed to fetch historical bars for %s", ticker) + + if total_published: + bars_published_counter.add(total_published) + + return total_published + + +async def _poll_latest_bars( + client, + watchlist: list[str], + timeframe, + publisher: StreamPublisher, + bars_published_counter, +) -> int: + """Fetch the latest bar for each ticker and publish to Redis. + + Returns the number of bars published. + """ + from alpaca.data.requests import StockBarsRequest + + published = 0 + + # Fetch bars from the last 10 minutes to ensure we get at least one + start = datetime.now(timezone.utc) - timedelta(minutes=10) + + for ticker in watchlist: + try: + request = StockBarsRequest( + symbol_or_symbols=[ticker], + timeframe=timeframe, + start=start, + limit=1, + ) + bars = await asyncio.to_thread(client.get_stock_bars, request) + + ticker_bars = bars[ticker] if ticker in bars else [] + if ticker_bars: + # Publish only the most recent bar + bar = ticker_bars[-1] + msg = _bar_to_dict(ticker, bar) + await publisher.publish(msg) + published += 1 + logger.debug("Published latest bar for %s: close=%.2f", ticker, bar.close) + except Exception: + logger.exception("Failed to fetch latest bar for %s", ticker) + + if published: + bars_published_counter.add(published) + + return published + + +async def run(config: MarketDataConfig | None = None) -> None: + """Main service loop. + + Connects to Alpaca and Redis, fetches historical bars on startup, + then polls for new bars at the configured interval. + """ + if config is None: + config = MarketDataConfig() + + logging.basicConfig(level=config.log_level) + logger.info("Starting Market Data service") + + # --- Telemetry --- + meter = setup_telemetry("market-data", config.otel_metrics_port) + bars_published_counter = meter.create_counter( + "market_data.bars_published", + description="Total OHLCV bars published to market:bars stream", + ) + poll_errors_counter = meter.create_counter( + "market_data.poll_errors", + description="Total poll cycle errors", + ) + + # --- Alpaca client --- + from alpaca.data.historical import StockHistoricalDataClient + + client = StockHistoricalDataClient( + api_key=config.alpaca_api_key, + secret_key=config.alpaca_secret_key, + ) + + # --- Redis --- + redis = Redis.from_url(config.redis_url, decode_responses=False) + publisher = StreamPublisher(redis, MARKET_BARS_STREAM) + + # --- Parse timeframe --- + timeframe = _parse_timeframe(config.bar_timeframe) + + # --- Graceful shutdown --- + shutdown_event = asyncio.Event() + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, shutdown_event.set) + + try: + # Fetch historical bars on startup + logger.info( + "Fetching %d historical bars for watchlist: %s", + config.historical_bars, + config.watchlist, + ) + total = await _fetch_historical_bars( + client, + config.watchlist, + timeframe, + config.historical_bars, + publisher, + bars_published_counter, + ) + logger.info("Historical backfill complete: %d total bars published", total) + + # Poll loop + logger.info( + "Starting poll loop (interval=%ds) for watchlist: %s", + config.poll_interval_seconds, + config.watchlist, + ) + while not shutdown_event.is_set(): + try: + await asyncio.wait_for( + shutdown_event.wait(), + timeout=config.poll_interval_seconds, + ) + break # Shutdown signaled + except asyncio.TimeoutError: + pass # Normal timeout — time to poll + + try: + count = await _poll_latest_bars( + client, + config.watchlist, + timeframe, + publisher, + bars_published_counter, + ) + logger.info("Poll cycle complete: %d bars published", count) + except Exception: + logger.exception("Poll cycle failed") + poll_errors_counter.add(1) + finally: + await redis.aclose() + logger.info("Market data service stopped gracefully") + + +def main() -> None: + """CLI entry point.""" + asyncio.run(run()) + + +if __name__ == "__main__": + main() diff --git a/services/sentiment_analyzer/main.py b/services/sentiment_analyzer/main.py index a88c1b1..c645223 100644 --- a/services/sentiment_analyzer/main.py +++ b/services/sentiment_analyzer/main.py @@ -3,7 +3,8 @@ Consumes ``news:raw`` articles from Redis Streams, scores them using a tiered approach (FinBERT first, Ollama fallback for low-confidence results), extracts ticker mentions, and publishes ``ScoredArticle`` messages to -``news:scored``. +``news:scored``. Also persists scored articles to the database (articles + +article_sentiments tables) so the dashboard can display real data. """ from __future__ import annotations @@ -14,11 +15,15 @@ import signal import time from redis.asyncio import Redis +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import async_sessionmaker from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer from services.sentiment_analyzer.config import SentimentAnalyzerConfig from services.sentiment_analyzer.ticker_extractor import extract_tickers +from shared.db import create_db +from shared.models.news import Article, ArticleSentiment from shared.redis_streams import StreamConsumer, StreamPublisher from shared.schemas.news import RawArticle, ScoredArticle from shared.telemetry import setup_telemetry @@ -33,6 +38,7 @@ async def process_article( publisher: StreamPublisher, config: SentimentAnalyzerConfig, counters: dict, + db_session_factory: async_sessionmaker | None = None, ) -> None: """Score a single article and publish one ScoredArticle per extracted ticker. @@ -50,6 +56,9 @@ async def process_article( Service configuration (confidence threshold, etc.). counters: Dict of OpenTelemetry counter/histogram instruments. + db_session_factory: + Optional async session factory for persisting to the DB. + When ``None``, DB persistence is skipped (backward compatible). """ start = time.monotonic() @@ -103,6 +112,46 @@ async def process_article( counters["articles_scored"].add(1) + # --- Step 5: Persist to DB --- + if db_session_factory is not None: + try: + async with db_session_factory() as session: + db_article = Article( + source=article.source, + url=article.url, + title=article.title, + published_at=article.published_at, + fetched_at=article.fetched_at, + content_hash=article.content_hash, + ) + session.add(db_article) + + for ticker in tickers: + sentiment = ArticleSentiment( + article_id=db_article.id, + ticker=ticker, + score=score, + confidence=confidence, + model_used=model_used, + ) + session.add(sentiment) + + await session.commit() + logger.debug( + "Persisted article '%s' with %d sentiments to DB", + article.title[:60], + len(tickers), + ) + except IntegrityError: + logger.debug( + "Article already exists in DB (content_hash=%s), skipping", + article.content_hash, + ) + except Exception: + logger.exception( + "Failed to persist article to DB: %s", article.title[:60] + ) + async def run(config: SentimentAnalyzerConfig | None = None) -> None: """Main service loop. @@ -150,6 +199,14 @@ async def run(config: SentimentAnalyzerConfig | None = None) -> None: ) ollama = OllamaAnalyzer(model=config.ollama_model, host=config.ollama_host) + # --- Database --- + db_session_factory = None + try: + _engine, db_session_factory = create_db(config) + logger.info("Database session factory initialised") + except Exception: + logger.exception("Failed to initialise DB — articles will NOT be persisted") + logger.info("Consuming from news:raw, publishing to news:scored") # Graceful shutdown on SIGTERM/SIGINT @@ -165,7 +222,7 @@ async def run(config: SentimentAnalyzerConfig | None = None) -> None: break try: article = RawArticle.model_validate(data) - await process_article(article, finbert, ollama, publisher, config, counters) + await process_article(article, finbert, ollama, publisher, config, counters, db_session_factory) except Exception: logger.exception("Error processing article: %s", data.get("title", "")) finally: diff --git a/services/signal_generator/main.py b/services/signal_generator/main.py index e8a04d0..48210c6 100644 --- a/services/signal_generator/main.py +++ b/services/signal_generator/main.py @@ -1,8 +1,9 @@ """Signal Generator service -- main entry point. -Consumes ``news:scored`` articles from Redis Streams, updates sentiment -context per ticker, runs the weighted ensemble of trading strategies, and -publishes qualifying ``TradeSignal`` messages to ``signals:generated``. +Consumes ``news:scored`` articles and ``market:bars`` OHLCV data from +Redis Streams, updates sentiment context and market data per ticker, +runs the weighted ensemble of trading strategies, and publishes +qualifying ``TradeSignal`` messages to ``signals:generated``. """ from __future__ import annotations @@ -10,16 +11,21 @@ from __future__ import annotations import asyncio import logging import signal +import uuid from collections import defaultdict, deque from redis.asyncio import Redis +from sqlalchemy.ext.asyncio import async_sessionmaker from services.signal_generator.config import SignalGeneratorConfig from services.signal_generator.ensemble import WeightedEnsemble from services.signal_generator.market_data import MarketDataManager +from shared.db import create_db +from shared.models.trading import Signal as SignalModel +from shared.models.trading import SignalDirection as SignalDirectionModel from shared.redis_streams import StreamConsumer, StreamPublisher from shared.schemas.news import ScoredArticle -from shared.schemas.trading import SentimentContext +from shared.schemas.trading import MarketSnapshot, SentimentContext from shared.strategies import MeanReversionStrategy, MomentumStrategy, NewsDrivenStrategy from shared.telemetry import setup_telemetry @@ -53,12 +59,150 @@ def _build_sentiment_context( ) +async def _consume_market_bars( + bars_consumer: StreamConsumer, + market_data: MarketDataManager, + shutdown_event: asyncio.Event, + bars_received_counter, +) -> None: + """Consume OHLCV bars from ``market:bars`` and feed them to the MarketDataManager. + + Runs as a concurrent task alongside the scored-article consumer. + """ + logger.info("Starting market:bars consumer") + async for _msg_id, data in bars_consumer.consume(): + if shutdown_event.is_set(): + break + try: + ticker = data.get("ticker") + if not ticker: + logger.warning("Received bar message without ticker field: %s", data) + continue + + # Build bar_data dict without the ticker key (OHLCVBar doesn't have it) + bar_data = {k: v for k, v in data.items() if k != "ticker"} + market_data.add_bar(ticker, bar_data) + bars_received_counter.add(1) + logger.debug("Added bar for %s: close=%s", ticker, data.get("close")) + except Exception: + logger.exception("Error processing market bar: %s", data) + + +async def _consume_scored_articles( + articles_consumer: StreamConsumer, + market_data: MarketDataManager, + ensemble: WeightedEnsemble, + weights: dict[str, float], + publisher: StreamPublisher, + shutdown_event: asyncio.Event, + signals_generated, + per_strategy_signal_count, + db_session_factory: async_sessionmaker | None = None, +) -> None: + """Consume scored articles from ``news:scored``, run the ensemble, and publish signals. + + Runs as a concurrent task alongside the market-bars consumer. + """ + # Per-ticker sentiment accumulators + sentiment_scores: dict[str, deque[float]] = defaultdict( + lambda: deque(maxlen=_MAX_SENTIMENT_SCORES) + ) + sentiment_confidences: dict[str, deque[float]] = defaultdict( + lambda: deque(maxlen=_MAX_SENTIMENT_SCORES) + ) + + logger.info("Starting news:scored consumer") + async for _msg_id, data in articles_consumer.consume(): + if shutdown_event.is_set(): + break + try: + article = ScoredArticle.model_validate(data) + ticker = article.ticker + + # Update sentiment accumulators + sentiment_scores[ticker].append(article.sentiment_score) + sentiment_confidences[ticker].append(article.confidence) + + # Build sentiment context + sentiment = _build_sentiment_context( + ticker, + sentiment_scores[ticker], + sentiment_confidences[ticker], + ) + + # Get market snapshot (may be None if no bars received yet) + snapshot = market_data.get_snapshot(ticker) + if snapshot is None: + # Create a minimal snapshot from sentiment data alone + # (the news_driven strategy does not require market indicators) + snapshot = MarketSnapshot( + ticker=ticker, + current_price=0.0, + open=0.0, + high=0.0, + low=0.0, + close=0.0, + volume=0.0, + ) + + # Run ensemble + signal_result = await ensemble.evaluate(ticker, snapshot, sentiment, weights) + + if signal_result is not None: + # Inject current price for trade executor position sizing + if snapshot and snapshot.current_price > 0: + if signal_result.sentiment_context is None: + signal_result.sentiment_context = {} + signal_result.sentiment_context["current_price"] = snapshot.current_price + + # Persist signal to DB + if db_session_factory is not None: + try: + async with db_session_factory() as session: + direction_map = { + "LONG": SignalDirectionModel.LONG, + "SHORT": SignalDirectionModel.SHORT, + "NEUTRAL": SignalDirectionModel.NEUTRAL, + } + db_signal = SignalModel( + id=signal_result.signal_id, + ticker=ticker, + direction=direction_map[signal_result.direction.value], + strength=signal_result.strength, + strategy_sources=signal_result.strategy_sources, + sentiment_score=sentiment.avg_score if sentiment else None, + acted_on=False, + ) + session.add(db_signal) + await session.commit() + except Exception: + logger.exception("Failed to persist signal to DB") + + await publisher.publish(signal_result.model_dump(mode="json")) + signals_generated.add(1) + for src in signal_result.strategy_sources: + strategy_name = src.split(":")[0] + per_strategy_signal_count.add(1, {"strategy": strategy_name}) + logger.info( + "Signal generated: %s %s strength=%.4f sources=%s", + signal_result.direction.value, + ticker, + signal_result.strength, + signal_result.strategy_sources, + ) + + except Exception: + logger.exception( + "Error processing scored article: %s", data.get("title", "") + ) + + async def run(config: SignalGeneratorConfig | None = None) -> None: """Main service loop. Connects to Redis, initialises strategies and telemetry, then - continuously consumes from ``news:scored`` and publishes qualifying - signals to ``signals:generated``. + continuously consumes from ``news:scored`` and ``market:bars``, + publishing qualifying signals to ``signals:generated``. """ if config is None: config = SignalGeneratorConfig() @@ -76,10 +220,19 @@ async def run(config: SignalGeneratorConfig | None = None) -> None: "per_strategy_signal_count", description="Signals emitted, broken down by strategy", ) + bars_received_counter = meter.create_counter( + "bars_received", + description="Total OHLCV bars received from market:bars stream", + ) # --- Redis --- redis = Redis.from_url(config.redis_url, decode_responses=False) - consumer = StreamConsumer(redis, "news:scored", "signal-generator", "worker-1") + articles_consumer = StreamConsumer( + redis, "news:scored", "signal-generator", "worker-1" + ) + bars_consumer = StreamConsumer( + redis, "market:bars", "signal-generator", "bars-worker" + ) publisher = StreamPublisher(redis, "signals:generated") # --- Market data --- @@ -96,11 +249,17 @@ async def run(config: SignalGeneratorConfig | None = None) -> None: # --- Strategy weights (default equal; could load from DB) --- weights = dict(_DEFAULT_WEIGHTS) - # --- Per-ticker sentiment accumulators --- - sentiment_scores: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES)) - sentiment_confidences: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES)) + # --- Database (for persisting signals) --- + db_session_factory = None + try: + _engine, db_session_factory = create_db(config) + logger.info("Database session factory initialised for signal persistence") + except Exception: + logger.exception("Failed to initialise DB — signals will NOT be persisted") - logger.info("Consuming from news:scored, publishing to signals:generated") + logger.info( + "Consuming from news:scored and market:bars, publishing to signals:generated" + ) # Graceful shutdown on SIGTERM/SIGINT shutdown_event = asyncio.Event() @@ -108,62 +267,30 @@ async def run(config: SignalGeneratorConfig | None = None) -> None: for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, shutdown_event.set) - # --- Consume loop --- + # --- Run both consumers concurrently --- try: - async for _msg_id, data in consumer.consume(): - if shutdown_event.is_set(): - break - try: - article = ScoredArticle.model_validate(data) - ticker = article.ticker - - # Update sentiment accumulators - sentiment_scores[ticker].append(article.sentiment_score) - sentiment_confidences[ticker].append(article.confidence) - - # Build sentiment context - sentiment = _build_sentiment_context( - ticker, - sentiment_scores[ticker], - sentiment_confidences[ticker], + async with asyncio.TaskGroup() as tg: + tg.create_task( + _consume_scored_articles( + articles_consumer, + market_data, + ensemble, + weights, + publisher, + shutdown_event, + signals_generated, + per_strategy_signal_count, + db_session_factory, ) - - # Get market snapshot (may be None if no bars received yet) - snapshot = market_data.get_snapshot(ticker) - if snapshot is None: - # Create a minimal snapshot from sentiment data alone - # (the news_driven strategy does not require market indicators) - from shared.schemas.trading import MarketSnapshot - - snapshot = MarketSnapshot( - ticker=ticker, - current_price=0.0, - open=0.0, - high=0.0, - low=0.0, - close=0.0, - volume=0.0, - ) - - # Run ensemble - signal_result = await ensemble.evaluate(ticker, snapshot, sentiment, weights) - - if signal_result is not None: - await publisher.publish(signal_result.model_dump(mode="json")) - signals_generated.add(1) - for src in signal_result.strategy_sources: - strategy_name = src.split(":")[0] - per_strategy_signal_count.add(1, {"strategy": strategy_name}) - logger.info( - "Signal generated: %s %s strength=%.4f sources=%s", - signal_result.direction.value, - ticker, - signal_result.strength, - signal_result.strategy_sources, - ) - - except Exception: - logger.exception("Error processing scored article: %s", data.get("title", "")) + ) + tg.create_task( + _consume_market_bars( + bars_consumer, + market_data, + shutdown_event, + bars_received_counter, + ) + ) finally: await redis.aclose() logger.info("Signal generator stopped gracefully") diff --git a/services/trade_executor/main.py b/services/trade_executor/main.py index 5d3d7d0..8d820c3 100644 --- a/services/trade_executor/main.py +++ b/services/trade_executor/main.py @@ -15,10 +15,15 @@ import time import uuid from redis.asyncio import Redis +from sqlalchemy.ext.asyncio import async_sessionmaker from services.trade_executor.config import TradeExecutorConfig from services.trade_executor.risk_manager import RiskManager from shared.broker.alpaca_broker import AlpacaBroker +from shared.db import create_db +from shared.models.trading import Trade as TradeModel +from shared.models.trading import TradeSide as TradeSideModel +from shared.models.trading import TradeStatus as TradeStatusModel from shared.redis_streams import StreamConsumer, StreamPublisher from shared.schemas.trading import ( OrderRequest, @@ -39,6 +44,7 @@ async def process_signal( broker: AlpacaBroker, publisher: StreamPublisher, counters: dict, + db_session_factory: async_sessionmaker | None = None, ) -> None: """Process a single trade signal: risk check, order, record, publish. @@ -54,6 +60,8 @@ async def process_signal( Publishes execution results to ``trades:executed``. counters: Dict of OpenTelemetry counter/histogram instruments. + db_session_factory: + Optional async session factory for persisting trades to the DB. """ # --- Step 1: risk check --- approved, reason = await risk_manager.check_risk(signal) @@ -93,12 +101,42 @@ async def process_signal( qty=result.qty, price=result.filled_price or 0.0, status=result.status, - signal_id=None, + signal_id=signal.signal_id, strategy_id=None, timestamp=result.timestamp, ) - # --- Step 6: publish to trades:executed --- + # --- Step 6: persist trade to DB --- + if db_session_factory is not None: + try: + side_map = { + OrderSide.BUY: TradeSideModel.BUY, + OrderSide.SELL: TradeSideModel.SELL, + } + status_map = { + OrderStatus.PENDING: TradeStatusModel.PENDING, + OrderStatus.FILLED: TradeStatusModel.FILLED, + OrderStatus.CANCELLED: TradeStatusModel.CANCELLED, + OrderStatus.REJECTED: TradeStatusModel.REJECTED, + } + async with db_session_factory() as session: + db_trade = TradeModel( + id=trade_id, + ticker=signal.ticker, + side=side_map[side], + qty=result.qty, + price=result.filled_price or 0.0, + timestamp=str(result.timestamp), + signal_id=signal.signal_id, + status=status_map.get(result.status, TradeStatusModel.PENDING), + ) + session.add(db_trade) + await session.commit() + logger.debug("Persisted trade %s to DB (signal_id=%s)", trade_id, signal.signal_id) + except Exception: + logger.exception("Failed to persist trade to DB") + + # --- Step 7: publish to trades:executed --- await publisher.publish(execution.model_dump(mode="json")) counters["trades_executed"].add(1) logger.info( @@ -157,6 +195,14 @@ async def run(config: TradeExecutorConfig | None = None) -> None: # --- Risk manager --- risk_manager = RiskManager(config, broker) + # --- Database (for persisting trades) --- + db_session_factory = None + try: + _engine, db_session_factory = create_db(config) + logger.info("Database session factory initialised for trade persistence") + except Exception: + logger.exception("Failed to initialise DB — trades will NOT be persisted") + logger.info("Consuming from signals:generated, publishing to trades:executed") # Graceful shutdown on SIGTERM/SIGINT @@ -172,7 +218,7 @@ async def run(config: TradeExecutorConfig | None = None) -> None: break try: signal_msg = TradeSignal.model_validate(data) - await process_signal(signal_msg, risk_manager, broker, publisher, counters) + await process_signal(signal_msg, risk_manager, broker, publisher, counters, db_session_factory) except Exception: logger.exception("Error processing signal: %s", data) finally: diff --git a/shared/schemas/trading.py b/shared/schemas/trading.py index 0a0046b..c3d1cd5 100644 --- a/shared/schemas/trading.py +++ b/shared/schemas/trading.py @@ -3,7 +3,7 @@ from datetime import datetime from enum import Enum from typing import Any -from uuid import UUID +from uuid import UUID, uuid4 from pydantic import BaseModel, Field @@ -96,6 +96,7 @@ class AccountInfo(BaseModel): class TradeSignal(BaseModel): """Published to ``signals:generated`` by the signal generator.""" + signal_id: UUID = Field(default_factory=uuid4) ticker: str direction: SignalDirection strength: float = Field(ge=0.0, le=1.0) diff --git a/tests/services/test_market_data.py b/tests/services/test_market_data.py new file mode 100644 index 0000000..d3b07f6 --- /dev/null +++ b/tests/services/test_market_data.py @@ -0,0 +1,462 @@ +"""Tests for the Market Data service. + +Covers configuration defaults, bar parsing from Alpaca response format, +historical bar fetching, poll logic, and Redis publish behaviour. + +All Alpaca SDK imports are mocked to avoid requiring pytz and other +transitive dependencies in the test environment. +""" + +from __future__ import annotations + +import sys +from datetime import datetime, timezone +from types import ModuleType +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from services.market_data.config import MarketDataConfig + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _FakeBar: + """Mimics an Alpaca Bar object with the required attributes.""" + + def __init__( + self, + timestamp: datetime, + open: float, + high: float, + low: float, + close: float, + volume: float, + ) -> None: + self.timestamp = timestamp + self.open = open + self.high = high + self.low = low + self.close = close + self.volume = volume + + +def _make_fake_bar(close: float = 150.0) -> _FakeBar: + return _FakeBar( + timestamp=datetime(2026, 1, 15, 10, 30, tzinfo=timezone.utc), + open=149.0, + high=151.0, + low=148.0, + close=close, + volume=10000.0, + ) + + +class _FakeBarSet(dict): + """Mimics an Alpaca BarSet (dict-like) mapping ticker to list of bars.""" + + pass + + +class _FakeTimeFrame: + """Mimics alpaca.data.timeframe.TimeFrame.""" + + def __init__(self, amount, unit): + self.amount = amount + self.unit = unit + + +class _FakeTimeFrameUnit: + """Mimics alpaca.data.timeframe.TimeFrameUnit.""" + + Minute = "Minute" + Hour = "Hour" + Day = "Day" + + +class _FakeStockBarsRequest: + """Mimics alpaca.data.requests.StockBarsRequest.""" + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +def _install_alpaca_mocks(): + """Install mock modules for the alpaca SDK so we can import market_data.main.""" + timeframe_mod = ModuleType("alpaca.data.timeframe") + timeframe_mod.TimeFrame = _FakeTimeFrame + timeframe_mod.TimeFrameUnit = _FakeTimeFrameUnit + + requests_mod = ModuleType("alpaca.data.requests") + requests_mod.StockBarsRequest = _FakeStockBarsRequest + + historical_mod = ModuleType("alpaca.data.historical") + historical_mod.StockHistoricalDataClient = MagicMock + + # Build the package hierarchy + alpaca_mod = sys.modules.get("alpaca") or ModuleType("alpaca") + data_mod = sys.modules.get("alpaca.data") or ModuleType("alpaca.data") + + sys.modules.setdefault("alpaca", alpaca_mod) + sys.modules["alpaca.data"] = data_mod + sys.modules["alpaca.data.timeframe"] = timeframe_mod + sys.modules["alpaca.data.requests"] = requests_mod + sys.modules["alpaca.data.historical"] = historical_mod + + +# Install mocks before importing from market_data.main +_install_alpaca_mocks() + +from services.market_data.main import ( # noqa: E402 + MARKET_BARS_STREAM, + _bar_to_dict, + _fetch_historical_bars, + _parse_timeframe, + _poll_latest_bars, +) + + +async def _to_thread_passthrough(func, *args, **kwargs): + """Replacement for asyncio.to_thread that calls synchronously.""" + return func(*args, **kwargs) + + +# --------------------------------------------------------------------------- +# Config defaults +# --------------------------------------------------------------------------- + + +class TestMarketDataConfig: + """Tests for MarketDataConfig defaults.""" + + def test_default_watchlist(self): + config = MarketDataConfig() + assert config.watchlist == ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"] + + def test_default_bar_timeframe(self): + config = MarketDataConfig() + assert config.bar_timeframe == "5Min" + + def test_default_poll_interval(self): + config = MarketDataConfig() + assert config.poll_interval_seconds == 60 + + def test_default_historical_bars(self): + config = MarketDataConfig() + assert config.historical_bars == 100 + + def test_default_alpaca_keys_empty(self): + config = MarketDataConfig() + assert config.alpaca_api_key == "" + assert config.alpaca_secret_key == "" + + def test_inherits_base_config_defaults(self): + config = MarketDataConfig() + assert config.log_level == "INFO" + assert config.otel_metrics_port == 9090 + + +# --------------------------------------------------------------------------- +# Timeframe parsing +# --------------------------------------------------------------------------- + + +class TestParseTimeframe: + """Tests for the _parse_timeframe helper.""" + + def test_parse_5min(self): + tf = _parse_timeframe("5Min") + assert tf is not None + assert tf.amount == 5 + + def test_parse_1min(self): + tf = _parse_timeframe("1Min") + assert tf is not None + assert tf.amount == 1 + + def test_parse_15min(self): + tf = _parse_timeframe("15Min") + assert tf is not None + assert tf.amount == 15 + + def test_parse_1hour(self): + tf = _parse_timeframe("1Hour") + assert tf is not None + assert tf.amount == 1 + + def test_parse_1day(self): + tf = _parse_timeframe("1Day") + assert tf is not None + assert tf.amount == 1 + + def test_parse_invalid_raises(self): + with pytest.raises(ValueError, match="Unsupported timeframe"): + _parse_timeframe("3Min") + + +# --------------------------------------------------------------------------- +# Bar to dict conversion +# --------------------------------------------------------------------------- + + +class TestBarToDict: + """Tests for _bar_to_dict conversion.""" + + def test_basic_conversion(self): + bar = _make_fake_bar(close=155.5) + result = _bar_to_dict("AAPL", bar) + + assert result["ticker"] == "AAPL" + assert result["close"] == 155.5 + assert result["open"] == 149.0 + assert result["high"] == 151.0 + assert result["low"] == 148.0 + assert result["volume"] == 10000.0 + assert "timestamp" in result + + def test_timestamp_is_iso_format(self): + bar = _make_fake_bar() + result = _bar_to_dict("TSLA", bar) + # Should be parseable back + parsed = datetime.fromisoformat(result["timestamp"]) + assert parsed.tzinfo is not None + + def test_all_fields_are_floats(self): + bar = _make_fake_bar() + result = _bar_to_dict("AAPL", bar) + for field in ("open", "high", "low", "close", "volume"): + assert isinstance(result[field], float) + + def test_ticker_is_passed_through(self): + bar = _make_fake_bar() + result = _bar_to_dict("NVDA", bar) + assert result["ticker"] == "NVDA" + + +# --------------------------------------------------------------------------- +# Fetch historical bars +# --------------------------------------------------------------------------- + + +class TestFetchHistoricalBars: + """Tests for _fetch_historical_bars with mocked Alpaca client.""" + + @pytest.mark.asyncio + async def test_publishes_bars_for_each_ticker(self): + """Should publish all historical bars for each ticker in the watchlist.""" + bars_aapl = [_make_fake_bar(150.0), _make_fake_bar(151.0)] + bars_tsla = [_make_fake_bar(200.0)] + + bar_set = _FakeBarSet({"AAPL": bars_aapl, "TSLA": bars_tsla}) + + mock_client = MagicMock() + mock_client.get_stock_bars = MagicMock(return_value=bar_set) + + publisher = AsyncMock() + publisher.publish = AsyncMock() + counter = MagicMock() + + with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): + total = await _fetch_historical_bars( + mock_client, ["AAPL", "TSLA"], MagicMock(), 100, publisher, counter + ) + + assert total == 3 + assert publisher.publish.call_count == 3 + counter.add.assert_called_once_with(3) + + @pytest.mark.asyncio + async def test_handles_empty_response(self): + """Should handle tickers with no bars gracefully.""" + bar_set = _FakeBarSet({}) + + mock_client = MagicMock() + mock_client.get_stock_bars = MagicMock(return_value=bar_set) + + publisher = AsyncMock() + counter = MagicMock() + + with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): + total = await _fetch_historical_bars( + mock_client, ["AAPL"], MagicMock(), 100, publisher, counter + ) + + assert total == 0 + publisher.publish.assert_not_called() + + @pytest.mark.asyncio + async def test_handles_client_exception(self): + """Should log and continue if one ticker fails.""" + mock_client = MagicMock() + mock_client.get_stock_bars = MagicMock(side_effect=Exception("API error")) + + publisher = AsyncMock() + counter = MagicMock() + + with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): + total = await _fetch_historical_bars( + mock_client, ["AAPL", "TSLA"], MagicMock(), 100, publisher, counter + ) + + assert total == 0 + publisher.publish.assert_not_called() + + @pytest.mark.asyncio + async def test_published_message_format(self): + """Published messages should have the expected fields.""" + bar = _make_fake_bar(175.0) + bar_set = _FakeBarSet({"MSFT": [bar]}) + + mock_client = MagicMock() + mock_client.get_stock_bars = MagicMock(return_value=bar_set) + + publisher = AsyncMock() + publisher.publish = AsyncMock() + counter = MagicMock() + + with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): + await _fetch_historical_bars( + mock_client, ["MSFT"], MagicMock(), 100, publisher, counter + ) + + msg = publisher.publish.call_args[0][0] + assert msg["ticker"] == "MSFT" + assert msg["close"] == 175.0 + assert "timestamp" in msg + assert "open" in msg + assert "high" in msg + assert "low" in msg + assert "volume" in msg + + +# --------------------------------------------------------------------------- +# Poll latest bars +# --------------------------------------------------------------------------- + + +class TestPollLatestBars: + """Tests for _poll_latest_bars with mocked Alpaca client.""" + + @pytest.mark.asyncio + async def test_publishes_latest_bar_per_ticker(self): + """Should publish one bar per ticker.""" + bar_set = _FakeBarSet({"AAPL": [_make_fake_bar(155.0)]}) + + mock_client = MagicMock() + mock_client.get_stock_bars = MagicMock(return_value=bar_set) + + publisher = AsyncMock() + publisher.publish = AsyncMock() + counter = MagicMock() + + with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): + count = await _poll_latest_bars( + mock_client, ["AAPL"], MagicMock(), publisher, counter + ) + + assert count == 1 + assert publisher.publish.call_count == 1 + + published_msg = publisher.publish.call_args[0][0] + assert published_msg["ticker"] == "AAPL" + assert published_msg["close"] == 155.0 + + @pytest.mark.asyncio + async def test_publishes_only_most_recent_bar(self): + """When multiple bars returned, should only publish the last one.""" + bars = [_make_fake_bar(150.0), _make_fake_bar(155.0)] + bar_set = _FakeBarSet({"AAPL": bars}) + + mock_client = MagicMock() + mock_client.get_stock_bars = MagicMock(return_value=bar_set) + + publisher = AsyncMock() + publisher.publish = AsyncMock() + counter = MagicMock() + + with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): + count = await _poll_latest_bars( + mock_client, ["AAPL"], MagicMock(), publisher, counter + ) + + assert count == 1 + published_msg = publisher.publish.call_args[0][0] + assert published_msg["close"] == 155.0 + + @pytest.mark.asyncio + async def test_handles_no_bars_for_ticker(self): + """Should return 0 when a ticker has no bars.""" + bar_set = _FakeBarSet({"AAPL": []}) + + mock_client = MagicMock() + mock_client.get_stock_bars = MagicMock(return_value=bar_set) + + publisher = AsyncMock() + counter = MagicMock() + + with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): + count = await _poll_latest_bars( + mock_client, ["AAPL"], MagicMock(), publisher, counter + ) + + assert count == 0 + publisher.publish.assert_not_called() + + @pytest.mark.asyncio + async def test_handles_ticker_not_in_response(self): + """Should handle gracefully when the response doesn't contain the ticker.""" + bar_set = _FakeBarSet({}) + + mock_client = MagicMock() + mock_client.get_stock_bars = MagicMock(return_value=bar_set) + + publisher = AsyncMock() + counter = MagicMock() + + with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): + count = await _poll_latest_bars( + mock_client, ["AAPL"], MagicMock(), publisher, counter + ) + + assert count == 0 + publisher.publish.assert_not_called() + + @pytest.mark.asyncio + async def test_multiple_tickers(self): + """Should poll and publish bars for all tickers in watchlist.""" + bar_set_aapl = _FakeBarSet({"AAPL": [_make_fake_bar(150.0)]}) + bar_set_tsla = _FakeBarSet({"TSLA": [_make_fake_bar(250.0)]}) + + mock_client = MagicMock() + # Return different bar sets for each call + mock_client.get_stock_bars = MagicMock( + side_effect=[bar_set_aapl, bar_set_tsla] + ) + + publisher = AsyncMock() + publisher.publish = AsyncMock() + counter = MagicMock() + + with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): + count = await _poll_latest_bars( + mock_client, ["AAPL", "TSLA"], MagicMock(), publisher, counter + ) + + assert count == 2 + assert publisher.publish.call_count == 2 + + +# --------------------------------------------------------------------------- +# Stream name constant +# --------------------------------------------------------------------------- + + +class TestStreamConstants: + """Verify the stream name constant.""" + + def test_market_bars_stream_name(self): + assert MARKET_BARS_STREAM == "market:bars" diff --git a/tests/services/test_portfolio_sync.py b/tests/services/test_portfolio_sync.py new file mode 100644 index 0000000..634c9d7 --- /dev/null +++ b/tests/services/test_portfolio_sync.py @@ -0,0 +1,456 @@ +"""Tests for portfolio sync background task. + +Verifies that the sync loop correctly: +- Creates PortfolioSnapshot rows from broker account data +- Upserts Position rows from broker positions +- Removes Position rows for closed positions +- Handles broker errors gracefully +- Respects US market hours +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime, time, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from services.api_gateway.config import ApiGatewayConfig +from services.api_gateway.tasks.portfolio_sync import ( + _sync_once, + is_market_open, + portfolio_sync_loop, +) +from shared.schemas.trading import AccountInfo, PositionInfo + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def config() -> ApiGatewayConfig: + return ApiGatewayConfig( + jwt_secret_key="test-secret-for-sync", + database_url="sqlite+aiosqlite:///:memory:", + redis_url="redis://localhost:6379/0", + alpaca_api_key="test-key", + alpaca_secret_key="test-secret", + paper_trading=True, + snapshot_interval_seconds=1, + ) + + +@pytest.fixture() +def config_no_creds() -> ApiGatewayConfig: + return ApiGatewayConfig( + jwt_secret_key="test-secret-for-sync", + database_url="sqlite+aiosqlite:///:memory:", + redis_url="redis://localhost:6379/0", + alpaca_api_key="", + alpaca_secret_key="", + ) + + +@pytest.fixture() +def mock_account() -> AccountInfo: + return AccountInfo( + equity=105000.0, + cash=50000.0, + buying_power=100000.0, + portfolio_value=105000.0, + ) + + +@pytest.fixture() +def mock_positions() -> list[PositionInfo]: + return [ + PositionInfo( + ticker="AAPL", + qty=10.0, + avg_entry=150.0, + current_price=155.0, + unrealized_pnl=50.0, + market_value=1550.0, + ), + PositionInfo( + ticker="MSFT", + qty=5.0, + avg_entry=400.0, + current_price=410.0, + unrealized_pnl=50.0, + market_value=2050.0, + ), + ] + + +@pytest.fixture() +def mock_broker(mock_account, mock_positions): + broker = AsyncMock() + broker.get_account = AsyncMock(return_value=mock_account) + broker.get_positions = AsyncMock(return_value=mock_positions) + return broker + + +@pytest.fixture() +def mock_session(): + """Create a mock async session with context manager support.""" + session = AsyncMock() + session.__aenter__ = AsyncMock(return_value=session) + session.__aexit__ = AsyncMock(return_value=False) + + # Mock the begin() context manager + begin_ctx = AsyncMock() + begin_ctx.__aenter__ = AsyncMock(return_value=None) + begin_ctx.__aexit__ = AsyncMock(return_value=False) + session.begin = MagicMock(return_value=begin_ctx) + + # session.add is synchronous in SQLAlchemy — use MagicMock to avoid warnings + session.add = MagicMock() + + return session + + +@pytest.fixture() +def mock_session_factory(mock_session): + factory = MagicMock() + factory.return_value = mock_session + return factory + + +# --------------------------------------------------------------------------- +# Market hours tests +# --------------------------------------------------------------------------- + + +class TestMarketHours: + """Tests for the is_market_open() function.""" + + def test_weekday_during_market_hours(self) -> None: + # Wednesday 2024-01-10 at 10:00 AM ET = 15:00 UTC + dt = datetime(2024, 1, 10, 15, 0, 0, tzinfo=timezone.utc) + assert is_market_open(dt) is True + + def test_weekday_before_market_open(self) -> None: + # Wednesday 2024-01-10 at 9:00 AM ET = 14:00 UTC + dt = datetime(2024, 1, 10, 14, 0, 0, tzinfo=timezone.utc) + assert is_market_open(dt) is False + + def test_weekday_after_market_close(self) -> None: + # Wednesday 2024-01-10 at 4:30 PM ET = 21:30 UTC + dt = datetime(2024, 1, 10, 21, 30, 0, tzinfo=timezone.utc) + assert is_market_open(dt) is False + + def test_weekend_saturday(self) -> None: + # Saturday 2024-01-13 at 12:00 PM ET = 17:00 UTC + dt = datetime(2024, 1, 13, 17, 0, 0, tzinfo=timezone.utc) + assert is_market_open(dt) is False + + def test_weekend_sunday(self) -> None: + # Sunday 2024-01-14 at 12:00 PM ET = 17:00 UTC + dt = datetime(2024, 1, 14, 17, 0, 0, tzinfo=timezone.utc) + assert is_market_open(dt) is False + + def test_market_open_boundary(self) -> None: + # Wednesday 2024-01-10 at exactly 9:30 AM ET = 14:30 UTC + dt = datetime(2024, 1, 10, 14, 30, 0, tzinfo=timezone.utc) + assert is_market_open(dt) is True + + def test_market_close_boundary(self) -> None: + # Wednesday 2024-01-10 at exactly 4:00 PM ET = 21:00 UTC + dt = datetime(2024, 1, 10, 21, 0, 0, tzinfo=timezone.utc) + assert is_market_open(dt) is False + + +# --------------------------------------------------------------------------- +# Snapshot creation tests +# --------------------------------------------------------------------------- + + +class TestSyncOnce: + """Tests for the _sync_once() function.""" + + async def test_creates_portfolio_snapshot( + self, mock_broker, mock_session_factory, mock_session + ) -> None: + # Mock the select query to return None (no existing positions) + execute_result = MagicMock() + execute_result.scalar_one_or_none.return_value = None + mock_session.execute = AsyncMock(return_value=execute_result) + + await _sync_once(mock_broker, mock_session_factory) + + # Verify the broker was called + mock_broker.get_account.assert_awaited_once() + mock_broker.get_positions.assert_awaited_once() + + # Verify session.add was called (snapshot + 2 new positions) + assert mock_session.add.call_count == 3 # 1 snapshot + 2 positions + + # Check the snapshot + snapshot_call = mock_session.add.call_args_list[0] + snapshot = snapshot_call[0][0] + assert snapshot.total_value == 105000.0 + assert snapshot.cash == 50000.0 + assert snapshot.positions_value == 55000.0 # 105000 - 50000 + assert snapshot.daily_pnl == 0.0 + + async def test_creates_position_rows_for_new_positions( + self, mock_broker, mock_session_factory, mock_session + ) -> None: + # No existing positions in DB + execute_result = MagicMock() + execute_result.scalar_one_or_none.return_value = None + mock_session.execute = AsyncMock(return_value=execute_result) + + await _sync_once(mock_broker, mock_session_factory) + + # Positions are added via session.add (after the snapshot) + position_calls = mock_session.add.call_args_list[1:] + assert len(position_calls) == 2 + + pos1 = position_calls[0][0][0] + assert pos1.ticker == "AAPL" + assert pos1.qty == 10.0 + assert pos1.avg_entry == 150.0 + assert pos1.unrealized_pnl == 50.0 + + pos2 = position_calls[1][0][0] + assert pos2.ticker == "MSFT" + assert pos2.qty == 5.0 + assert pos2.avg_entry == 400.0 + + async def test_updates_existing_position( + self, mock_broker, mock_session_factory, mock_session + ) -> None: + # Mock an existing position for AAPL, None for MSFT + existing_aapl = MagicMock() + existing_aapl.ticker = "AAPL" + existing_aapl.qty = 5.0 # old qty + existing_aapl.avg_entry = 140.0 # old entry + + result_aapl = MagicMock() + result_aapl.scalar_one_or_none.return_value = existing_aapl + result_msft = MagicMock() + result_msft.scalar_one_or_none.return_value = None + + # First execute call is for the delete of stale positions; + # but within the loop, select calls come first + mock_session.execute = AsyncMock( + side_effect=[result_aapl, result_msft, MagicMock()] + ) + + await _sync_once(mock_broker, mock_session_factory) + + # AAPL should be updated in place + assert existing_aapl.qty == 10.0 + assert existing_aapl.avg_entry == 150.0 + assert existing_aapl.unrealized_pnl == 50.0 + + # MSFT should be added as new (snapshot + MSFT = 2 adds) + assert mock_session.add.call_count == 2 # snapshot + new MSFT + + async def test_removes_closed_positions( + self, mock_session_factory, mock_session + ) -> None: + # Broker returns only AAPL (MSFT was sold) + broker = AsyncMock() + broker.get_account = AsyncMock( + return_value=AccountInfo( + equity=100000, cash=90000, buying_power=90000, portfolio_value=100000 + ) + ) + broker.get_positions = AsyncMock( + return_value=[ + PositionInfo( + ticker="AAPL", + qty=10.0, + avg_entry=150.0, + current_price=155.0, + unrealized_pnl=50.0, + market_value=1550.0, + ) + ] + ) + + execute_result = MagicMock() + execute_result.scalar_one_or_none.return_value = None + mock_session.execute = AsyncMock(return_value=execute_result) + + await _sync_once(broker, mock_session_factory) + + # The delete statement should have been executed + # Find the delete call among execute calls + delete_called = False + for call in mock_session.execute.call_args_list: + stmt = call[0][0] + # Check if it's a delete statement (SQLAlchemy Delete object) + stmt_str = str(stmt) + if "DELETE" in stmt_str.upper(): + delete_called = True + break + assert delete_called, "Expected a DELETE statement for closed positions" + + async def test_removes_all_positions_when_broker_has_none( + self, mock_session_factory, mock_session + ) -> None: + broker = AsyncMock() + broker.get_account = AsyncMock( + return_value=AccountInfo( + equity=100000, cash=100000, buying_power=100000, portfolio_value=100000 + ) + ) + broker.get_positions = AsyncMock(return_value=[]) + + mock_session.execute = AsyncMock(return_value=MagicMock()) + + await _sync_once(broker, mock_session_factory) + + # Should delete all positions since broker has none + delete_called = False + for call in mock_session.execute.call_args_list: + stmt = call[0][0] + stmt_str = str(stmt) + if "DELETE" in stmt_str.upper(): + delete_called = True + break + assert delete_called, "Expected a DELETE statement to clear all positions" + + +# --------------------------------------------------------------------------- +# Error handling tests +# --------------------------------------------------------------------------- + + +class TestSyncErrorHandling: + """Tests that the sync loop handles errors gracefully.""" + + async def test_broker_error_does_not_crash_loop( + self, config, mock_session_factory + ) -> None: + """Broker raises an exception — loop should catch it and continue.""" + call_count = 0 + + async def mock_sync_once(broker, sf): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ConnectionError("Broker API down") + # Second call succeeds + + with ( + patch( + "services.api_gateway.tasks.portfolio_sync.AlpacaBroker" + ) as MockBroker, + patch( + "services.api_gateway.tasks.portfolio_sync._sync_once", + side_effect=mock_sync_once, + ), + patch( + "services.api_gateway.tasks.portfolio_sync.is_market_open", + return_value=True, + ), + ): + MockBroker.return_value = AsyncMock() + + task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory)) + + # Give it time for 2 iterations (interval=1s) + await asyncio.sleep(2.5) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert call_count >= 2, "Loop should have retried after the error" + + async def test_no_credentials_returns_immediately( + self, config_no_creds, mock_session_factory + ) -> None: + """When Alpaca credentials are empty, the loop should exit immediately.""" + task = asyncio.create_task( + portfolio_sync_loop(config_no_creds, mock_session_factory) + ) + # Should complete almost immediately since no creds + await asyncio.wait_for(task, timeout=2.0) + # If we get here without timeout, the function returned correctly + + +# --------------------------------------------------------------------------- +# Market hours integration with loop +# --------------------------------------------------------------------------- + + +class TestSyncLoopMarketHours: + """Tests that the loop respects market hours.""" + + async def test_skips_sync_outside_market_hours( + self, config, mock_session_factory + ) -> None: + sync_called = False + + async def mock_sync(broker, sf): + nonlocal sync_called + sync_called = True + + with ( + patch( + "services.api_gateway.tasks.portfolio_sync.AlpacaBroker" + ) as MockBroker, + patch( + "services.api_gateway.tasks.portfolio_sync._sync_once", + side_effect=mock_sync, + ), + patch( + "services.api_gateway.tasks.portfolio_sync.is_market_open", + return_value=False, + ), + ): + MockBroker.return_value = AsyncMock() + + task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory)) + await asyncio.sleep(1.5) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert not sync_called, "Sync should not run outside market hours" + + async def test_runs_sync_during_market_hours( + self, config, mock_session_factory + ) -> None: + sync_called = False + + async def mock_sync(broker, sf): + nonlocal sync_called + sync_called = True + + with ( + patch( + "services.api_gateway.tasks.portfolio_sync.AlpacaBroker" + ) as MockBroker, + patch( + "services.api_gateway.tasks.portfolio_sync._sync_once", + side_effect=mock_sync, + ), + patch( + "services.api_gateway.tasks.portfolio_sync.is_market_open", + return_value=True, + ), + ): + MockBroker.return_value = AsyncMock() + + task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory)) + await asyncio.sleep(1.5) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert sync_called, "Sync should run during market hours" diff --git a/tests/services/test_sentiment_analyzer.py b/tests/services/test_sentiment_analyzer.py index 0e354c9..6f56e2b 100644 --- a/tests/services/test_sentiment_analyzer.py +++ b/tests/services/test_sentiment_analyzer.py @@ -1,7 +1,7 @@ """Tests for the sentiment analyzer service. Covers FinBERT analyzer, Ollama analyzer, ticker extraction, and the main -service flow. +service flow including DB persistence. """ from __future__ import annotations @@ -10,6 +10,7 @@ from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest +from sqlalchemy.exc import IntegrityError from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer @@ -409,3 +410,229 @@ class TestMainFlow: publisher.publish.assert_not_called() # Still counted as scored counters["articles_scored"].add.assert_called_once_with(1) + + +# --------------------------------------------------------------------------- +# Helpers for DB persistence tests +# --------------------------------------------------------------------------- + +def _make_counters() -> dict: + """Create a dict of mock OpenTelemetry counter/histogram instruments.""" + return { + "articles_scored": MagicMock(), + "finbert_count": MagicMock(), + "ollama_count": MagicMock(), + "inference_latency": MagicMock(), + } + + +def _make_mock_db_session_factory(session: AsyncMock | None = None) -> AsyncMock: + """Create a mock async_sessionmaker that yields a mock session. + + The returned factory, when called, returns an async context manager + that yields ``session`` (a mock AsyncSession). + """ + if session is None: + session = AsyncMock() + session.add = MagicMock() + session.commit = AsyncMock() + + factory = MagicMock() + + # factory() should return an async context manager (the session) + ctx = AsyncMock() + ctx.__aenter__ = AsyncMock(return_value=session) + ctx.__aexit__ = AsyncMock(return_value=False) + factory.return_value = ctx + + return factory + + +# --------------------------------------------------------------------------- +# DB Persistence Tests +# --------------------------------------------------------------------------- + +class TestDBPersistence: + """Tests for the DB write step in process_article.""" + + @pytest.mark.asyncio + async def test_db_write_creates_article_and_sentiments(self): + """When db_session_factory is provided, Article and ArticleSentiment rows are created.""" + finbert = AsyncMock(spec=FinBERTAnalyzer) + finbert.analyze = AsyncMock(return_value=(0.75, 0.88)) + + ollama = AsyncMock(spec=OllamaAnalyzer) + + publisher = AsyncMock() + publisher.publish = AsyncMock(return_value=b"1-0") + + config = SentimentAnalyzerConfig( + finbert_confidence_threshold=0.6, + otel_metrics_port=0, + ) + + counters = _make_counters() + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() + + db_factory = _make_mock_db_session_factory(mock_session) + + article = _make_raw_article( + title="$AAPL and $MSFT report strong earnings", + content="Both Apple and Microsoft beat estimates.", + ) + + await process_article( + article, finbert, ollama, publisher, config, counters, db_factory + ) + + # session.add should be called: 1 Article + 2 ArticleSentiments = 3 calls + assert mock_session.add.call_count == 3 + + # Verify the types of objects added + added_objects = [call.args[0] for call in mock_session.add.call_args_list] + + from shared.models.news import Article, ArticleSentiment + + articles = [o for o in added_objects if isinstance(o, Article)] + sentiments = [o for o in added_objects if isinstance(o, ArticleSentiment)] + + assert len(articles) == 1 + assert len(sentiments) == 2 + + # Verify article fields + db_article = articles[0] + assert db_article.source == "test" + assert db_article.url == "https://example.com/article" + assert db_article.content_hash == "abc123" + + # Verify sentiment fields + tickers_in_sentiments = {s.ticker for s in sentiments} + assert "AAPL" in tickers_in_sentiments + assert "MSFT" in tickers_in_sentiments + for s in sentiments: + assert s.score == 0.75 + assert s.confidence == 0.88 + assert s.model_used == "finbert" + + # session.commit should be called once + mock_session.commit.assert_awaited_once() + + # Redis publishing should still happen + assert publisher.publish.call_count == 2 + + @pytest.mark.asyncio + async def test_db_duplicate_article_handled_gracefully(self): + """Duplicate content_hash (IntegrityError) should be caught silently.""" + finbert = AsyncMock(spec=FinBERTAnalyzer) + finbert.analyze = AsyncMock(return_value=(0.5, 0.9)) + + ollama = AsyncMock(spec=OllamaAnalyzer) + + publisher = AsyncMock() + publisher.publish = AsyncMock(return_value=b"1-0") + + config = SentimentAnalyzerConfig( + finbert_confidence_threshold=0.6, + otel_metrics_port=0, + ) + + counters = _make_counters() + + mock_session = AsyncMock() + mock_session.add = MagicMock() + # Simulate IntegrityError on commit (duplicate content_hash) + mock_session.commit = AsyncMock( + side_effect=IntegrityError("duplicate", {}, Exception()) + ) + + db_factory = _make_mock_db_session_factory(mock_session) + + article = _make_raw_article( + title="$AAPL news", + content="Apple earnings report.", + ) + + # Should NOT raise — IntegrityError is caught + await process_article( + article, finbert, ollama, publisher, config, counters, db_factory + ) + + # Redis publishing should still have happened before the DB write + assert publisher.publish.call_count >= 1 + counters["articles_scored"].add.assert_called_once_with(1) + + @pytest.mark.asyncio + async def test_db_none_backward_compatible(self): + """When db_session_factory is None, process_article works without DB writes.""" + finbert = AsyncMock(spec=FinBERTAnalyzer) + finbert.analyze = AsyncMock(return_value=(0.6, 0.85)) + + ollama = AsyncMock(spec=OllamaAnalyzer) + + publisher = AsyncMock() + publisher.publish = AsyncMock(return_value=b"1-0") + + config = SentimentAnalyzerConfig( + finbert_confidence_threshold=0.6, + otel_metrics_port=0, + ) + + counters = _make_counters() + + article = _make_raw_article( + title="$GOOG quarterly results", + content="Google reports revenue growth.", + ) + + # Pass db_session_factory=None explicitly (the default) + await process_article( + article, finbert, ollama, publisher, config, counters, db_session_factory=None + ) + + # Redis publishing should work as before + assert publisher.publish.call_count >= 1 + counters["articles_scored"].add.assert_called_once_with(1) + + @pytest.mark.asyncio + async def test_db_error_does_not_break_processing(self): + """A DB error should be logged but not prevent Redis publishing.""" + finbert = AsyncMock(spec=FinBERTAnalyzer) + finbert.analyze = AsyncMock(return_value=(0.3, 0.7)) + + ollama = AsyncMock(spec=OllamaAnalyzer) + + publisher = AsyncMock() + publisher.publish = AsyncMock(return_value=b"1-0") + + config = SentimentAnalyzerConfig( + finbert_confidence_threshold=0.6, + otel_metrics_port=0, + ) + + counters = _make_counters() + + mock_session = AsyncMock() + mock_session.add = MagicMock() + # Simulate a generic DB error + mock_session.commit = AsyncMock( + side_effect=RuntimeError("connection lost") + ) + + db_factory = _make_mock_db_session_factory(mock_session) + + article = _make_raw_article( + title="$TSLA stock update", + content="Tesla announces new factory.", + ) + + # Should NOT raise — generic exceptions in DB write are caught + await process_article( + article, finbert, ollama, publisher, config, counters, db_factory + ) + + # Redis publishing should still have happened + assert publisher.publish.call_count >= 1 + counters["articles_scored"].add.assert_called_once_with(1) diff --git a/tests/services/test_signal_generator.py b/tests/services/test_signal_generator.py index 9386364..684d874 100644 --- a/tests/services/test_signal_generator.py +++ b/tests/services/test_signal_generator.py @@ -357,3 +357,201 @@ class TestEnsembleTagsStrategySources: assert parts[0] == "alpha" assert parts[1] == "LONG" assert float(parts[2]) == pytest.approx(0.75, abs=0.01) + + +# --------------------------------------------------------------------------- +# Signal Generator — current_price injection into sentiment_context +# --------------------------------------------------------------------------- + + +class TestCurrentPriceInjection: + """Verify that current_price flows into sentiment_context on published signals.""" + + @pytest.mark.asyncio + async def test_current_price_set_when_snapshot_has_price(self): + """When snapshot has a positive current_price, it should appear in sentiment_context.""" + s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9)) + ensemble = WeightedEnsemble([s1], threshold=0.0) + market = MarketSnapshot( + ticker="AAPL", + current_price=155.50, + open=154.0, + high=156.0, + low=153.0, + close=155.50, + volume=5000, + ) + sentiment = SentimentContext( + ticker="AAPL", + avg_score=0.7, + article_count=5, + recent_scores=[0.6, 0.7, 0.8], + avg_confidence=0.85, + ) + weights = {"alpha": 1.0} + + signal_result = await ensemble.evaluate("AAPL", market, sentiment, weights) + assert signal_result is not None + + # Simulate the current_price injection from the signal generator main loop + if signal_result.sentiment_context is None: + signal_result.sentiment_context = {} + if market.current_price > 0: + signal_result.sentiment_context["current_price"] = market.current_price + + assert "current_price" in signal_result.sentiment_context + assert signal_result.sentiment_context["current_price"] == 155.50 + + @pytest.mark.asyncio + async def test_current_price_not_set_when_snapshot_is_none(self): + """When snapshot is None (minimal fallback), current_price should not be injected.""" + s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9)) + ensemble = WeightedEnsemble([s1], threshold=0.0) + # Minimal fallback snapshot with current_price=0.0 + market = MarketSnapshot( + ticker="AAPL", + current_price=0.0, + open=0.0, + high=0.0, + low=0.0, + close=0.0, + volume=0.0, + ) + weights = {"alpha": 1.0} + + signal_result = await ensemble.evaluate("AAPL", market, None, weights) + assert signal_result is not None + + # Simulate the injection logic — should NOT set current_price when price is 0 + if market.current_price > 0: + if signal_result.sentiment_context is None: + signal_result.sentiment_context = {} + signal_result.sentiment_context["current_price"] = market.current_price + + # sentiment_context should be None (no sentiment was passed, and no price injection) + assert signal_result.sentiment_context is None + + @pytest.mark.asyncio + async def test_current_price_not_set_when_price_is_zero(self): + """current_price should NOT be injected when snapshot.current_price is exactly 0.""" + s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.8)) + ensemble = WeightedEnsemble([s1], threshold=0.0) + sentiment = SentimentContext( + ticker="AAPL", + avg_score=0.5, + article_count=3, + recent_scores=[0.5], + avg_confidence=0.9, + ) + market = MarketSnapshot( + ticker="AAPL", + current_price=0.0, + open=0.0, + high=0.0, + low=0.0, + close=0.0, + volume=0.0, + ) + weights = {"alpha": 1.0} + + signal_result = await ensemble.evaluate("AAPL", market, sentiment, weights) + assert signal_result is not None + + # Simulate injection logic + if market.current_price > 0: + if signal_result.sentiment_context is None: + signal_result.sentiment_context = {} + signal_result.sentiment_context["current_price"] = market.current_price + + # sentiment_context should exist (from ensemble) but WITHOUT current_price + assert signal_result.sentiment_context is not None + assert "current_price" not in signal_result.sentiment_context + + @pytest.mark.asyncio + async def test_current_price_preserves_existing_sentiment_context(self): + """Injecting current_price should not overwrite other sentiment_context fields.""" + s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9)) + ensemble = WeightedEnsemble([s1], threshold=0.0) + sentiment = SentimentContext( + ticker="AAPL", + avg_score=0.75, + article_count=10, + recent_scores=[0.7, 0.8], + avg_confidence=0.9, + ) + market = MarketSnapshot( + ticker="AAPL", + current_price=200.0, + open=198.0, + high=202.0, + low=197.0, + close=200.0, + volume=8000, + ) + weights = {"alpha": 1.0} + + signal_result = await ensemble.evaluate("AAPL", market, sentiment, weights) + assert signal_result is not None + assert signal_result.sentiment_context is not None + + # Preserve existing keys + original_keys = set(signal_result.sentiment_context.keys()) + + # Inject current_price + if market.current_price > 0: + signal_result.sentiment_context["current_price"] = market.current_price + + # All original keys should still be present + for key in original_keys: + assert key in signal_result.sentiment_context + # Plus the new one + assert signal_result.sentiment_context["current_price"] == 200.0 + + +# --------------------------------------------------------------------------- +# Signal Generator — signal_id generation +# --------------------------------------------------------------------------- + + +class TestSignalIdGeneration: + """Verify that TradeSignal includes a signal_id UUID.""" + + @pytest.mark.asyncio + async def test_signal_has_uuid(self): + """Each generated signal should have a signal_id UUID.""" + s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9)) + ensemble = WeightedEnsemble([s1], threshold=0.0) + market = MarketSnapshot( + ticker="AAPL", current_price=150.0, + open=149.0, high=151.0, low=148.0, close=150.0, volume=1000, + ) + weights = {"alpha": 1.0} + + signal_result = await ensemble.evaluate("AAPL", market, None, weights) + assert signal_result is not None + assert signal_result.signal_id is not None + from uuid import UUID + assert isinstance(signal_result.signal_id, UUID) + + @pytest.mark.asyncio + async def test_signal_ids_are_unique(self): + """Multiple signals should get different signal_id values.""" + s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9)) + ensemble = WeightedEnsemble([s1], threshold=0.0) + market = MarketSnapshot( + ticker="AAPL", current_price=150.0, + open=149.0, high=151.0, low=148.0, close=150.0, volume=1000, + ) + weights = {"alpha": 1.0} + + sig1 = await ensemble.evaluate("AAPL", market, None, weights) + sig2 = await ensemble.evaluate("AAPL", market, None, weights) + assert sig1 is not None and sig2 is not None + assert sig1.signal_id != sig2.signal_id + + def test_signal_id_serializes_in_json(self): + """signal_id should serialize properly in JSON mode.""" + signal = _make_signal(SignalDirection.LONG, 0.8) + data = signal.model_dump(mode="json") + assert "signal_id" in data + assert isinstance(data["signal_id"], str) diff --git a/tests/services/test_trade_executor.py b/tests/services/test_trade_executor.py index 8ec42bf..b99a9b3 100644 --- a/tests/services/test_trade_executor.py +++ b/tests/services/test_trade_executor.py @@ -401,3 +401,119 @@ class TestExecutorFlowRejected: # Rejection counter should have been incremented counters["rejections"].add.assert_called_once() + + +# --------------------------------------------------------------------------- +# Executor flow — DB persistence +# --------------------------------------------------------------------------- + + +def _make_mock_db_session_factory(session=None): + """Create a mock async_sessionmaker that yields a mock session.""" + if session is None: + session = AsyncMock() + session.add = MagicMock() + session.commit = AsyncMock() + + factory = MagicMock() + ctx = AsyncMock() + ctx.__aenter__ = AsyncMock(return_value=session) + ctx.__aexit__ = AsyncMock(return_value=False) + factory.return_value = ctx + return factory + + +class TestExecutorDBPersistence: + """Verify that trades are persisted to the DB when db_session_factory is provided.""" + + @pytest.mark.asyncio + async def test_trade_persisted_with_signal_id(self): + """When db_session_factory is provided, a Trade row should be created.""" + config = _make_config() + broker = _mock_broker(positions=[], account=_make_account(100_000)) + publisher = AsyncMock() + publisher.publish = AsyncMock(return_value=b"1-0") + counters = { + "trades_executed": MagicMock(), + "rejections": MagicMock(), + "fill_latency": MagicMock(), + } + + signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() + db_factory = _make_mock_db_session_factory(mock_session) + + with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): + await process_signal( + signal, RiskManager(config, broker), broker, publisher, counters, db_factory + ) + + # Trade should be persisted + mock_session.add.assert_called_once() + mock_session.commit.assert_awaited_once() + + # Verify the trade object + trade_obj = mock_session.add.call_args[0][0] + assert trade_obj.ticker == "AAPL" + assert trade_obj.signal_id == signal.signal_id + + @pytest.mark.asyncio + async def test_trade_not_persisted_without_db(self): + """When db_session_factory is None, no DB write should happen.""" + config = _make_config() + broker = _mock_broker(positions=[], account=_make_account(100_000)) + publisher = AsyncMock() + publisher.publish = AsyncMock(return_value=b"1-0") + counters = { + "trades_executed": MagicMock(), + "rejections": MagicMock(), + "fill_latency": MagicMock(), + } + + signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) + + with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): + await process_signal( + signal, RiskManager(config, broker), broker, publisher, counters, None + ) + + # Should still publish + publisher.publish.assert_called_once() + + @pytest.mark.asyncio + async def test_db_error_does_not_block_publishing(self): + """A DB error should not prevent the trade from being published.""" + config = _make_config() + broker = _mock_broker(positions=[], account=_make_account(100_000)) + publisher = AsyncMock() + publisher.publish = AsyncMock(return_value=b"1-0") + counters = { + "trades_executed": MagicMock(), + "rejections": MagicMock(), + "fill_latency": MagicMock(), + } + + signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock(side_effect=RuntimeError("DB connection lost")) + db_factory = _make_mock_db_session_factory(mock_session) + + with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): + await process_signal( + signal, RiskManager(config, broker), broker, publisher, counters, db_factory + ) + + # Trade should still be published despite DB error + publisher.publish.assert_called_once() + counters["trades_executed"].add.assert_called_once_with(1) + + def test_signal_id_flows_through_execution(self): + """signal_id from TradeSignal should appear in the published TradeExecution.""" + signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) + assert signal.signal_id is not None + # Verify signal_id is a UUID + from uuid import UUID + assert isinstance(signal.signal_id, UUID)