feat: real data pipeline — market data, DB persistence, portfolio sync, signal-trade linkage

Wire the trading bot to real Alpaca market data and persist pipeline
state to the database so the dashboard displays live information.

- Add market-data service fetching OHLCV bars from Alpaca, publishing
  to market:bars Redis Stream; signal generator consumes bars and
  injects current_price into signals for position sizing
- Sentiment analyzer now persists Article + ArticleSentiment rows to
  DB after scoring, with duplicate and error handling
- API gateway runs a background portfolio sync task that snapshots
  Alpaca account state into PortfolioSnapshot/Position DB tables
  during market hours
- TradeSignal carries a signal_id UUID; signal generator and trade
  executor both persist their records to DB with cross-references
- 303 unit tests pass (57 new tests added)
This commit is contained in:
Viktor Barzin 2026-02-22 19:52:45 +00:00
parent 5a6b20c8f1
commit e2a3bd456d
No known key found for this signature in database
GPG key ID: 0EB088298288D958
19 changed files with 2238 additions and 72 deletions

View file

@ -10,6 +10,16 @@ TRADING_LOG_LEVEL=INFO
TRADING_ALPACA_API_KEY=your_api_key_here
TRADING_ALPACA_SECRET_KEY=your_secret_key_here
TRADING_ALPACA_BASE_URL=https://paper-api.alpaca.markets
TRADING_PAPER_TRADING=true
# Market data service — watchlist tickers (comma-separated)
TRADING_WATCHLIST=["AAPL","TSLA","NVDA","MSFT","GOOGL"]
TRADING_BAR_TIMEFRAME=5Min
TRADING_POLL_INTERVAL_SECONDS=60
TRADING_HISTORICAL_BARS=100
# Portfolio sync interval (seconds, api-gateway background task)
TRADING_SNAPSHOT_INTERVAL_SECONDS=60
# JWT — REQUIRED, generate with: python -c "import secrets; print(secrets.token_hex(32))"
TRADING_JWT_SECRET_KEY=

View file

@ -142,6 +142,19 @@ services:
env_file: .env
restart: unless-stopped
market-data:
build:
context: .
dockerfile: docker/Dockerfile.service
args:
EXTRAS: "trading"
SERVICE_MODULE: "market_data"
depends_on:
redis:
condition: service_healthy
env_file: .env
restart: unless-stopped
api-gateway:
build:
context: .

View file

@ -16,6 +16,12 @@ class ApiGatewayConfig(BaseConfig):
access_token_expire_minutes: int = 15
refresh_token_expire_days: int = 7
# Alpaca brokerage credentials (for portfolio sync)
alpaca_api_key: str = ""
alpaca_secret_key: str = ""
paper_trading: bool = True
snapshot_interval_seconds: int = 60
# CORS settings
cors_origins: list[str] = ["http://localhost:5173"]

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import AsyncIterator
@ -43,9 +44,23 @@ def create_app(config: ApiGatewayConfig | None = None) -> FastAPI:
)
app.state.config = config
# Start portfolio sync background task
from services.api_gateway.tasks.portfolio_sync import portfolio_sync_loop
sync_task = asyncio.create_task(
portfolio_sync_loop(config, session_factory)
)
logger.info("API Gateway started")
yield
# Cancel the sync task
sync_task.cancel()
try:
await sync_task
except asyncio.CancelledError:
pass
# Cleanup
await app.state.redis.aclose()
await engine.dispose()

View file

View file

@ -0,0 +1,155 @@
"""Background task that periodically snapshots Alpaca account state into the DB.
Runs on a configurable interval (default 60s) during US market hours,
creating ``PortfolioSnapshot`` rows and upserting ``Position`` rows so
the dashboard portfolio page reflects real brokerage data.
"""
from __future__ import annotations
import asyncio
import logging
from datetime import datetime, time, timezone
from zoneinfo import ZoneInfo
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import async_sessionmaker
from services.api_gateway.config import ApiGatewayConfig
from shared.broker.alpaca_broker import AlpacaBroker
from shared.models.timeseries import PortfolioSnapshot
from shared.models.trading import Position
logger = logging.getLogger(__name__)
# US Eastern timezone for market hours check
_ET = ZoneInfo("America/New_York")
_MARKET_OPEN = time(9, 30)
_MARKET_CLOSE = time(16, 0)
def is_market_open(now_utc: datetime | None = None) -> bool:
"""Return ``True`` if the US stock market is currently open.
Checks for weekday (Mon-Fri) and time between 9:30 AM and 4:00 PM ET.
"""
if now_utc is None:
now_utc = datetime.now(timezone.utc)
now_et = now_utc.astimezone(_ET)
# Weekday check: Monday=0 ... Friday=4
if now_et.weekday() >= 5:
return False
return _MARKET_OPEN <= now_et.time() < _MARKET_CLOSE
async def _sync_once(
broker: AlpacaBroker,
session_factory: async_sessionmaker,
) -> None:
"""Perform a single portfolio snapshot and position upsert cycle."""
now = datetime.now(timezone.utc)
# 1. Snapshot account state
account = await broker.get_account()
snapshot = PortfolioSnapshot(
timestamp=now,
total_value=account.portfolio_value,
cash=account.cash,
positions_value=account.portfolio_value - account.cash,
daily_pnl=0.0,
)
# 2. Fetch broker positions
broker_positions = await broker.get_positions()
broker_tickers = {p.ticker for p in broker_positions}
async with session_factory() as session:
async with session.begin():
# Insert portfolio snapshot
session.add(snapshot)
# Upsert positions
for pos_info in broker_positions:
result = await session.execute(
select(Position).where(Position.ticker == pos_info.ticker)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.qty = pos_info.qty
existing.avg_entry = pos_info.avg_entry
existing.unrealized_pnl = pos_info.unrealized_pnl
else:
new_pos = Position(
ticker=pos_info.ticker,
qty=pos_info.qty,
avg_entry=pos_info.avg_entry,
unrealized_pnl=pos_info.unrealized_pnl,
stop_loss=None,
take_profit=None,
)
session.add(new_pos)
# 3. Remove positions that are no longer held at the broker
if broker_tickers:
await session.execute(
delete(Position).where(Position.ticker.notin_(broker_tickers))
)
else:
# No positions at broker — delete all local positions
await session.execute(delete(Position))
logger.info(
"Portfolio sync complete: value=%.2f, cash=%.2f, positions=%d",
account.portfolio_value,
account.cash,
len(broker_positions),
)
async def portfolio_sync_loop(
config: ApiGatewayConfig,
session_factory: async_sessionmaker,
) -> None:
"""Run the portfolio sync loop until cancelled.
Parameters
----------
config:
API Gateway configuration containing Alpaca credentials and
the snapshot interval.
session_factory:
SQLAlchemy async session factory for DB access.
"""
if not config.alpaca_api_key or not config.alpaca_secret_key:
logger.warning(
"Alpaca API credentials not configured — portfolio sync disabled"
)
return
broker = AlpacaBroker(
api_key=config.alpaca_api_key,
secret_key=config.alpaca_secret_key,
paper=config.paper_trading,
)
logger.info(
"Portfolio sync started (interval=%ds, paper=%s)",
config.snapshot_interval_seconds,
config.paper_trading,
)
while True:
try:
if is_market_open():
await _sync_once(broker, session_factory)
else:
logger.debug("Market closed — skipping portfolio snapshot")
except asyncio.CancelledError:
logger.info("Portfolio sync task cancelled — shutting down")
raise
except Exception:
logger.exception("Portfolio sync error — will retry next interval")
await asyncio.sleep(config.snapshot_interval_seconds)

View file

@ -0,0 +1 @@
"""Market Data service -- fetches OHLCV bars from Alpaca and publishes to Redis Streams."""

View file

@ -0,0 +1,3 @@
from services.market_data.main import main
main()

View file

@ -0,0 +1,16 @@
"""Configuration for the market data service."""
from shared.config import BaseConfig
class MarketDataConfig(BaseConfig):
"""Extends BaseConfig with market-data-specific settings."""
watchlist: list[str] = ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"]
bar_timeframe: str = "5Min"
poll_interval_seconds: int = 60
historical_bars: int = 100
alpaca_api_key: str = ""
alpaca_secret_key: str = ""
model_config = {"env_prefix": "TRADING_"}

View file

@ -0,0 +1,257 @@
"""Market Data service -- main entry point.
Fetches historical and live OHLCV bars from Alpaca's market data API
and publishes them to the ``market:bars`` Redis Stream for consumption
by the signal generator and other downstream services.
"""
from __future__ import annotations
import asyncio
import logging
import signal
from datetime import datetime, timedelta, timezone
from redis.asyncio import Redis
from services.market_data.config import MarketDataConfig
from shared.redis_streams import StreamPublisher
from shared.telemetry import setup_telemetry
logger = logging.getLogger(__name__)
MARKET_BARS_STREAM = "market:bars"
def _parse_timeframe(timeframe_str: str):
"""Parse a timeframe string like '5Min' into an Alpaca TimeFrame object.
Returns a ``TimeFrame`` instance suitable for ``StockBarsRequest``.
"""
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
# Supported formats: "1Min", "5Min", "15Min", "1Hour", "1Day"
tf_map = {
"1Min": TimeFrame(1, TimeFrameUnit.Minute),
"5Min": TimeFrame(5, TimeFrameUnit.Minute),
"15Min": TimeFrame(15, TimeFrameUnit.Minute),
"1Hour": TimeFrame(1, TimeFrameUnit.Hour),
"1Day": TimeFrame(1, TimeFrameUnit.Day),
}
tf = tf_map.get(timeframe_str)
if tf is None:
raise ValueError(
f"Unsupported timeframe '{timeframe_str}'. "
f"Supported values: {list(tf_map.keys())}"
)
return tf
def _bar_to_dict(ticker: str, bar) -> dict:
"""Convert an Alpaca Bar object to a flat dictionary for Redis publishing."""
return {
"ticker": ticker,
"timestamp": bar.timestamp.isoformat(),
"open": float(bar.open),
"high": float(bar.high),
"low": float(bar.low),
"close": float(bar.close),
"volume": float(bar.volume),
}
async def _fetch_historical_bars(
client,
watchlist: list[str],
timeframe,
limit: int,
publisher: StreamPublisher,
bars_published_counter,
) -> int:
"""Fetch historical bars for each ticker and publish to Redis.
Returns the total number of bars published.
"""
from alpaca.data.requests import StockBarsRequest
total_published = 0
# Use a start time far enough back to get the requested number of bars
start = datetime.now(timezone.utc) - timedelta(days=30)
for ticker in watchlist:
try:
request = StockBarsRequest(
symbol_or_symbols=[ticker],
timeframe=timeframe,
start=start,
limit=limit,
)
bars = await asyncio.to_thread(client.get_stock_bars, request)
ticker_bars = bars[ticker] if ticker in bars else []
for bar in ticker_bars:
msg = _bar_to_dict(ticker, bar)
await publisher.publish(msg)
total_published += 1
logger.info(
"Published %d historical bars for %s",
len(ticker_bars),
ticker,
)
except Exception:
logger.exception("Failed to fetch historical bars for %s", ticker)
if total_published:
bars_published_counter.add(total_published)
return total_published
async def _poll_latest_bars(
client,
watchlist: list[str],
timeframe,
publisher: StreamPublisher,
bars_published_counter,
) -> int:
"""Fetch the latest bar for each ticker and publish to Redis.
Returns the number of bars published.
"""
from alpaca.data.requests import StockBarsRequest
published = 0
# Fetch bars from the last 10 minutes to ensure we get at least one
start = datetime.now(timezone.utc) - timedelta(minutes=10)
for ticker in watchlist:
try:
request = StockBarsRequest(
symbol_or_symbols=[ticker],
timeframe=timeframe,
start=start,
limit=1,
)
bars = await asyncio.to_thread(client.get_stock_bars, request)
ticker_bars = bars[ticker] if ticker in bars else []
if ticker_bars:
# Publish only the most recent bar
bar = ticker_bars[-1]
msg = _bar_to_dict(ticker, bar)
await publisher.publish(msg)
published += 1
logger.debug("Published latest bar for %s: close=%.2f", ticker, bar.close)
except Exception:
logger.exception("Failed to fetch latest bar for %s", ticker)
if published:
bars_published_counter.add(published)
return published
async def run(config: MarketDataConfig | None = None) -> None:
"""Main service loop.
Connects to Alpaca and Redis, fetches historical bars on startup,
then polls for new bars at the configured interval.
"""
if config is None:
config = MarketDataConfig()
logging.basicConfig(level=config.log_level)
logger.info("Starting Market Data service")
# --- Telemetry ---
meter = setup_telemetry("market-data", config.otel_metrics_port)
bars_published_counter = meter.create_counter(
"market_data.bars_published",
description="Total OHLCV bars published to market:bars stream",
)
poll_errors_counter = meter.create_counter(
"market_data.poll_errors",
description="Total poll cycle errors",
)
# --- Alpaca client ---
from alpaca.data.historical import StockHistoricalDataClient
client = StockHistoricalDataClient(
api_key=config.alpaca_api_key,
secret_key=config.alpaca_secret_key,
)
# --- Redis ---
redis = Redis.from_url(config.redis_url, decode_responses=False)
publisher = StreamPublisher(redis, MARKET_BARS_STREAM)
# --- Parse timeframe ---
timeframe = _parse_timeframe(config.bar_timeframe)
# --- Graceful shutdown ---
shutdown_event = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, shutdown_event.set)
try:
# Fetch historical bars on startup
logger.info(
"Fetching %d historical bars for watchlist: %s",
config.historical_bars,
config.watchlist,
)
total = await _fetch_historical_bars(
client,
config.watchlist,
timeframe,
config.historical_bars,
publisher,
bars_published_counter,
)
logger.info("Historical backfill complete: %d total bars published", total)
# Poll loop
logger.info(
"Starting poll loop (interval=%ds) for watchlist: %s",
config.poll_interval_seconds,
config.watchlist,
)
while not shutdown_event.is_set():
try:
await asyncio.wait_for(
shutdown_event.wait(),
timeout=config.poll_interval_seconds,
)
break # Shutdown signaled
except asyncio.TimeoutError:
pass # Normal timeout — time to poll
try:
count = await _poll_latest_bars(
client,
config.watchlist,
timeframe,
publisher,
bars_published_counter,
)
logger.info("Poll cycle complete: %d bars published", count)
except Exception:
logger.exception("Poll cycle failed")
poll_errors_counter.add(1)
finally:
await redis.aclose()
logger.info("Market data service stopped gracefully")
def main() -> None:
"""CLI entry point."""
asyncio.run(run())
if __name__ == "__main__":
main()

View file

@ -3,7 +3,8 @@
Consumes ``news:raw`` articles from Redis Streams, scores them using a
tiered approach (FinBERT first, Ollama fallback for low-confidence results),
extracts ticker mentions, and publishes ``ScoredArticle`` messages to
``news:scored``.
``news:scored``. Also persists scored articles to the database (articles +
article_sentiments tables) so the dashboard can display real data.
"""
from __future__ import annotations
@ -14,11 +15,15 @@ import signal
import time
from redis.asyncio import Redis
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import async_sessionmaker
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
from services.sentiment_analyzer.config import SentimentAnalyzerConfig
from services.sentiment_analyzer.ticker_extractor import extract_tickers
from shared.db import create_db
from shared.models.news import Article, ArticleSentiment
from shared.redis_streams import StreamConsumer, StreamPublisher
from shared.schemas.news import RawArticle, ScoredArticle
from shared.telemetry import setup_telemetry
@ -33,6 +38,7 @@ async def process_article(
publisher: StreamPublisher,
config: SentimentAnalyzerConfig,
counters: dict,
db_session_factory: async_sessionmaker | None = None,
) -> None:
"""Score a single article and publish one ScoredArticle per extracted ticker.
@ -50,6 +56,9 @@ async def process_article(
Service configuration (confidence threshold, etc.).
counters:
Dict of OpenTelemetry counter/histogram instruments.
db_session_factory:
Optional async session factory for persisting to the DB.
When ``None``, DB persistence is skipped (backward compatible).
"""
start = time.monotonic()
@ -103,6 +112,46 @@ async def process_article(
counters["articles_scored"].add(1)
# --- Step 5: Persist to DB ---
if db_session_factory is not None:
try:
async with db_session_factory() as session:
db_article = Article(
source=article.source,
url=article.url,
title=article.title,
published_at=article.published_at,
fetched_at=article.fetched_at,
content_hash=article.content_hash,
)
session.add(db_article)
for ticker in tickers:
sentiment = ArticleSentiment(
article_id=db_article.id,
ticker=ticker,
score=score,
confidence=confidence,
model_used=model_used,
)
session.add(sentiment)
await session.commit()
logger.debug(
"Persisted article '%s' with %d sentiments to DB",
article.title[:60],
len(tickers),
)
except IntegrityError:
logger.debug(
"Article already exists in DB (content_hash=%s), skipping",
article.content_hash,
)
except Exception:
logger.exception(
"Failed to persist article to DB: %s", article.title[:60]
)
async def run(config: SentimentAnalyzerConfig | None = None) -> None:
"""Main service loop.
@ -150,6 +199,14 @@ async def run(config: SentimentAnalyzerConfig | None = None) -> None:
)
ollama = OllamaAnalyzer(model=config.ollama_model, host=config.ollama_host)
# --- Database ---
db_session_factory = None
try:
_engine, db_session_factory = create_db(config)
logger.info("Database session factory initialised")
except Exception:
logger.exception("Failed to initialise DB — articles will NOT be persisted")
logger.info("Consuming from news:raw, publishing to news:scored")
# Graceful shutdown on SIGTERM/SIGINT
@ -165,7 +222,7 @@ async def run(config: SentimentAnalyzerConfig | None = None) -> None:
break
try:
article = RawArticle.model_validate(data)
await process_article(article, finbert, ollama, publisher, config, counters)
await process_article(article, finbert, ollama, publisher, config, counters, db_session_factory)
except Exception:
logger.exception("Error processing article: %s", data.get("title", "<unknown>"))
finally:

View file

@ -1,8 +1,9 @@
"""Signal Generator service -- main entry point.
Consumes ``news:scored`` articles from Redis Streams, updates sentiment
context per ticker, runs the weighted ensemble of trading strategies, and
publishes qualifying ``TradeSignal`` messages to ``signals:generated``.
Consumes ``news:scored`` articles and ``market:bars`` OHLCV data from
Redis Streams, updates sentiment context and market data per ticker,
runs the weighted ensemble of trading strategies, and publishes
qualifying ``TradeSignal`` messages to ``signals:generated``.
"""
from __future__ import annotations
@ -10,16 +11,21 @@ from __future__ import annotations
import asyncio
import logging
import signal
import uuid
from collections import defaultdict, deque
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import async_sessionmaker
from services.signal_generator.config import SignalGeneratorConfig
from services.signal_generator.ensemble import WeightedEnsemble
from services.signal_generator.market_data import MarketDataManager
from shared.db import create_db
from shared.models.trading import Signal as SignalModel
from shared.models.trading import SignalDirection as SignalDirectionModel
from shared.redis_streams import StreamConsumer, StreamPublisher
from shared.schemas.news import ScoredArticle
from shared.schemas.trading import SentimentContext
from shared.schemas.trading import MarketSnapshot, SentimentContext
from shared.strategies import MeanReversionStrategy, MomentumStrategy, NewsDrivenStrategy
from shared.telemetry import setup_telemetry
@ -53,12 +59,150 @@ def _build_sentiment_context(
)
async def _consume_market_bars(
bars_consumer: StreamConsumer,
market_data: MarketDataManager,
shutdown_event: asyncio.Event,
bars_received_counter,
) -> None:
"""Consume OHLCV bars from ``market:bars`` and feed them to the MarketDataManager.
Runs as a concurrent task alongside the scored-article consumer.
"""
logger.info("Starting market:bars consumer")
async for _msg_id, data in bars_consumer.consume():
if shutdown_event.is_set():
break
try:
ticker = data.get("ticker")
if not ticker:
logger.warning("Received bar message without ticker field: %s", data)
continue
# Build bar_data dict without the ticker key (OHLCVBar doesn't have it)
bar_data = {k: v for k, v in data.items() if k != "ticker"}
market_data.add_bar(ticker, bar_data)
bars_received_counter.add(1)
logger.debug("Added bar for %s: close=%s", ticker, data.get("close"))
except Exception:
logger.exception("Error processing market bar: %s", data)
async def _consume_scored_articles(
articles_consumer: StreamConsumer,
market_data: MarketDataManager,
ensemble: WeightedEnsemble,
weights: dict[str, float],
publisher: StreamPublisher,
shutdown_event: asyncio.Event,
signals_generated,
per_strategy_signal_count,
db_session_factory: async_sessionmaker | None = None,
) -> None:
"""Consume scored articles from ``news:scored``, run the ensemble, and publish signals.
Runs as a concurrent task alongside the market-bars consumer.
"""
# Per-ticker sentiment accumulators
sentiment_scores: dict[str, deque[float]] = defaultdict(
lambda: deque(maxlen=_MAX_SENTIMENT_SCORES)
)
sentiment_confidences: dict[str, deque[float]] = defaultdict(
lambda: deque(maxlen=_MAX_SENTIMENT_SCORES)
)
logger.info("Starting news:scored consumer")
async for _msg_id, data in articles_consumer.consume():
if shutdown_event.is_set():
break
try:
article = ScoredArticle.model_validate(data)
ticker = article.ticker
# Update sentiment accumulators
sentiment_scores[ticker].append(article.sentiment_score)
sentiment_confidences[ticker].append(article.confidence)
# Build sentiment context
sentiment = _build_sentiment_context(
ticker,
sentiment_scores[ticker],
sentiment_confidences[ticker],
)
# Get market snapshot (may be None if no bars received yet)
snapshot = market_data.get_snapshot(ticker)
if snapshot is None:
# Create a minimal snapshot from sentiment data alone
# (the news_driven strategy does not require market indicators)
snapshot = MarketSnapshot(
ticker=ticker,
current_price=0.0,
open=0.0,
high=0.0,
low=0.0,
close=0.0,
volume=0.0,
)
# Run ensemble
signal_result = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
if signal_result is not None:
# Inject current price for trade executor position sizing
if snapshot and snapshot.current_price > 0:
if signal_result.sentiment_context is None:
signal_result.sentiment_context = {}
signal_result.sentiment_context["current_price"] = snapshot.current_price
# Persist signal to DB
if db_session_factory is not None:
try:
async with db_session_factory() as session:
direction_map = {
"LONG": SignalDirectionModel.LONG,
"SHORT": SignalDirectionModel.SHORT,
"NEUTRAL": SignalDirectionModel.NEUTRAL,
}
db_signal = SignalModel(
id=signal_result.signal_id,
ticker=ticker,
direction=direction_map[signal_result.direction.value],
strength=signal_result.strength,
strategy_sources=signal_result.strategy_sources,
sentiment_score=sentiment.avg_score if sentiment else None,
acted_on=False,
)
session.add(db_signal)
await session.commit()
except Exception:
logger.exception("Failed to persist signal to DB")
await publisher.publish(signal_result.model_dump(mode="json"))
signals_generated.add(1)
for src in signal_result.strategy_sources:
strategy_name = src.split(":")[0]
per_strategy_signal_count.add(1, {"strategy": strategy_name})
logger.info(
"Signal generated: %s %s strength=%.4f sources=%s",
signal_result.direction.value,
ticker,
signal_result.strength,
signal_result.strategy_sources,
)
except Exception:
logger.exception(
"Error processing scored article: %s", data.get("title", "<unknown>")
)
async def run(config: SignalGeneratorConfig | None = None) -> None:
"""Main service loop.
Connects to Redis, initialises strategies and telemetry, then
continuously consumes from ``news:scored`` and publishes qualifying
signals to ``signals:generated``.
continuously consumes from ``news:scored`` and ``market:bars``,
publishing qualifying signals to ``signals:generated``.
"""
if config is None:
config = SignalGeneratorConfig()
@ -76,10 +220,19 @@ async def run(config: SignalGeneratorConfig | None = None) -> None:
"per_strategy_signal_count",
description="Signals emitted, broken down by strategy",
)
bars_received_counter = meter.create_counter(
"bars_received",
description="Total OHLCV bars received from market:bars stream",
)
# --- Redis ---
redis = Redis.from_url(config.redis_url, decode_responses=False)
consumer = StreamConsumer(redis, "news:scored", "signal-generator", "worker-1")
articles_consumer = StreamConsumer(
redis, "news:scored", "signal-generator", "worker-1"
)
bars_consumer = StreamConsumer(
redis, "market:bars", "signal-generator", "bars-worker"
)
publisher = StreamPublisher(redis, "signals:generated")
# --- Market data ---
@ -96,11 +249,17 @@ async def run(config: SignalGeneratorConfig | None = None) -> None:
# --- Strategy weights (default equal; could load from DB) ---
weights = dict(_DEFAULT_WEIGHTS)
# --- Per-ticker sentiment accumulators ---
sentiment_scores: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES))
sentiment_confidences: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES))
# --- Database (for persisting signals) ---
db_session_factory = None
try:
_engine, db_session_factory = create_db(config)
logger.info("Database session factory initialised for signal persistence")
except Exception:
logger.exception("Failed to initialise DB — signals will NOT be persisted")
logger.info("Consuming from news:scored, publishing to signals:generated")
logger.info(
"Consuming from news:scored and market:bars, publishing to signals:generated"
)
# Graceful shutdown on SIGTERM/SIGINT
shutdown_event = asyncio.Event()
@ -108,62 +267,30 @@ async def run(config: SignalGeneratorConfig | None = None) -> None:
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, shutdown_event.set)
# --- Consume loop ---
# --- Run both consumers concurrently ---
try:
async for _msg_id, data in consumer.consume():
if shutdown_event.is_set():
break
try:
article = ScoredArticle.model_validate(data)
ticker = article.ticker
# Update sentiment accumulators
sentiment_scores[ticker].append(article.sentiment_score)
sentiment_confidences[ticker].append(article.confidence)
# Build sentiment context
sentiment = _build_sentiment_context(
ticker,
sentiment_scores[ticker],
sentiment_confidences[ticker],
async with asyncio.TaskGroup() as tg:
tg.create_task(
_consume_scored_articles(
articles_consumer,
market_data,
ensemble,
weights,
publisher,
shutdown_event,
signals_generated,
per_strategy_signal_count,
db_session_factory,
)
# Get market snapshot (may be None if no bars received yet)
snapshot = market_data.get_snapshot(ticker)
if snapshot is None:
# Create a minimal snapshot from sentiment data alone
# (the news_driven strategy does not require market indicators)
from shared.schemas.trading import MarketSnapshot
snapshot = MarketSnapshot(
ticker=ticker,
current_price=0.0,
open=0.0,
high=0.0,
low=0.0,
close=0.0,
volume=0.0,
)
# Run ensemble
signal_result = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
if signal_result is not None:
await publisher.publish(signal_result.model_dump(mode="json"))
signals_generated.add(1)
for src in signal_result.strategy_sources:
strategy_name = src.split(":")[0]
per_strategy_signal_count.add(1, {"strategy": strategy_name})
logger.info(
"Signal generated: %s %s strength=%.4f sources=%s",
signal_result.direction.value,
ticker,
signal_result.strength,
signal_result.strategy_sources,
)
except Exception:
logger.exception("Error processing scored article: %s", data.get("title", "<unknown>"))
)
tg.create_task(
_consume_market_bars(
bars_consumer,
market_data,
shutdown_event,
bars_received_counter,
)
)
finally:
await redis.aclose()
logger.info("Signal generator stopped gracefully")

View file

@ -15,10 +15,15 @@ import time
import uuid
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import async_sessionmaker
from services.trade_executor.config import TradeExecutorConfig
from services.trade_executor.risk_manager import RiskManager
from shared.broker.alpaca_broker import AlpacaBroker
from shared.db import create_db
from shared.models.trading import Trade as TradeModel
from shared.models.trading import TradeSide as TradeSideModel
from shared.models.trading import TradeStatus as TradeStatusModel
from shared.redis_streams import StreamConsumer, StreamPublisher
from shared.schemas.trading import (
OrderRequest,
@ -39,6 +44,7 @@ async def process_signal(
broker: AlpacaBroker,
publisher: StreamPublisher,
counters: dict,
db_session_factory: async_sessionmaker | None = None,
) -> None:
"""Process a single trade signal: risk check, order, record, publish.
@ -54,6 +60,8 @@ async def process_signal(
Publishes execution results to ``trades:executed``.
counters:
Dict of OpenTelemetry counter/histogram instruments.
db_session_factory:
Optional async session factory for persisting trades to the DB.
"""
# --- Step 1: risk check ---
approved, reason = await risk_manager.check_risk(signal)
@ -93,12 +101,42 @@ async def process_signal(
qty=result.qty,
price=result.filled_price or 0.0,
status=result.status,
signal_id=None,
signal_id=signal.signal_id,
strategy_id=None,
timestamp=result.timestamp,
)
# --- Step 6: publish to trades:executed ---
# --- Step 6: persist trade to DB ---
if db_session_factory is not None:
try:
side_map = {
OrderSide.BUY: TradeSideModel.BUY,
OrderSide.SELL: TradeSideModel.SELL,
}
status_map = {
OrderStatus.PENDING: TradeStatusModel.PENDING,
OrderStatus.FILLED: TradeStatusModel.FILLED,
OrderStatus.CANCELLED: TradeStatusModel.CANCELLED,
OrderStatus.REJECTED: TradeStatusModel.REJECTED,
}
async with db_session_factory() as session:
db_trade = TradeModel(
id=trade_id,
ticker=signal.ticker,
side=side_map[side],
qty=result.qty,
price=result.filled_price or 0.0,
timestamp=str(result.timestamp),
signal_id=signal.signal_id,
status=status_map.get(result.status, TradeStatusModel.PENDING),
)
session.add(db_trade)
await session.commit()
logger.debug("Persisted trade %s to DB (signal_id=%s)", trade_id, signal.signal_id)
except Exception:
logger.exception("Failed to persist trade to DB")
# --- Step 7: publish to trades:executed ---
await publisher.publish(execution.model_dump(mode="json"))
counters["trades_executed"].add(1)
logger.info(
@ -157,6 +195,14 @@ async def run(config: TradeExecutorConfig | None = None) -> None:
# --- Risk manager ---
risk_manager = RiskManager(config, broker)
# --- Database (for persisting trades) ---
db_session_factory = None
try:
_engine, db_session_factory = create_db(config)
logger.info("Database session factory initialised for trade persistence")
except Exception:
logger.exception("Failed to initialise DB — trades will NOT be persisted")
logger.info("Consuming from signals:generated, publishing to trades:executed")
# Graceful shutdown on SIGTERM/SIGINT
@ -172,7 +218,7 @@ async def run(config: TradeExecutorConfig | None = None) -> None:
break
try:
signal_msg = TradeSignal.model_validate(data)
await process_signal(signal_msg, risk_manager, broker, publisher, counters)
await process_signal(signal_msg, risk_manager, broker, publisher, counters, db_session_factory)
except Exception:
logger.exception("Error processing signal: %s", data)
finally:

View file

@ -3,7 +3,7 @@
from datetime import datetime
from enum import Enum
from typing import Any
from uuid import UUID
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
@ -96,6 +96,7 @@ class AccountInfo(BaseModel):
class TradeSignal(BaseModel):
"""Published to ``signals:generated`` by the signal generator."""
signal_id: UUID = Field(default_factory=uuid4)
ticker: str
direction: SignalDirection
strength: float = Field(ge=0.0, le=1.0)

View file

@ -0,0 +1,462 @@
"""Tests for the Market Data service.
Covers configuration defaults, bar parsing from Alpaca response format,
historical bar fetching, poll logic, and Redis publish behaviour.
All Alpaca SDK imports are mocked to avoid requiring pytz and other
transitive dependencies in the test environment.
"""
from __future__ import annotations
import sys
from datetime import datetime, timezone
from types import ModuleType
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from services.market_data.config import MarketDataConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakeBar:
"""Mimics an Alpaca Bar object with the required attributes."""
def __init__(
self,
timestamp: datetime,
open: float,
high: float,
low: float,
close: float,
volume: float,
) -> None:
self.timestamp = timestamp
self.open = open
self.high = high
self.low = low
self.close = close
self.volume = volume
def _make_fake_bar(close: float = 150.0) -> _FakeBar:
return _FakeBar(
timestamp=datetime(2026, 1, 15, 10, 30, tzinfo=timezone.utc),
open=149.0,
high=151.0,
low=148.0,
close=close,
volume=10000.0,
)
class _FakeBarSet(dict):
"""Mimics an Alpaca BarSet (dict-like) mapping ticker to list of bars."""
pass
class _FakeTimeFrame:
"""Mimics alpaca.data.timeframe.TimeFrame."""
def __init__(self, amount, unit):
self.amount = amount
self.unit = unit
class _FakeTimeFrameUnit:
"""Mimics alpaca.data.timeframe.TimeFrameUnit."""
Minute = "Minute"
Hour = "Hour"
Day = "Day"
class _FakeStockBarsRequest:
"""Mimics alpaca.data.requests.StockBarsRequest."""
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def _install_alpaca_mocks():
"""Install mock modules for the alpaca SDK so we can import market_data.main."""
timeframe_mod = ModuleType("alpaca.data.timeframe")
timeframe_mod.TimeFrame = _FakeTimeFrame
timeframe_mod.TimeFrameUnit = _FakeTimeFrameUnit
requests_mod = ModuleType("alpaca.data.requests")
requests_mod.StockBarsRequest = _FakeStockBarsRequest
historical_mod = ModuleType("alpaca.data.historical")
historical_mod.StockHistoricalDataClient = MagicMock
# Build the package hierarchy
alpaca_mod = sys.modules.get("alpaca") or ModuleType("alpaca")
data_mod = sys.modules.get("alpaca.data") or ModuleType("alpaca.data")
sys.modules.setdefault("alpaca", alpaca_mod)
sys.modules["alpaca.data"] = data_mod
sys.modules["alpaca.data.timeframe"] = timeframe_mod
sys.modules["alpaca.data.requests"] = requests_mod
sys.modules["alpaca.data.historical"] = historical_mod
# Install mocks before importing from market_data.main
_install_alpaca_mocks()
from services.market_data.main import ( # noqa: E402
MARKET_BARS_STREAM,
_bar_to_dict,
_fetch_historical_bars,
_parse_timeframe,
_poll_latest_bars,
)
async def _to_thread_passthrough(func, *args, **kwargs):
"""Replacement for asyncio.to_thread that calls synchronously."""
return func(*args, **kwargs)
# ---------------------------------------------------------------------------
# Config defaults
# ---------------------------------------------------------------------------
class TestMarketDataConfig:
"""Tests for MarketDataConfig defaults."""
def test_default_watchlist(self):
config = MarketDataConfig()
assert config.watchlist == ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"]
def test_default_bar_timeframe(self):
config = MarketDataConfig()
assert config.bar_timeframe == "5Min"
def test_default_poll_interval(self):
config = MarketDataConfig()
assert config.poll_interval_seconds == 60
def test_default_historical_bars(self):
config = MarketDataConfig()
assert config.historical_bars == 100
def test_default_alpaca_keys_empty(self):
config = MarketDataConfig()
assert config.alpaca_api_key == ""
assert config.alpaca_secret_key == ""
def test_inherits_base_config_defaults(self):
config = MarketDataConfig()
assert config.log_level == "INFO"
assert config.otel_metrics_port == 9090
# ---------------------------------------------------------------------------
# Timeframe parsing
# ---------------------------------------------------------------------------
class TestParseTimeframe:
"""Tests for the _parse_timeframe helper."""
def test_parse_5min(self):
tf = _parse_timeframe("5Min")
assert tf is not None
assert tf.amount == 5
def test_parse_1min(self):
tf = _parse_timeframe("1Min")
assert tf is not None
assert tf.amount == 1
def test_parse_15min(self):
tf = _parse_timeframe("15Min")
assert tf is not None
assert tf.amount == 15
def test_parse_1hour(self):
tf = _parse_timeframe("1Hour")
assert tf is not None
assert tf.amount == 1
def test_parse_1day(self):
tf = _parse_timeframe("1Day")
assert tf is not None
assert tf.amount == 1
def test_parse_invalid_raises(self):
with pytest.raises(ValueError, match="Unsupported timeframe"):
_parse_timeframe("3Min")
# ---------------------------------------------------------------------------
# Bar to dict conversion
# ---------------------------------------------------------------------------
class TestBarToDict:
"""Tests for _bar_to_dict conversion."""
def test_basic_conversion(self):
bar = _make_fake_bar(close=155.5)
result = _bar_to_dict("AAPL", bar)
assert result["ticker"] == "AAPL"
assert result["close"] == 155.5
assert result["open"] == 149.0
assert result["high"] == 151.0
assert result["low"] == 148.0
assert result["volume"] == 10000.0
assert "timestamp" in result
def test_timestamp_is_iso_format(self):
bar = _make_fake_bar()
result = _bar_to_dict("TSLA", bar)
# Should be parseable back
parsed = datetime.fromisoformat(result["timestamp"])
assert parsed.tzinfo is not None
def test_all_fields_are_floats(self):
bar = _make_fake_bar()
result = _bar_to_dict("AAPL", bar)
for field in ("open", "high", "low", "close", "volume"):
assert isinstance(result[field], float)
def test_ticker_is_passed_through(self):
bar = _make_fake_bar()
result = _bar_to_dict("NVDA", bar)
assert result["ticker"] == "NVDA"
# ---------------------------------------------------------------------------
# Fetch historical bars
# ---------------------------------------------------------------------------
class TestFetchHistoricalBars:
"""Tests for _fetch_historical_bars with mocked Alpaca client."""
@pytest.mark.asyncio
async def test_publishes_bars_for_each_ticker(self):
"""Should publish all historical bars for each ticker in the watchlist."""
bars_aapl = [_make_fake_bar(150.0), _make_fake_bar(151.0)]
bars_tsla = [_make_fake_bar(200.0)]
bar_set = _FakeBarSet({"AAPL": bars_aapl, "TSLA": bars_tsla})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
total = await _fetch_historical_bars(
mock_client, ["AAPL", "TSLA"], MagicMock(), 100, publisher, counter
)
assert total == 3
assert publisher.publish.call_count == 3
counter.add.assert_called_once_with(3)
@pytest.mark.asyncio
async def test_handles_empty_response(self):
"""Should handle tickers with no bars gracefully."""
bar_set = _FakeBarSet({})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
total = await _fetch_historical_bars(
mock_client, ["AAPL"], MagicMock(), 100, publisher, counter
)
assert total == 0
publisher.publish.assert_not_called()
@pytest.mark.asyncio
async def test_handles_client_exception(self):
"""Should log and continue if one ticker fails."""
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(side_effect=Exception("API error"))
publisher = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
total = await _fetch_historical_bars(
mock_client, ["AAPL", "TSLA"], MagicMock(), 100, publisher, counter
)
assert total == 0
publisher.publish.assert_not_called()
@pytest.mark.asyncio
async def test_published_message_format(self):
"""Published messages should have the expected fields."""
bar = _make_fake_bar(175.0)
bar_set = _FakeBarSet({"MSFT": [bar]})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
await _fetch_historical_bars(
mock_client, ["MSFT"], MagicMock(), 100, publisher, counter
)
msg = publisher.publish.call_args[0][0]
assert msg["ticker"] == "MSFT"
assert msg["close"] == 175.0
assert "timestamp" in msg
assert "open" in msg
assert "high" in msg
assert "low" in msg
assert "volume" in msg
# ---------------------------------------------------------------------------
# Poll latest bars
# ---------------------------------------------------------------------------
class TestPollLatestBars:
"""Tests for _poll_latest_bars with mocked Alpaca client."""
@pytest.mark.asyncio
async def test_publishes_latest_bar_per_ticker(self):
"""Should publish one bar per ticker."""
bar_set = _FakeBarSet({"AAPL": [_make_fake_bar(155.0)]})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
count = await _poll_latest_bars(
mock_client, ["AAPL"], MagicMock(), publisher, counter
)
assert count == 1
assert publisher.publish.call_count == 1
published_msg = publisher.publish.call_args[0][0]
assert published_msg["ticker"] == "AAPL"
assert published_msg["close"] == 155.0
@pytest.mark.asyncio
async def test_publishes_only_most_recent_bar(self):
"""When multiple bars returned, should only publish the last one."""
bars = [_make_fake_bar(150.0), _make_fake_bar(155.0)]
bar_set = _FakeBarSet({"AAPL": bars})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
count = await _poll_latest_bars(
mock_client, ["AAPL"], MagicMock(), publisher, counter
)
assert count == 1
published_msg = publisher.publish.call_args[0][0]
assert published_msg["close"] == 155.0
@pytest.mark.asyncio
async def test_handles_no_bars_for_ticker(self):
"""Should return 0 when a ticker has no bars."""
bar_set = _FakeBarSet({"AAPL": []})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
count = await _poll_latest_bars(
mock_client, ["AAPL"], MagicMock(), publisher, counter
)
assert count == 0
publisher.publish.assert_not_called()
@pytest.mark.asyncio
async def test_handles_ticker_not_in_response(self):
"""Should handle gracefully when the response doesn't contain the ticker."""
bar_set = _FakeBarSet({})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
count = await _poll_latest_bars(
mock_client, ["AAPL"], MagicMock(), publisher, counter
)
assert count == 0
publisher.publish.assert_not_called()
@pytest.mark.asyncio
async def test_multiple_tickers(self):
"""Should poll and publish bars for all tickers in watchlist."""
bar_set_aapl = _FakeBarSet({"AAPL": [_make_fake_bar(150.0)]})
bar_set_tsla = _FakeBarSet({"TSLA": [_make_fake_bar(250.0)]})
mock_client = MagicMock()
# Return different bar sets for each call
mock_client.get_stock_bars = MagicMock(
side_effect=[bar_set_aapl, bar_set_tsla]
)
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
count = await _poll_latest_bars(
mock_client, ["AAPL", "TSLA"], MagicMock(), publisher, counter
)
assert count == 2
assert publisher.publish.call_count == 2
# ---------------------------------------------------------------------------
# Stream name constant
# ---------------------------------------------------------------------------
class TestStreamConstants:
"""Verify the stream name constant."""
def test_market_bars_stream_name(self):
assert MARKET_BARS_STREAM == "market:bars"

View file

@ -0,0 +1,456 @@
"""Tests for portfolio sync background task.
Verifies that the sync loop correctly:
- Creates PortfolioSnapshot rows from broker account data
- Upserts Position rows from broker positions
- Removes Position rows for closed positions
- Handles broker errors gracefully
- Respects US market hours
"""
from __future__ import annotations
import asyncio
from datetime import datetime, time, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from services.api_gateway.config import ApiGatewayConfig
from services.api_gateway.tasks.portfolio_sync import (
_sync_once,
is_market_open,
portfolio_sync_loop,
)
from shared.schemas.trading import AccountInfo, PositionInfo
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def config() -> ApiGatewayConfig:
return ApiGatewayConfig(
jwt_secret_key="test-secret-for-sync",
database_url="sqlite+aiosqlite:///:memory:",
redis_url="redis://localhost:6379/0",
alpaca_api_key="test-key",
alpaca_secret_key="test-secret",
paper_trading=True,
snapshot_interval_seconds=1,
)
@pytest.fixture()
def config_no_creds() -> ApiGatewayConfig:
return ApiGatewayConfig(
jwt_secret_key="test-secret-for-sync",
database_url="sqlite+aiosqlite:///:memory:",
redis_url="redis://localhost:6379/0",
alpaca_api_key="",
alpaca_secret_key="",
)
@pytest.fixture()
def mock_account() -> AccountInfo:
return AccountInfo(
equity=105000.0,
cash=50000.0,
buying_power=100000.0,
portfolio_value=105000.0,
)
@pytest.fixture()
def mock_positions() -> list[PositionInfo]:
return [
PositionInfo(
ticker="AAPL",
qty=10.0,
avg_entry=150.0,
current_price=155.0,
unrealized_pnl=50.0,
market_value=1550.0,
),
PositionInfo(
ticker="MSFT",
qty=5.0,
avg_entry=400.0,
current_price=410.0,
unrealized_pnl=50.0,
market_value=2050.0,
),
]
@pytest.fixture()
def mock_broker(mock_account, mock_positions):
broker = AsyncMock()
broker.get_account = AsyncMock(return_value=mock_account)
broker.get_positions = AsyncMock(return_value=mock_positions)
return broker
@pytest.fixture()
def mock_session():
"""Create a mock async session with context manager support."""
session = AsyncMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=False)
# Mock the begin() context manager
begin_ctx = AsyncMock()
begin_ctx.__aenter__ = AsyncMock(return_value=None)
begin_ctx.__aexit__ = AsyncMock(return_value=False)
session.begin = MagicMock(return_value=begin_ctx)
# session.add is synchronous in SQLAlchemy — use MagicMock to avoid warnings
session.add = MagicMock()
return session
@pytest.fixture()
def mock_session_factory(mock_session):
factory = MagicMock()
factory.return_value = mock_session
return factory
# ---------------------------------------------------------------------------
# Market hours tests
# ---------------------------------------------------------------------------
class TestMarketHours:
"""Tests for the is_market_open() function."""
def test_weekday_during_market_hours(self) -> None:
# Wednesday 2024-01-10 at 10:00 AM ET = 15:00 UTC
dt = datetime(2024, 1, 10, 15, 0, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is True
def test_weekday_before_market_open(self) -> None:
# Wednesday 2024-01-10 at 9:00 AM ET = 14:00 UTC
dt = datetime(2024, 1, 10, 14, 0, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is False
def test_weekday_after_market_close(self) -> None:
# Wednesday 2024-01-10 at 4:30 PM ET = 21:30 UTC
dt = datetime(2024, 1, 10, 21, 30, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is False
def test_weekend_saturday(self) -> None:
# Saturday 2024-01-13 at 12:00 PM ET = 17:00 UTC
dt = datetime(2024, 1, 13, 17, 0, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is False
def test_weekend_sunday(self) -> None:
# Sunday 2024-01-14 at 12:00 PM ET = 17:00 UTC
dt = datetime(2024, 1, 14, 17, 0, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is False
def test_market_open_boundary(self) -> None:
# Wednesday 2024-01-10 at exactly 9:30 AM ET = 14:30 UTC
dt = datetime(2024, 1, 10, 14, 30, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is True
def test_market_close_boundary(self) -> None:
# Wednesday 2024-01-10 at exactly 4:00 PM ET = 21:00 UTC
dt = datetime(2024, 1, 10, 21, 0, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is False
# ---------------------------------------------------------------------------
# Snapshot creation tests
# ---------------------------------------------------------------------------
class TestSyncOnce:
"""Tests for the _sync_once() function."""
async def test_creates_portfolio_snapshot(
self, mock_broker, mock_session_factory, mock_session
) -> None:
# Mock the select query to return None (no existing positions)
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=execute_result)
await _sync_once(mock_broker, mock_session_factory)
# Verify the broker was called
mock_broker.get_account.assert_awaited_once()
mock_broker.get_positions.assert_awaited_once()
# Verify session.add was called (snapshot + 2 new positions)
assert mock_session.add.call_count == 3 # 1 snapshot + 2 positions
# Check the snapshot
snapshot_call = mock_session.add.call_args_list[0]
snapshot = snapshot_call[0][0]
assert snapshot.total_value == 105000.0
assert snapshot.cash == 50000.0
assert snapshot.positions_value == 55000.0 # 105000 - 50000
assert snapshot.daily_pnl == 0.0
async def test_creates_position_rows_for_new_positions(
self, mock_broker, mock_session_factory, mock_session
) -> None:
# No existing positions in DB
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=execute_result)
await _sync_once(mock_broker, mock_session_factory)
# Positions are added via session.add (after the snapshot)
position_calls = mock_session.add.call_args_list[1:]
assert len(position_calls) == 2
pos1 = position_calls[0][0][0]
assert pos1.ticker == "AAPL"
assert pos1.qty == 10.0
assert pos1.avg_entry == 150.0
assert pos1.unrealized_pnl == 50.0
pos2 = position_calls[1][0][0]
assert pos2.ticker == "MSFT"
assert pos2.qty == 5.0
assert pos2.avg_entry == 400.0
async def test_updates_existing_position(
self, mock_broker, mock_session_factory, mock_session
) -> None:
# Mock an existing position for AAPL, None for MSFT
existing_aapl = MagicMock()
existing_aapl.ticker = "AAPL"
existing_aapl.qty = 5.0 # old qty
existing_aapl.avg_entry = 140.0 # old entry
result_aapl = MagicMock()
result_aapl.scalar_one_or_none.return_value = existing_aapl
result_msft = MagicMock()
result_msft.scalar_one_or_none.return_value = None
# First execute call is for the delete of stale positions;
# but within the loop, select calls come first
mock_session.execute = AsyncMock(
side_effect=[result_aapl, result_msft, MagicMock()]
)
await _sync_once(mock_broker, mock_session_factory)
# AAPL should be updated in place
assert existing_aapl.qty == 10.0
assert existing_aapl.avg_entry == 150.0
assert existing_aapl.unrealized_pnl == 50.0
# MSFT should be added as new (snapshot + MSFT = 2 adds)
assert mock_session.add.call_count == 2 # snapshot + new MSFT
async def test_removes_closed_positions(
self, mock_session_factory, mock_session
) -> None:
# Broker returns only AAPL (MSFT was sold)
broker = AsyncMock()
broker.get_account = AsyncMock(
return_value=AccountInfo(
equity=100000, cash=90000, buying_power=90000, portfolio_value=100000
)
)
broker.get_positions = AsyncMock(
return_value=[
PositionInfo(
ticker="AAPL",
qty=10.0,
avg_entry=150.0,
current_price=155.0,
unrealized_pnl=50.0,
market_value=1550.0,
)
]
)
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=execute_result)
await _sync_once(broker, mock_session_factory)
# The delete statement should have been executed
# Find the delete call among execute calls
delete_called = False
for call in mock_session.execute.call_args_list:
stmt = call[0][0]
# Check if it's a delete statement (SQLAlchemy Delete object)
stmt_str = str(stmt)
if "DELETE" in stmt_str.upper():
delete_called = True
break
assert delete_called, "Expected a DELETE statement for closed positions"
async def test_removes_all_positions_when_broker_has_none(
self, mock_session_factory, mock_session
) -> None:
broker = AsyncMock()
broker.get_account = AsyncMock(
return_value=AccountInfo(
equity=100000, cash=100000, buying_power=100000, portfolio_value=100000
)
)
broker.get_positions = AsyncMock(return_value=[])
mock_session.execute = AsyncMock(return_value=MagicMock())
await _sync_once(broker, mock_session_factory)
# Should delete all positions since broker has none
delete_called = False
for call in mock_session.execute.call_args_list:
stmt = call[0][0]
stmt_str = str(stmt)
if "DELETE" in stmt_str.upper():
delete_called = True
break
assert delete_called, "Expected a DELETE statement to clear all positions"
# ---------------------------------------------------------------------------
# Error handling tests
# ---------------------------------------------------------------------------
class TestSyncErrorHandling:
"""Tests that the sync loop handles errors gracefully."""
async def test_broker_error_does_not_crash_loop(
self, config, mock_session_factory
) -> None:
"""Broker raises an exception — loop should catch it and continue."""
call_count = 0
async def mock_sync_once(broker, sf):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ConnectionError("Broker API down")
# Second call succeeds
with (
patch(
"services.api_gateway.tasks.portfolio_sync.AlpacaBroker"
) as MockBroker,
patch(
"services.api_gateway.tasks.portfolio_sync._sync_once",
side_effect=mock_sync_once,
),
patch(
"services.api_gateway.tasks.portfolio_sync.is_market_open",
return_value=True,
),
):
MockBroker.return_value = AsyncMock()
task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory))
# Give it time for 2 iterations (interval=1s)
await asyncio.sleep(2.5)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
assert call_count >= 2, "Loop should have retried after the error"
async def test_no_credentials_returns_immediately(
self, config_no_creds, mock_session_factory
) -> None:
"""When Alpaca credentials are empty, the loop should exit immediately."""
task = asyncio.create_task(
portfolio_sync_loop(config_no_creds, mock_session_factory)
)
# Should complete almost immediately since no creds
await asyncio.wait_for(task, timeout=2.0)
# If we get here without timeout, the function returned correctly
# ---------------------------------------------------------------------------
# Market hours integration with loop
# ---------------------------------------------------------------------------
class TestSyncLoopMarketHours:
"""Tests that the loop respects market hours."""
async def test_skips_sync_outside_market_hours(
self, config, mock_session_factory
) -> None:
sync_called = False
async def mock_sync(broker, sf):
nonlocal sync_called
sync_called = True
with (
patch(
"services.api_gateway.tasks.portfolio_sync.AlpacaBroker"
) as MockBroker,
patch(
"services.api_gateway.tasks.portfolio_sync._sync_once",
side_effect=mock_sync,
),
patch(
"services.api_gateway.tasks.portfolio_sync.is_market_open",
return_value=False,
),
):
MockBroker.return_value = AsyncMock()
task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory))
await asyncio.sleep(1.5)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
assert not sync_called, "Sync should not run outside market hours"
async def test_runs_sync_during_market_hours(
self, config, mock_session_factory
) -> None:
sync_called = False
async def mock_sync(broker, sf):
nonlocal sync_called
sync_called = True
with (
patch(
"services.api_gateway.tasks.portfolio_sync.AlpacaBroker"
) as MockBroker,
patch(
"services.api_gateway.tasks.portfolio_sync._sync_once",
side_effect=mock_sync,
),
patch(
"services.api_gateway.tasks.portfolio_sync.is_market_open",
return_value=True,
),
):
MockBroker.return_value = AsyncMock()
task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory))
await asyncio.sleep(1.5)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
assert sync_called, "Sync should run during market hours"

View file

@ -1,7 +1,7 @@
"""Tests for the sentiment analyzer service.
Covers FinBERT analyzer, Ollama analyzer, ticker extraction, and the main
service flow.
service flow including DB persistence.
"""
from __future__ import annotations
@ -10,6 +10,7 @@ from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy.exc import IntegrityError
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
@ -409,3 +410,229 @@ class TestMainFlow:
publisher.publish.assert_not_called()
# Still counted as scored
counters["articles_scored"].add.assert_called_once_with(1)
# ---------------------------------------------------------------------------
# Helpers for DB persistence tests
# ---------------------------------------------------------------------------
def _make_counters() -> dict:
"""Create a dict of mock OpenTelemetry counter/histogram instruments."""
return {
"articles_scored": MagicMock(),
"finbert_count": MagicMock(),
"ollama_count": MagicMock(),
"inference_latency": MagicMock(),
}
def _make_mock_db_session_factory(session: AsyncMock | None = None) -> AsyncMock:
"""Create a mock async_sessionmaker that yields a mock session.
The returned factory, when called, returns an async context manager
that yields ``session`` (a mock AsyncSession).
"""
if session is None:
session = AsyncMock()
session.add = MagicMock()
session.commit = AsyncMock()
factory = MagicMock()
# factory() should return an async context manager (the session)
ctx = AsyncMock()
ctx.__aenter__ = AsyncMock(return_value=session)
ctx.__aexit__ = AsyncMock(return_value=False)
factory.return_value = ctx
return factory
# ---------------------------------------------------------------------------
# DB Persistence Tests
# ---------------------------------------------------------------------------
class TestDBPersistence:
"""Tests for the DB write step in process_article."""
@pytest.mark.asyncio
async def test_db_write_creates_article_and_sentiments(self):
"""When db_session_factory is provided, Article and ArticleSentiment rows are created."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.75, 0.88))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = _make_counters()
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
db_factory = _make_mock_db_session_factory(mock_session)
article = _make_raw_article(
title="$AAPL and $MSFT report strong earnings",
content="Both Apple and Microsoft beat estimates.",
)
await process_article(
article, finbert, ollama, publisher, config, counters, db_factory
)
# session.add should be called: 1 Article + 2 ArticleSentiments = 3 calls
assert mock_session.add.call_count == 3
# Verify the types of objects added
added_objects = [call.args[0] for call in mock_session.add.call_args_list]
from shared.models.news import Article, ArticleSentiment
articles = [o for o in added_objects if isinstance(o, Article)]
sentiments = [o for o in added_objects if isinstance(o, ArticleSentiment)]
assert len(articles) == 1
assert len(sentiments) == 2
# Verify article fields
db_article = articles[0]
assert db_article.source == "test"
assert db_article.url == "https://example.com/article"
assert db_article.content_hash == "abc123"
# Verify sentiment fields
tickers_in_sentiments = {s.ticker for s in sentiments}
assert "AAPL" in tickers_in_sentiments
assert "MSFT" in tickers_in_sentiments
for s in sentiments:
assert s.score == 0.75
assert s.confidence == 0.88
assert s.model_used == "finbert"
# session.commit should be called once
mock_session.commit.assert_awaited_once()
# Redis publishing should still happen
assert publisher.publish.call_count == 2
@pytest.mark.asyncio
async def test_db_duplicate_article_handled_gracefully(self):
"""Duplicate content_hash (IntegrityError) should be caught silently."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.5, 0.9))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = _make_counters()
mock_session = AsyncMock()
mock_session.add = MagicMock()
# Simulate IntegrityError on commit (duplicate content_hash)
mock_session.commit = AsyncMock(
side_effect=IntegrityError("duplicate", {}, Exception())
)
db_factory = _make_mock_db_session_factory(mock_session)
article = _make_raw_article(
title="$AAPL news",
content="Apple earnings report.",
)
# Should NOT raise — IntegrityError is caught
await process_article(
article, finbert, ollama, publisher, config, counters, db_factory
)
# Redis publishing should still have happened before the DB write
assert publisher.publish.call_count >= 1
counters["articles_scored"].add.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_db_none_backward_compatible(self):
"""When db_session_factory is None, process_article works without DB writes."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.6, 0.85))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = _make_counters()
article = _make_raw_article(
title="$GOOG quarterly results",
content="Google reports revenue growth.",
)
# Pass db_session_factory=None explicitly (the default)
await process_article(
article, finbert, ollama, publisher, config, counters, db_session_factory=None
)
# Redis publishing should work as before
assert publisher.publish.call_count >= 1
counters["articles_scored"].add.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_db_error_does_not_break_processing(self):
"""A DB error should be logged but not prevent Redis publishing."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.3, 0.7))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = _make_counters()
mock_session = AsyncMock()
mock_session.add = MagicMock()
# Simulate a generic DB error
mock_session.commit = AsyncMock(
side_effect=RuntimeError("connection lost")
)
db_factory = _make_mock_db_session_factory(mock_session)
article = _make_raw_article(
title="$TSLA stock update",
content="Tesla announces new factory.",
)
# Should NOT raise — generic exceptions in DB write are caught
await process_article(
article, finbert, ollama, publisher, config, counters, db_factory
)
# Redis publishing should still have happened
assert publisher.publish.call_count >= 1
counters["articles_scored"].add.assert_called_once_with(1)

View file

@ -357,3 +357,201 @@ class TestEnsembleTagsStrategySources:
assert parts[0] == "alpha"
assert parts[1] == "LONG"
assert float(parts[2]) == pytest.approx(0.75, abs=0.01)
# ---------------------------------------------------------------------------
# Signal Generator — current_price injection into sentiment_context
# ---------------------------------------------------------------------------
class TestCurrentPriceInjection:
"""Verify that current_price flows into sentiment_context on published signals."""
@pytest.mark.asyncio
async def test_current_price_set_when_snapshot_has_price(self):
"""When snapshot has a positive current_price, it should appear in sentiment_context."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
ensemble = WeightedEnsemble([s1], threshold=0.0)
market = MarketSnapshot(
ticker="AAPL",
current_price=155.50,
open=154.0,
high=156.0,
low=153.0,
close=155.50,
volume=5000,
)
sentiment = SentimentContext(
ticker="AAPL",
avg_score=0.7,
article_count=5,
recent_scores=[0.6, 0.7, 0.8],
avg_confidence=0.85,
)
weights = {"alpha": 1.0}
signal_result = await ensemble.evaluate("AAPL", market, sentiment, weights)
assert signal_result is not None
# Simulate the current_price injection from the signal generator main loop
if signal_result.sentiment_context is None:
signal_result.sentiment_context = {}
if market.current_price > 0:
signal_result.sentiment_context["current_price"] = market.current_price
assert "current_price" in signal_result.sentiment_context
assert signal_result.sentiment_context["current_price"] == 155.50
@pytest.mark.asyncio
async def test_current_price_not_set_when_snapshot_is_none(self):
"""When snapshot is None (minimal fallback), current_price should not be injected."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
ensemble = WeightedEnsemble([s1], threshold=0.0)
# Minimal fallback snapshot with current_price=0.0
market = MarketSnapshot(
ticker="AAPL",
current_price=0.0,
open=0.0,
high=0.0,
low=0.0,
close=0.0,
volume=0.0,
)
weights = {"alpha": 1.0}
signal_result = await ensemble.evaluate("AAPL", market, None, weights)
assert signal_result is not None
# Simulate the injection logic — should NOT set current_price when price is 0
if market.current_price > 0:
if signal_result.sentiment_context is None:
signal_result.sentiment_context = {}
signal_result.sentiment_context["current_price"] = market.current_price
# sentiment_context should be None (no sentiment was passed, and no price injection)
assert signal_result.sentiment_context is None
@pytest.mark.asyncio
async def test_current_price_not_set_when_price_is_zero(self):
"""current_price should NOT be injected when snapshot.current_price is exactly 0."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.8))
ensemble = WeightedEnsemble([s1], threshold=0.0)
sentiment = SentimentContext(
ticker="AAPL",
avg_score=0.5,
article_count=3,
recent_scores=[0.5],
avg_confidence=0.9,
)
market = MarketSnapshot(
ticker="AAPL",
current_price=0.0,
open=0.0,
high=0.0,
low=0.0,
close=0.0,
volume=0.0,
)
weights = {"alpha": 1.0}
signal_result = await ensemble.evaluate("AAPL", market, sentiment, weights)
assert signal_result is not None
# Simulate injection logic
if market.current_price > 0:
if signal_result.sentiment_context is None:
signal_result.sentiment_context = {}
signal_result.sentiment_context["current_price"] = market.current_price
# sentiment_context should exist (from ensemble) but WITHOUT current_price
assert signal_result.sentiment_context is not None
assert "current_price" not in signal_result.sentiment_context
@pytest.mark.asyncio
async def test_current_price_preserves_existing_sentiment_context(self):
"""Injecting current_price should not overwrite other sentiment_context fields."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
ensemble = WeightedEnsemble([s1], threshold=0.0)
sentiment = SentimentContext(
ticker="AAPL",
avg_score=0.75,
article_count=10,
recent_scores=[0.7, 0.8],
avg_confidence=0.9,
)
market = MarketSnapshot(
ticker="AAPL",
current_price=200.0,
open=198.0,
high=202.0,
low=197.0,
close=200.0,
volume=8000,
)
weights = {"alpha": 1.0}
signal_result = await ensemble.evaluate("AAPL", market, sentiment, weights)
assert signal_result is not None
assert signal_result.sentiment_context is not None
# Preserve existing keys
original_keys = set(signal_result.sentiment_context.keys())
# Inject current_price
if market.current_price > 0:
signal_result.sentiment_context["current_price"] = market.current_price
# All original keys should still be present
for key in original_keys:
assert key in signal_result.sentiment_context
# Plus the new one
assert signal_result.sentiment_context["current_price"] == 200.0
# ---------------------------------------------------------------------------
# Signal Generator — signal_id generation
# ---------------------------------------------------------------------------
class TestSignalIdGeneration:
"""Verify that TradeSignal includes a signal_id UUID."""
@pytest.mark.asyncio
async def test_signal_has_uuid(self):
"""Each generated signal should have a signal_id UUID."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
ensemble = WeightedEnsemble([s1], threshold=0.0)
market = MarketSnapshot(
ticker="AAPL", current_price=150.0,
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
)
weights = {"alpha": 1.0}
signal_result = await ensemble.evaluate("AAPL", market, None, weights)
assert signal_result is not None
assert signal_result.signal_id is not None
from uuid import UUID
assert isinstance(signal_result.signal_id, UUID)
@pytest.mark.asyncio
async def test_signal_ids_are_unique(self):
"""Multiple signals should get different signal_id values."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
ensemble = WeightedEnsemble([s1], threshold=0.0)
market = MarketSnapshot(
ticker="AAPL", current_price=150.0,
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
)
weights = {"alpha": 1.0}
sig1 = await ensemble.evaluate("AAPL", market, None, weights)
sig2 = await ensemble.evaluate("AAPL", market, None, weights)
assert sig1 is not None and sig2 is not None
assert sig1.signal_id != sig2.signal_id
def test_signal_id_serializes_in_json(self):
"""signal_id should serialize properly in JSON mode."""
signal = _make_signal(SignalDirection.LONG, 0.8)
data = signal.model_dump(mode="json")
assert "signal_id" in data
assert isinstance(data["signal_id"], str)

View file

@ -401,3 +401,119 @@ class TestExecutorFlowRejected:
# Rejection counter should have been incremented
counters["rejections"].add.assert_called_once()
# ---------------------------------------------------------------------------
# Executor flow — DB persistence
# ---------------------------------------------------------------------------
def _make_mock_db_session_factory(session=None):
"""Create a mock async_sessionmaker that yields a mock session."""
if session is None:
session = AsyncMock()
session.add = MagicMock()
session.commit = AsyncMock()
factory = MagicMock()
ctx = AsyncMock()
ctx.__aenter__ = AsyncMock(return_value=session)
ctx.__aexit__ = AsyncMock(return_value=False)
factory.return_value = ctx
return factory
class TestExecutorDBPersistence:
"""Verify that trades are persisted to the DB when db_session_factory is provided."""
@pytest.mark.asyncio
async def test_trade_persisted_with_signal_id(self):
"""When db_session_factory is provided, a Trade row should be created."""
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
counters = {
"trades_executed": MagicMock(),
"rejections": MagicMock(),
"fill_latency": MagicMock(),
}
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
db_factory = _make_mock_db_session_factory(mock_session)
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
await process_signal(
signal, RiskManager(config, broker), broker, publisher, counters, db_factory
)
# Trade should be persisted
mock_session.add.assert_called_once()
mock_session.commit.assert_awaited_once()
# Verify the trade object
trade_obj = mock_session.add.call_args[0][0]
assert trade_obj.ticker == "AAPL"
assert trade_obj.signal_id == signal.signal_id
@pytest.mark.asyncio
async def test_trade_not_persisted_without_db(self):
"""When db_session_factory is None, no DB write should happen."""
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
counters = {
"trades_executed": MagicMock(),
"rejections": MagicMock(),
"fill_latency": MagicMock(),
}
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
await process_signal(
signal, RiskManager(config, broker), broker, publisher, counters, None
)
# Should still publish
publisher.publish.assert_called_once()
@pytest.mark.asyncio
async def test_db_error_does_not_block_publishing(self):
"""A DB error should not prevent the trade from being published."""
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
counters = {
"trades_executed": MagicMock(),
"rejections": MagicMock(),
"fill_latency": MagicMock(),
}
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock(side_effect=RuntimeError("DB connection lost"))
db_factory = _make_mock_db_session_factory(mock_session)
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
await process_signal(
signal, RiskManager(config, broker), broker, publisher, counters, db_factory
)
# Trade should still be published despite DB error
publisher.publish.assert_called_once()
counters["trades_executed"].add.assert_called_once_with(1)
def test_signal_id_flows_through_execution(self):
"""signal_id from TradeSignal should appear in the published TradeExecution."""
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
assert signal.signal_id is not None
# Verify signal_id is a UUID
from uuid import UUID
assert isinstance(signal.signal_id, UUID)