feat: real data pipeline — market data, DB persistence, portfolio sync, signal-trade linkage
Wire the trading bot to real Alpaca market data and persist pipeline state to the database so the dashboard displays live information. - Add market-data service fetching OHLCV bars from Alpaca, publishing to market:bars Redis Stream; signal generator consumes bars and injects current_price into signals for position sizing - Sentiment analyzer now persists Article + ArticleSentiment rows to DB after scoring, with duplicate and error handling - API gateway runs a background portfolio sync task that snapshots Alpaca account state into PortfolioSnapshot/Position DB tables during market hours - TradeSignal carries a signal_id UUID; signal generator and trade executor both persist their records to DB with cross-references - 303 unit tests pass (57 new tests added)
This commit is contained in:
parent
5a6b20c8f1
commit
e2a3bd456d
19 changed files with 2238 additions and 72 deletions
10
.env.example
10
.env.example
|
|
@ -10,6 +10,16 @@ TRADING_LOG_LEVEL=INFO
|
|||
TRADING_ALPACA_API_KEY=your_api_key_here
|
||||
TRADING_ALPACA_SECRET_KEY=your_secret_key_here
|
||||
TRADING_ALPACA_BASE_URL=https://paper-api.alpaca.markets
|
||||
TRADING_PAPER_TRADING=true
|
||||
|
||||
# Market data service — watchlist tickers (comma-separated)
|
||||
TRADING_WATCHLIST=["AAPL","TSLA","NVDA","MSFT","GOOGL"]
|
||||
TRADING_BAR_TIMEFRAME=5Min
|
||||
TRADING_POLL_INTERVAL_SECONDS=60
|
||||
TRADING_HISTORICAL_BARS=100
|
||||
|
||||
# Portfolio sync interval (seconds, api-gateway background task)
|
||||
TRADING_SNAPSHOT_INTERVAL_SECONDS=60
|
||||
|
||||
# JWT — REQUIRED, generate with: python -c "import secrets; print(secrets.token_hex(32))"
|
||||
TRADING_JWT_SECRET_KEY=
|
||||
|
|
|
|||
|
|
@ -142,6 +142,19 @@ services:
|
|||
env_file: .env
|
||||
restart: unless-stopped
|
||||
|
||||
market-data:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.service
|
||||
args:
|
||||
EXTRAS: "trading"
|
||||
SERVICE_MODULE: "market_data"
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
env_file: .env
|
||||
restart: unless-stopped
|
||||
|
||||
api-gateway:
|
||||
build:
|
||||
context: .
|
||||
|
|
|
|||
|
|
@ -16,6 +16,12 @@ class ApiGatewayConfig(BaseConfig):
|
|||
access_token_expire_minutes: int = 15
|
||||
refresh_token_expire_days: int = 7
|
||||
|
||||
# Alpaca brokerage credentials (for portfolio sync)
|
||||
alpaca_api_key: str = ""
|
||||
alpaca_secret_key: str = ""
|
||||
paper_trading: bool = True
|
||||
snapshot_interval_seconds: int = 60
|
||||
|
||||
# CORS settings
|
||||
cors_origins: list[str] = ["http://localhost:5173"]
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator
|
||||
|
|
@ -43,9 +44,23 @@ def create_app(config: ApiGatewayConfig | None = None) -> FastAPI:
|
|||
)
|
||||
app.state.config = config
|
||||
|
||||
# Start portfolio sync background task
|
||||
from services.api_gateway.tasks.portfolio_sync import portfolio_sync_loop
|
||||
|
||||
sync_task = asyncio.create_task(
|
||||
portfolio_sync_loop(config, session_factory)
|
||||
)
|
||||
|
||||
logger.info("API Gateway started")
|
||||
yield
|
||||
|
||||
# Cancel the sync task
|
||||
sync_task.cancel()
|
||||
try:
|
||||
await sync_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Cleanup
|
||||
await app.state.redis.aclose()
|
||||
await engine.dispose()
|
||||
|
|
|
|||
0
services/api_gateway/tasks/__init__.py
Normal file
0
services/api_gateway/tasks/__init__.py
Normal file
155
services/api_gateway/tasks/portfolio_sync.py
Normal file
155
services/api_gateway/tasks/portfolio_sync.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
"""Background task that periodically snapshots Alpaca account state into the DB.
|
||||
|
||||
Runs on a configurable interval (default 60s) during US market hours,
|
||||
creating ``PortfolioSnapshot`` rows and upserting ``Position`` rows so
|
||||
the dashboard portfolio page reflects real brokerage data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, time, timezone
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from services.api_gateway.config import ApiGatewayConfig
|
||||
from shared.broker.alpaca_broker import AlpacaBroker
|
||||
from shared.models.timeseries import PortfolioSnapshot
|
||||
from shared.models.trading import Position
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# US Eastern timezone for market hours check
|
||||
_ET = ZoneInfo("America/New_York")
|
||||
_MARKET_OPEN = time(9, 30)
|
||||
_MARKET_CLOSE = time(16, 0)
|
||||
|
||||
|
||||
def is_market_open(now_utc: datetime | None = None) -> bool:
|
||||
"""Return ``True`` if the US stock market is currently open.
|
||||
|
||||
Checks for weekday (Mon-Fri) and time between 9:30 AM and 4:00 PM ET.
|
||||
"""
|
||||
if now_utc is None:
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
now_et = now_utc.astimezone(_ET)
|
||||
# Weekday check: Monday=0 ... Friday=4
|
||||
if now_et.weekday() >= 5:
|
||||
return False
|
||||
return _MARKET_OPEN <= now_et.time() < _MARKET_CLOSE
|
||||
|
||||
|
||||
async def _sync_once(
|
||||
broker: AlpacaBroker,
|
||||
session_factory: async_sessionmaker,
|
||||
) -> None:
|
||||
"""Perform a single portfolio snapshot and position upsert cycle."""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# 1. Snapshot account state
|
||||
account = await broker.get_account()
|
||||
|
||||
snapshot = PortfolioSnapshot(
|
||||
timestamp=now,
|
||||
total_value=account.portfolio_value,
|
||||
cash=account.cash,
|
||||
positions_value=account.portfolio_value - account.cash,
|
||||
daily_pnl=0.0,
|
||||
)
|
||||
|
||||
# 2. Fetch broker positions
|
||||
broker_positions = await broker.get_positions()
|
||||
broker_tickers = {p.ticker for p in broker_positions}
|
||||
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
# Insert portfolio snapshot
|
||||
session.add(snapshot)
|
||||
|
||||
# Upsert positions
|
||||
for pos_info in broker_positions:
|
||||
result = await session.execute(
|
||||
select(Position).where(Position.ticker == pos_info.ticker)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing is not None:
|
||||
existing.qty = pos_info.qty
|
||||
existing.avg_entry = pos_info.avg_entry
|
||||
existing.unrealized_pnl = pos_info.unrealized_pnl
|
||||
else:
|
||||
new_pos = Position(
|
||||
ticker=pos_info.ticker,
|
||||
qty=pos_info.qty,
|
||||
avg_entry=pos_info.avg_entry,
|
||||
unrealized_pnl=pos_info.unrealized_pnl,
|
||||
stop_loss=None,
|
||||
take_profit=None,
|
||||
)
|
||||
session.add(new_pos)
|
||||
|
||||
# 3. Remove positions that are no longer held at the broker
|
||||
if broker_tickers:
|
||||
await session.execute(
|
||||
delete(Position).where(Position.ticker.notin_(broker_tickers))
|
||||
)
|
||||
else:
|
||||
# No positions at broker — delete all local positions
|
||||
await session.execute(delete(Position))
|
||||
|
||||
logger.info(
|
||||
"Portfolio sync complete: value=%.2f, cash=%.2f, positions=%d",
|
||||
account.portfolio_value,
|
||||
account.cash,
|
||||
len(broker_positions),
|
||||
)
|
||||
|
||||
|
||||
async def portfolio_sync_loop(
|
||||
config: ApiGatewayConfig,
|
||||
session_factory: async_sessionmaker,
|
||||
) -> None:
|
||||
"""Run the portfolio sync loop until cancelled.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config:
|
||||
API Gateway configuration containing Alpaca credentials and
|
||||
the snapshot interval.
|
||||
session_factory:
|
||||
SQLAlchemy async session factory for DB access.
|
||||
"""
|
||||
if not config.alpaca_api_key or not config.alpaca_secret_key:
|
||||
logger.warning(
|
||||
"Alpaca API credentials not configured — portfolio sync disabled"
|
||||
)
|
||||
return
|
||||
|
||||
broker = AlpacaBroker(
|
||||
api_key=config.alpaca_api_key,
|
||||
secret_key=config.alpaca_secret_key,
|
||||
paper=config.paper_trading,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Portfolio sync started (interval=%ds, paper=%s)",
|
||||
config.snapshot_interval_seconds,
|
||||
config.paper_trading,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
if is_market_open():
|
||||
await _sync_once(broker, session_factory)
|
||||
else:
|
||||
logger.debug("Market closed — skipping portfolio snapshot")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Portfolio sync task cancelled — shutting down")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Portfolio sync error — will retry next interval")
|
||||
|
||||
await asyncio.sleep(config.snapshot_interval_seconds)
|
||||
1
services/market_data/__init__.py
Normal file
1
services/market_data/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Market Data service -- fetches OHLCV bars from Alpaca and publishes to Redis Streams."""
|
||||
3
services/market_data/__main__.py
Normal file
3
services/market_data/__main__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from services.market_data.main import main
|
||||
|
||||
main()
|
||||
16
services/market_data/config.py
Normal file
16
services/market_data/config.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""Configuration for the market data service."""
|
||||
|
||||
from shared.config import BaseConfig
|
||||
|
||||
|
||||
class MarketDataConfig(BaseConfig):
|
||||
"""Extends BaseConfig with market-data-specific settings."""
|
||||
|
||||
watchlist: list[str] = ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"]
|
||||
bar_timeframe: str = "5Min"
|
||||
poll_interval_seconds: int = 60
|
||||
historical_bars: int = 100
|
||||
alpaca_api_key: str = ""
|
||||
alpaca_secret_key: str = ""
|
||||
|
||||
model_config = {"env_prefix": "TRADING_"}
|
||||
257
services/market_data/main.py
Normal file
257
services/market_data/main.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
"""Market Data service -- main entry point.
|
||||
|
||||
Fetches historical and live OHLCV bars from Alpaca's market data API
|
||||
and publishes them to the ``market:bars`` Redis Stream for consumption
|
||||
by the signal generator and other downstream services.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from services.market_data.config import MarketDataConfig
|
||||
from shared.redis_streams import StreamPublisher
|
||||
from shared.telemetry import setup_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MARKET_BARS_STREAM = "market:bars"
|
||||
|
||||
|
||||
def _parse_timeframe(timeframe_str: str):
|
||||
"""Parse a timeframe string like '5Min' into an Alpaca TimeFrame object.
|
||||
|
||||
Returns a ``TimeFrame`` instance suitable for ``StockBarsRequest``.
|
||||
"""
|
||||
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
|
||||
|
||||
# Supported formats: "1Min", "5Min", "15Min", "1Hour", "1Day"
|
||||
tf_map = {
|
||||
"1Min": TimeFrame(1, TimeFrameUnit.Minute),
|
||||
"5Min": TimeFrame(5, TimeFrameUnit.Minute),
|
||||
"15Min": TimeFrame(15, TimeFrameUnit.Minute),
|
||||
"1Hour": TimeFrame(1, TimeFrameUnit.Hour),
|
||||
"1Day": TimeFrame(1, TimeFrameUnit.Day),
|
||||
}
|
||||
tf = tf_map.get(timeframe_str)
|
||||
if tf is None:
|
||||
raise ValueError(
|
||||
f"Unsupported timeframe '{timeframe_str}'. "
|
||||
f"Supported values: {list(tf_map.keys())}"
|
||||
)
|
||||
return tf
|
||||
|
||||
|
||||
def _bar_to_dict(ticker: str, bar) -> dict:
|
||||
"""Convert an Alpaca Bar object to a flat dictionary for Redis publishing."""
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"timestamp": bar.timestamp.isoformat(),
|
||||
"open": float(bar.open),
|
||||
"high": float(bar.high),
|
||||
"low": float(bar.low),
|
||||
"close": float(bar.close),
|
||||
"volume": float(bar.volume),
|
||||
}
|
||||
|
||||
|
||||
async def _fetch_historical_bars(
|
||||
client,
|
||||
watchlist: list[str],
|
||||
timeframe,
|
||||
limit: int,
|
||||
publisher: StreamPublisher,
|
||||
bars_published_counter,
|
||||
) -> int:
|
||||
"""Fetch historical bars for each ticker and publish to Redis.
|
||||
|
||||
Returns the total number of bars published.
|
||||
"""
|
||||
from alpaca.data.requests import StockBarsRequest
|
||||
|
||||
total_published = 0
|
||||
|
||||
# Use a start time far enough back to get the requested number of bars
|
||||
start = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
|
||||
for ticker in watchlist:
|
||||
try:
|
||||
request = StockBarsRequest(
|
||||
symbol_or_symbols=[ticker],
|
||||
timeframe=timeframe,
|
||||
start=start,
|
||||
limit=limit,
|
||||
)
|
||||
bars = await asyncio.to_thread(client.get_stock_bars, request)
|
||||
|
||||
ticker_bars = bars[ticker] if ticker in bars else []
|
||||
for bar in ticker_bars:
|
||||
msg = _bar_to_dict(ticker, bar)
|
||||
await publisher.publish(msg)
|
||||
total_published += 1
|
||||
|
||||
logger.info(
|
||||
"Published %d historical bars for %s",
|
||||
len(ticker_bars),
|
||||
ticker,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch historical bars for %s", ticker)
|
||||
|
||||
if total_published:
|
||||
bars_published_counter.add(total_published)
|
||||
|
||||
return total_published
|
||||
|
||||
|
||||
async def _poll_latest_bars(
|
||||
client,
|
||||
watchlist: list[str],
|
||||
timeframe,
|
||||
publisher: StreamPublisher,
|
||||
bars_published_counter,
|
||||
) -> int:
|
||||
"""Fetch the latest bar for each ticker and publish to Redis.
|
||||
|
||||
Returns the number of bars published.
|
||||
"""
|
||||
from alpaca.data.requests import StockBarsRequest
|
||||
|
||||
published = 0
|
||||
|
||||
# Fetch bars from the last 10 minutes to ensure we get at least one
|
||||
start = datetime.now(timezone.utc) - timedelta(minutes=10)
|
||||
|
||||
for ticker in watchlist:
|
||||
try:
|
||||
request = StockBarsRequest(
|
||||
symbol_or_symbols=[ticker],
|
||||
timeframe=timeframe,
|
||||
start=start,
|
||||
limit=1,
|
||||
)
|
||||
bars = await asyncio.to_thread(client.get_stock_bars, request)
|
||||
|
||||
ticker_bars = bars[ticker] if ticker in bars else []
|
||||
if ticker_bars:
|
||||
# Publish only the most recent bar
|
||||
bar = ticker_bars[-1]
|
||||
msg = _bar_to_dict(ticker, bar)
|
||||
await publisher.publish(msg)
|
||||
published += 1
|
||||
logger.debug("Published latest bar for %s: close=%.2f", ticker, bar.close)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch latest bar for %s", ticker)
|
||||
|
||||
if published:
|
||||
bars_published_counter.add(published)
|
||||
|
||||
return published
|
||||
|
||||
|
||||
async def run(config: MarketDataConfig | None = None) -> None:
|
||||
"""Main service loop.
|
||||
|
||||
Connects to Alpaca and Redis, fetches historical bars on startup,
|
||||
then polls for new bars at the configured interval.
|
||||
"""
|
||||
if config is None:
|
||||
config = MarketDataConfig()
|
||||
|
||||
logging.basicConfig(level=config.log_level)
|
||||
logger.info("Starting Market Data service")
|
||||
|
||||
# --- Telemetry ---
|
||||
meter = setup_telemetry("market-data", config.otel_metrics_port)
|
||||
bars_published_counter = meter.create_counter(
|
||||
"market_data.bars_published",
|
||||
description="Total OHLCV bars published to market:bars stream",
|
||||
)
|
||||
poll_errors_counter = meter.create_counter(
|
||||
"market_data.poll_errors",
|
||||
description="Total poll cycle errors",
|
||||
)
|
||||
|
||||
# --- Alpaca client ---
|
||||
from alpaca.data.historical import StockHistoricalDataClient
|
||||
|
||||
client = StockHistoricalDataClient(
|
||||
api_key=config.alpaca_api_key,
|
||||
secret_key=config.alpaca_secret_key,
|
||||
)
|
||||
|
||||
# --- Redis ---
|
||||
redis = Redis.from_url(config.redis_url, decode_responses=False)
|
||||
publisher = StreamPublisher(redis, MARKET_BARS_STREAM)
|
||||
|
||||
# --- Parse timeframe ---
|
||||
timeframe = _parse_timeframe(config.bar_timeframe)
|
||||
|
||||
# --- Graceful shutdown ---
|
||||
shutdown_event = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, shutdown_event.set)
|
||||
|
||||
try:
|
||||
# Fetch historical bars on startup
|
||||
logger.info(
|
||||
"Fetching %d historical bars for watchlist: %s",
|
||||
config.historical_bars,
|
||||
config.watchlist,
|
||||
)
|
||||
total = await _fetch_historical_bars(
|
||||
client,
|
||||
config.watchlist,
|
||||
timeframe,
|
||||
config.historical_bars,
|
||||
publisher,
|
||||
bars_published_counter,
|
||||
)
|
||||
logger.info("Historical backfill complete: %d total bars published", total)
|
||||
|
||||
# Poll loop
|
||||
logger.info(
|
||||
"Starting poll loop (interval=%ds) for watchlist: %s",
|
||||
config.poll_interval_seconds,
|
||||
config.watchlist,
|
||||
)
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
shutdown_event.wait(),
|
||||
timeout=config.poll_interval_seconds,
|
||||
)
|
||||
break # Shutdown signaled
|
||||
except asyncio.TimeoutError:
|
||||
pass # Normal timeout — time to poll
|
||||
|
||||
try:
|
||||
count = await _poll_latest_bars(
|
||||
client,
|
||||
config.watchlist,
|
||||
timeframe,
|
||||
publisher,
|
||||
bars_published_counter,
|
||||
)
|
||||
logger.info("Poll cycle complete: %d bars published", count)
|
||||
except Exception:
|
||||
logger.exception("Poll cycle failed")
|
||||
poll_errors_counter.add(1)
|
||||
finally:
|
||||
await redis.aclose()
|
||||
logger.info("Market data service stopped gracefully")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""CLI entry point."""
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -3,7 +3,8 @@
|
|||
Consumes ``news:raw`` articles from Redis Streams, scores them using a
|
||||
tiered approach (FinBERT first, Ollama fallback for low-confidence results),
|
||||
extracts ticker mentions, and publishes ``ScoredArticle`` messages to
|
||||
``news:scored``.
|
||||
``news:scored``. Also persists scored articles to the database (articles +
|
||||
article_sentiments tables) so the dashboard can display real data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -14,11 +15,15 @@ import signal
|
|||
import time
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
|
||||
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
|
||||
from services.sentiment_analyzer.config import SentimentAnalyzerConfig
|
||||
from services.sentiment_analyzer.ticker_extractor import extract_tickers
|
||||
from shared.db import create_db
|
||||
from shared.models.news import Article, ArticleSentiment
|
||||
from shared.redis_streams import StreamConsumer, StreamPublisher
|
||||
from shared.schemas.news import RawArticle, ScoredArticle
|
||||
from shared.telemetry import setup_telemetry
|
||||
|
|
@ -33,6 +38,7 @@ async def process_article(
|
|||
publisher: StreamPublisher,
|
||||
config: SentimentAnalyzerConfig,
|
||||
counters: dict,
|
||||
db_session_factory: async_sessionmaker | None = None,
|
||||
) -> None:
|
||||
"""Score a single article and publish one ScoredArticle per extracted ticker.
|
||||
|
||||
|
|
@ -50,6 +56,9 @@ async def process_article(
|
|||
Service configuration (confidence threshold, etc.).
|
||||
counters:
|
||||
Dict of OpenTelemetry counter/histogram instruments.
|
||||
db_session_factory:
|
||||
Optional async session factory for persisting to the DB.
|
||||
When ``None``, DB persistence is skipped (backward compatible).
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
|
|
@ -103,6 +112,46 @@ async def process_article(
|
|||
|
||||
counters["articles_scored"].add(1)
|
||||
|
||||
# --- Step 5: Persist to DB ---
|
||||
if db_session_factory is not None:
|
||||
try:
|
||||
async with db_session_factory() as session:
|
||||
db_article = Article(
|
||||
source=article.source,
|
||||
url=article.url,
|
||||
title=article.title,
|
||||
published_at=article.published_at,
|
||||
fetched_at=article.fetched_at,
|
||||
content_hash=article.content_hash,
|
||||
)
|
||||
session.add(db_article)
|
||||
|
||||
for ticker in tickers:
|
||||
sentiment = ArticleSentiment(
|
||||
article_id=db_article.id,
|
||||
ticker=ticker,
|
||||
score=score,
|
||||
confidence=confidence,
|
||||
model_used=model_used,
|
||||
)
|
||||
session.add(sentiment)
|
||||
|
||||
await session.commit()
|
||||
logger.debug(
|
||||
"Persisted article '%s' with %d sentiments to DB",
|
||||
article.title[:60],
|
||||
len(tickers),
|
||||
)
|
||||
except IntegrityError:
|
||||
logger.debug(
|
||||
"Article already exists in DB (content_hash=%s), skipping",
|
||||
article.content_hash,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to persist article to DB: %s", article.title[:60]
|
||||
)
|
||||
|
||||
|
||||
async def run(config: SentimentAnalyzerConfig | None = None) -> None:
|
||||
"""Main service loop.
|
||||
|
|
@ -150,6 +199,14 @@ async def run(config: SentimentAnalyzerConfig | None = None) -> None:
|
|||
)
|
||||
ollama = OllamaAnalyzer(model=config.ollama_model, host=config.ollama_host)
|
||||
|
||||
# --- Database ---
|
||||
db_session_factory = None
|
||||
try:
|
||||
_engine, db_session_factory = create_db(config)
|
||||
logger.info("Database session factory initialised")
|
||||
except Exception:
|
||||
logger.exception("Failed to initialise DB — articles will NOT be persisted")
|
||||
|
||||
logger.info("Consuming from news:raw, publishing to news:scored")
|
||||
|
||||
# Graceful shutdown on SIGTERM/SIGINT
|
||||
|
|
@ -165,7 +222,7 @@ async def run(config: SentimentAnalyzerConfig | None = None) -> None:
|
|||
break
|
||||
try:
|
||||
article = RawArticle.model_validate(data)
|
||||
await process_article(article, finbert, ollama, publisher, config, counters)
|
||||
await process_article(article, finbert, ollama, publisher, config, counters, db_session_factory)
|
||||
except Exception:
|
||||
logger.exception("Error processing article: %s", data.get("title", "<unknown>"))
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
"""Signal Generator service -- main entry point.
|
||||
|
||||
Consumes ``news:scored`` articles from Redis Streams, updates sentiment
|
||||
context per ticker, runs the weighted ensemble of trading strategies, and
|
||||
publishes qualifying ``TradeSignal`` messages to ``signals:generated``.
|
||||
Consumes ``news:scored`` articles and ``market:bars`` OHLCV data from
|
||||
Redis Streams, updates sentiment context and market data per ticker,
|
||||
runs the weighted ensemble of trading strategies, and publishes
|
||||
qualifying ``TradeSignal`` messages to ``signals:generated``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -10,16 +11,21 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from services.signal_generator.config import SignalGeneratorConfig
|
||||
from services.signal_generator.ensemble import WeightedEnsemble
|
||||
from services.signal_generator.market_data import MarketDataManager
|
||||
from shared.db import create_db
|
||||
from shared.models.trading import Signal as SignalModel
|
||||
from shared.models.trading import SignalDirection as SignalDirectionModel
|
||||
from shared.redis_streams import StreamConsumer, StreamPublisher
|
||||
from shared.schemas.news import ScoredArticle
|
||||
from shared.schemas.trading import SentimentContext
|
||||
from shared.schemas.trading import MarketSnapshot, SentimentContext
|
||||
from shared.strategies import MeanReversionStrategy, MomentumStrategy, NewsDrivenStrategy
|
||||
from shared.telemetry import setup_telemetry
|
||||
|
||||
|
|
@ -53,12 +59,150 @@ def _build_sentiment_context(
|
|||
)
|
||||
|
||||
|
||||
async def _consume_market_bars(
|
||||
bars_consumer: StreamConsumer,
|
||||
market_data: MarketDataManager,
|
||||
shutdown_event: asyncio.Event,
|
||||
bars_received_counter,
|
||||
) -> None:
|
||||
"""Consume OHLCV bars from ``market:bars`` and feed them to the MarketDataManager.
|
||||
|
||||
Runs as a concurrent task alongside the scored-article consumer.
|
||||
"""
|
||||
logger.info("Starting market:bars consumer")
|
||||
async for _msg_id, data in bars_consumer.consume():
|
||||
if shutdown_event.is_set():
|
||||
break
|
||||
try:
|
||||
ticker = data.get("ticker")
|
||||
if not ticker:
|
||||
logger.warning("Received bar message without ticker field: %s", data)
|
||||
continue
|
||||
|
||||
# Build bar_data dict without the ticker key (OHLCVBar doesn't have it)
|
||||
bar_data = {k: v for k, v in data.items() if k != "ticker"}
|
||||
market_data.add_bar(ticker, bar_data)
|
||||
bars_received_counter.add(1)
|
||||
logger.debug("Added bar for %s: close=%s", ticker, data.get("close"))
|
||||
except Exception:
|
||||
logger.exception("Error processing market bar: %s", data)
|
||||
|
||||
|
||||
async def _consume_scored_articles(
|
||||
articles_consumer: StreamConsumer,
|
||||
market_data: MarketDataManager,
|
||||
ensemble: WeightedEnsemble,
|
||||
weights: dict[str, float],
|
||||
publisher: StreamPublisher,
|
||||
shutdown_event: asyncio.Event,
|
||||
signals_generated,
|
||||
per_strategy_signal_count,
|
||||
db_session_factory: async_sessionmaker | None = None,
|
||||
) -> None:
|
||||
"""Consume scored articles from ``news:scored``, run the ensemble, and publish signals.
|
||||
|
||||
Runs as a concurrent task alongside the market-bars consumer.
|
||||
"""
|
||||
# Per-ticker sentiment accumulators
|
||||
sentiment_scores: dict[str, deque[float]] = defaultdict(
|
||||
lambda: deque(maxlen=_MAX_SENTIMENT_SCORES)
|
||||
)
|
||||
sentiment_confidences: dict[str, deque[float]] = defaultdict(
|
||||
lambda: deque(maxlen=_MAX_SENTIMENT_SCORES)
|
||||
)
|
||||
|
||||
logger.info("Starting news:scored consumer")
|
||||
async for _msg_id, data in articles_consumer.consume():
|
||||
if shutdown_event.is_set():
|
||||
break
|
||||
try:
|
||||
article = ScoredArticle.model_validate(data)
|
||||
ticker = article.ticker
|
||||
|
||||
# Update sentiment accumulators
|
||||
sentiment_scores[ticker].append(article.sentiment_score)
|
||||
sentiment_confidences[ticker].append(article.confidence)
|
||||
|
||||
# Build sentiment context
|
||||
sentiment = _build_sentiment_context(
|
||||
ticker,
|
||||
sentiment_scores[ticker],
|
||||
sentiment_confidences[ticker],
|
||||
)
|
||||
|
||||
# Get market snapshot (may be None if no bars received yet)
|
||||
snapshot = market_data.get_snapshot(ticker)
|
||||
if snapshot is None:
|
||||
# Create a minimal snapshot from sentiment data alone
|
||||
# (the news_driven strategy does not require market indicators)
|
||||
snapshot = MarketSnapshot(
|
||||
ticker=ticker,
|
||||
current_price=0.0,
|
||||
open=0.0,
|
||||
high=0.0,
|
||||
low=0.0,
|
||||
close=0.0,
|
||||
volume=0.0,
|
||||
)
|
||||
|
||||
# Run ensemble
|
||||
signal_result = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
|
||||
|
||||
if signal_result is not None:
|
||||
# Inject current price for trade executor position sizing
|
||||
if snapshot and snapshot.current_price > 0:
|
||||
if signal_result.sentiment_context is None:
|
||||
signal_result.sentiment_context = {}
|
||||
signal_result.sentiment_context["current_price"] = snapshot.current_price
|
||||
|
||||
# Persist signal to DB
|
||||
if db_session_factory is not None:
|
||||
try:
|
||||
async with db_session_factory() as session:
|
||||
direction_map = {
|
||||
"LONG": SignalDirectionModel.LONG,
|
||||
"SHORT": SignalDirectionModel.SHORT,
|
||||
"NEUTRAL": SignalDirectionModel.NEUTRAL,
|
||||
}
|
||||
db_signal = SignalModel(
|
||||
id=signal_result.signal_id,
|
||||
ticker=ticker,
|
||||
direction=direction_map[signal_result.direction.value],
|
||||
strength=signal_result.strength,
|
||||
strategy_sources=signal_result.strategy_sources,
|
||||
sentiment_score=sentiment.avg_score if sentiment else None,
|
||||
acted_on=False,
|
||||
)
|
||||
session.add(db_signal)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to persist signal to DB")
|
||||
|
||||
await publisher.publish(signal_result.model_dump(mode="json"))
|
||||
signals_generated.add(1)
|
||||
for src in signal_result.strategy_sources:
|
||||
strategy_name = src.split(":")[0]
|
||||
per_strategy_signal_count.add(1, {"strategy": strategy_name})
|
||||
logger.info(
|
||||
"Signal generated: %s %s strength=%.4f sources=%s",
|
||||
signal_result.direction.value,
|
||||
ticker,
|
||||
signal_result.strength,
|
||||
signal_result.strategy_sources,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error processing scored article: %s", data.get("title", "<unknown>")
|
||||
)
|
||||
|
||||
|
||||
async def run(config: SignalGeneratorConfig | None = None) -> None:
|
||||
"""Main service loop.
|
||||
|
||||
Connects to Redis, initialises strategies and telemetry, then
|
||||
continuously consumes from ``news:scored`` and publishes qualifying
|
||||
signals to ``signals:generated``.
|
||||
continuously consumes from ``news:scored`` and ``market:bars``,
|
||||
publishing qualifying signals to ``signals:generated``.
|
||||
"""
|
||||
if config is None:
|
||||
config = SignalGeneratorConfig()
|
||||
|
|
@ -76,10 +220,19 @@ async def run(config: SignalGeneratorConfig | None = None) -> None:
|
|||
"per_strategy_signal_count",
|
||||
description="Signals emitted, broken down by strategy",
|
||||
)
|
||||
bars_received_counter = meter.create_counter(
|
||||
"bars_received",
|
||||
description="Total OHLCV bars received from market:bars stream",
|
||||
)
|
||||
|
||||
# --- Redis ---
|
||||
redis = Redis.from_url(config.redis_url, decode_responses=False)
|
||||
consumer = StreamConsumer(redis, "news:scored", "signal-generator", "worker-1")
|
||||
articles_consumer = StreamConsumer(
|
||||
redis, "news:scored", "signal-generator", "worker-1"
|
||||
)
|
||||
bars_consumer = StreamConsumer(
|
||||
redis, "market:bars", "signal-generator", "bars-worker"
|
||||
)
|
||||
publisher = StreamPublisher(redis, "signals:generated")
|
||||
|
||||
# --- Market data ---
|
||||
|
|
@ -96,11 +249,17 @@ async def run(config: SignalGeneratorConfig | None = None) -> None:
|
|||
# --- Strategy weights (default equal; could load from DB) ---
|
||||
weights = dict(_DEFAULT_WEIGHTS)
|
||||
|
||||
# --- Per-ticker sentiment accumulators ---
|
||||
sentiment_scores: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES))
|
||||
sentiment_confidences: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES))
|
||||
# --- Database (for persisting signals) ---
|
||||
db_session_factory = None
|
||||
try:
|
||||
_engine, db_session_factory = create_db(config)
|
||||
logger.info("Database session factory initialised for signal persistence")
|
||||
except Exception:
|
||||
logger.exception("Failed to initialise DB — signals will NOT be persisted")
|
||||
|
||||
logger.info("Consuming from news:scored, publishing to signals:generated")
|
||||
logger.info(
|
||||
"Consuming from news:scored and market:bars, publishing to signals:generated"
|
||||
)
|
||||
|
||||
# Graceful shutdown on SIGTERM/SIGINT
|
||||
shutdown_event = asyncio.Event()
|
||||
|
|
@ -108,62 +267,30 @@ async def run(config: SignalGeneratorConfig | None = None) -> None:
|
|||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, shutdown_event.set)
|
||||
|
||||
# --- Consume loop ---
|
||||
# --- Run both consumers concurrently ---
|
||||
try:
|
||||
async for _msg_id, data in consumer.consume():
|
||||
if shutdown_event.is_set():
|
||||
break
|
||||
try:
|
||||
article = ScoredArticle.model_validate(data)
|
||||
ticker = article.ticker
|
||||
|
||||
# Update sentiment accumulators
|
||||
sentiment_scores[ticker].append(article.sentiment_score)
|
||||
sentiment_confidences[ticker].append(article.confidence)
|
||||
|
||||
# Build sentiment context
|
||||
sentiment = _build_sentiment_context(
|
||||
ticker,
|
||||
sentiment_scores[ticker],
|
||||
sentiment_confidences[ticker],
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
tg.create_task(
|
||||
_consume_scored_articles(
|
||||
articles_consumer,
|
||||
market_data,
|
||||
ensemble,
|
||||
weights,
|
||||
publisher,
|
||||
shutdown_event,
|
||||
signals_generated,
|
||||
per_strategy_signal_count,
|
||||
db_session_factory,
|
||||
)
|
||||
|
||||
# Get market snapshot (may be None if no bars received yet)
|
||||
snapshot = market_data.get_snapshot(ticker)
|
||||
if snapshot is None:
|
||||
# Create a minimal snapshot from sentiment data alone
|
||||
# (the news_driven strategy does not require market indicators)
|
||||
from shared.schemas.trading import MarketSnapshot
|
||||
|
||||
snapshot = MarketSnapshot(
|
||||
ticker=ticker,
|
||||
current_price=0.0,
|
||||
open=0.0,
|
||||
high=0.0,
|
||||
low=0.0,
|
||||
close=0.0,
|
||||
volume=0.0,
|
||||
)
|
||||
|
||||
# Run ensemble
|
||||
signal_result = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
|
||||
|
||||
if signal_result is not None:
|
||||
await publisher.publish(signal_result.model_dump(mode="json"))
|
||||
signals_generated.add(1)
|
||||
for src in signal_result.strategy_sources:
|
||||
strategy_name = src.split(":")[0]
|
||||
per_strategy_signal_count.add(1, {"strategy": strategy_name})
|
||||
logger.info(
|
||||
"Signal generated: %s %s strength=%.4f sources=%s",
|
||||
signal_result.direction.value,
|
||||
ticker,
|
||||
signal_result.strength,
|
||||
signal_result.strategy_sources,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing scored article: %s", data.get("title", "<unknown>"))
|
||||
)
|
||||
tg.create_task(
|
||||
_consume_market_bars(
|
||||
bars_consumer,
|
||||
market_data,
|
||||
shutdown_event,
|
||||
bars_received_counter,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await redis.aclose()
|
||||
logger.info("Signal generator stopped gracefully")
|
||||
|
|
|
|||
|
|
@ -15,10 +15,15 @@ import time
|
|||
import uuid
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from services.trade_executor.config import TradeExecutorConfig
|
||||
from services.trade_executor.risk_manager import RiskManager
|
||||
from shared.broker.alpaca_broker import AlpacaBroker
|
||||
from shared.db import create_db
|
||||
from shared.models.trading import Trade as TradeModel
|
||||
from shared.models.trading import TradeSide as TradeSideModel
|
||||
from shared.models.trading import TradeStatus as TradeStatusModel
|
||||
from shared.redis_streams import StreamConsumer, StreamPublisher
|
||||
from shared.schemas.trading import (
|
||||
OrderRequest,
|
||||
|
|
@ -39,6 +44,7 @@ async def process_signal(
|
|||
broker: AlpacaBroker,
|
||||
publisher: StreamPublisher,
|
||||
counters: dict,
|
||||
db_session_factory: async_sessionmaker | None = None,
|
||||
) -> None:
|
||||
"""Process a single trade signal: risk check, order, record, publish.
|
||||
|
||||
|
|
@ -54,6 +60,8 @@ async def process_signal(
|
|||
Publishes execution results to ``trades:executed``.
|
||||
counters:
|
||||
Dict of OpenTelemetry counter/histogram instruments.
|
||||
db_session_factory:
|
||||
Optional async session factory for persisting trades to the DB.
|
||||
"""
|
||||
# --- Step 1: risk check ---
|
||||
approved, reason = await risk_manager.check_risk(signal)
|
||||
|
|
@ -93,12 +101,42 @@ async def process_signal(
|
|||
qty=result.qty,
|
||||
price=result.filled_price or 0.0,
|
||||
status=result.status,
|
||||
signal_id=None,
|
||||
signal_id=signal.signal_id,
|
||||
strategy_id=None,
|
||||
timestamp=result.timestamp,
|
||||
)
|
||||
|
||||
# --- Step 6: publish to trades:executed ---
|
||||
# --- Step 6: persist trade to DB ---
|
||||
if db_session_factory is not None:
|
||||
try:
|
||||
side_map = {
|
||||
OrderSide.BUY: TradeSideModel.BUY,
|
||||
OrderSide.SELL: TradeSideModel.SELL,
|
||||
}
|
||||
status_map = {
|
||||
OrderStatus.PENDING: TradeStatusModel.PENDING,
|
||||
OrderStatus.FILLED: TradeStatusModel.FILLED,
|
||||
OrderStatus.CANCELLED: TradeStatusModel.CANCELLED,
|
||||
OrderStatus.REJECTED: TradeStatusModel.REJECTED,
|
||||
}
|
||||
async with db_session_factory() as session:
|
||||
db_trade = TradeModel(
|
||||
id=trade_id,
|
||||
ticker=signal.ticker,
|
||||
side=side_map[side],
|
||||
qty=result.qty,
|
||||
price=result.filled_price or 0.0,
|
||||
timestamp=str(result.timestamp),
|
||||
signal_id=signal.signal_id,
|
||||
status=status_map.get(result.status, TradeStatusModel.PENDING),
|
||||
)
|
||||
session.add(db_trade)
|
||||
await session.commit()
|
||||
logger.debug("Persisted trade %s to DB (signal_id=%s)", trade_id, signal.signal_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to persist trade to DB")
|
||||
|
||||
# --- Step 7: publish to trades:executed ---
|
||||
await publisher.publish(execution.model_dump(mode="json"))
|
||||
counters["trades_executed"].add(1)
|
||||
logger.info(
|
||||
|
|
@ -157,6 +195,14 @@ async def run(config: TradeExecutorConfig | None = None) -> None:
|
|||
# --- Risk manager ---
|
||||
risk_manager = RiskManager(config, broker)
|
||||
|
||||
# --- Database (for persisting trades) ---
|
||||
db_session_factory = None
|
||||
try:
|
||||
_engine, db_session_factory = create_db(config)
|
||||
logger.info("Database session factory initialised for trade persistence")
|
||||
except Exception:
|
||||
logger.exception("Failed to initialise DB — trades will NOT be persisted")
|
||||
|
||||
logger.info("Consuming from signals:generated, publishing to trades:executed")
|
||||
|
||||
# Graceful shutdown on SIGTERM/SIGINT
|
||||
|
|
@ -172,7 +218,7 @@ async def run(config: TradeExecutorConfig | None = None) -> None:
|
|||
break
|
||||
try:
|
||||
signal_msg = TradeSignal.model_validate(data)
|
||||
await process_signal(signal_msg, risk_manager, broker, publisher, counters)
|
||||
await process_signal(signal_msg, risk_manager, broker, publisher, counters, db_session_factory)
|
||||
except Exception:
|
||||
logger.exception("Error processing signal: %s", data)
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -96,6 +96,7 @@ class AccountInfo(BaseModel):
|
|||
class TradeSignal(BaseModel):
|
||||
"""Published to ``signals:generated`` by the signal generator."""
|
||||
|
||||
signal_id: UUID = Field(default_factory=uuid4)
|
||||
ticker: str
|
||||
direction: SignalDirection
|
||||
strength: float = Field(ge=0.0, le=1.0)
|
||||
|
|
|
|||
462
tests/services/test_market_data.py
Normal file
462
tests/services/test_market_data.py
Normal file
|
|
@ -0,0 +1,462 @@
|
|||
"""Tests for the Market Data service.
|
||||
|
||||
Covers configuration defaults, bar parsing from Alpaca response format,
|
||||
historical bar fetching, poll logic, and Redis publish behaviour.
|
||||
|
||||
All Alpaca SDK imports are mocked to avoid requiring pytz and other
|
||||
transitive dependencies in the test environment.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from types import ModuleType
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.market_data.config import MarketDataConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeBar:
|
||||
"""Mimics an Alpaca Bar object with the required attributes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timestamp: datetime,
|
||||
open: float,
|
||||
high: float,
|
||||
low: float,
|
||||
close: float,
|
||||
volume: float,
|
||||
) -> None:
|
||||
self.timestamp = timestamp
|
||||
self.open = open
|
||||
self.high = high
|
||||
self.low = low
|
||||
self.close = close
|
||||
self.volume = volume
|
||||
|
||||
|
||||
def _make_fake_bar(close: float = 150.0) -> _FakeBar:
|
||||
return _FakeBar(
|
||||
timestamp=datetime(2026, 1, 15, 10, 30, tzinfo=timezone.utc),
|
||||
open=149.0,
|
||||
high=151.0,
|
||||
low=148.0,
|
||||
close=close,
|
||||
volume=10000.0,
|
||||
)
|
||||
|
||||
|
||||
class _FakeBarSet(dict):
|
||||
"""Mimics an Alpaca BarSet (dict-like) mapping ticker to list of bars."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _FakeTimeFrame:
|
||||
"""Mimics alpaca.data.timeframe.TimeFrame."""
|
||||
|
||||
def __init__(self, amount, unit):
|
||||
self.amount = amount
|
||||
self.unit = unit
|
||||
|
||||
|
||||
class _FakeTimeFrameUnit:
|
||||
"""Mimics alpaca.data.timeframe.TimeFrameUnit."""
|
||||
|
||||
Minute = "Minute"
|
||||
Hour = "Hour"
|
||||
Day = "Day"
|
||||
|
||||
|
||||
class _FakeStockBarsRequest:
|
||||
"""Mimics alpaca.data.requests.StockBarsRequest."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
def _install_alpaca_mocks():
|
||||
"""Install mock modules for the alpaca SDK so we can import market_data.main."""
|
||||
timeframe_mod = ModuleType("alpaca.data.timeframe")
|
||||
timeframe_mod.TimeFrame = _FakeTimeFrame
|
||||
timeframe_mod.TimeFrameUnit = _FakeTimeFrameUnit
|
||||
|
||||
requests_mod = ModuleType("alpaca.data.requests")
|
||||
requests_mod.StockBarsRequest = _FakeStockBarsRequest
|
||||
|
||||
historical_mod = ModuleType("alpaca.data.historical")
|
||||
historical_mod.StockHistoricalDataClient = MagicMock
|
||||
|
||||
# Build the package hierarchy
|
||||
alpaca_mod = sys.modules.get("alpaca") or ModuleType("alpaca")
|
||||
data_mod = sys.modules.get("alpaca.data") or ModuleType("alpaca.data")
|
||||
|
||||
sys.modules.setdefault("alpaca", alpaca_mod)
|
||||
sys.modules["alpaca.data"] = data_mod
|
||||
sys.modules["alpaca.data.timeframe"] = timeframe_mod
|
||||
sys.modules["alpaca.data.requests"] = requests_mod
|
||||
sys.modules["alpaca.data.historical"] = historical_mod
|
||||
|
||||
|
||||
# Install mocks before importing from market_data.main
|
||||
_install_alpaca_mocks()
|
||||
|
||||
from services.market_data.main import ( # noqa: E402
|
||||
MARKET_BARS_STREAM,
|
||||
_bar_to_dict,
|
||||
_fetch_historical_bars,
|
||||
_parse_timeframe,
|
||||
_poll_latest_bars,
|
||||
)
|
||||
|
||||
|
||||
async def _to_thread_passthrough(func, *args, **kwargs):
|
||||
"""Replacement for asyncio.to_thread that calls synchronously."""
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarketDataConfig:
|
||||
"""Tests for MarketDataConfig defaults."""
|
||||
|
||||
def test_default_watchlist(self):
|
||||
config = MarketDataConfig()
|
||||
assert config.watchlist == ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"]
|
||||
|
||||
def test_default_bar_timeframe(self):
|
||||
config = MarketDataConfig()
|
||||
assert config.bar_timeframe == "5Min"
|
||||
|
||||
def test_default_poll_interval(self):
|
||||
config = MarketDataConfig()
|
||||
assert config.poll_interval_seconds == 60
|
||||
|
||||
def test_default_historical_bars(self):
|
||||
config = MarketDataConfig()
|
||||
assert config.historical_bars == 100
|
||||
|
||||
def test_default_alpaca_keys_empty(self):
|
||||
config = MarketDataConfig()
|
||||
assert config.alpaca_api_key == ""
|
||||
assert config.alpaca_secret_key == ""
|
||||
|
||||
def test_inherits_base_config_defaults(self):
|
||||
config = MarketDataConfig()
|
||||
assert config.log_level == "INFO"
|
||||
assert config.otel_metrics_port == 9090
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeframe parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseTimeframe:
|
||||
"""Tests for the _parse_timeframe helper."""
|
||||
|
||||
def test_parse_5min(self):
|
||||
tf = _parse_timeframe("5Min")
|
||||
assert tf is not None
|
||||
assert tf.amount == 5
|
||||
|
||||
def test_parse_1min(self):
|
||||
tf = _parse_timeframe("1Min")
|
||||
assert tf is not None
|
||||
assert tf.amount == 1
|
||||
|
||||
def test_parse_15min(self):
|
||||
tf = _parse_timeframe("15Min")
|
||||
assert tf is not None
|
||||
assert tf.amount == 15
|
||||
|
||||
def test_parse_1hour(self):
|
||||
tf = _parse_timeframe("1Hour")
|
||||
assert tf is not None
|
||||
assert tf.amount == 1
|
||||
|
||||
def test_parse_1day(self):
|
||||
tf = _parse_timeframe("1Day")
|
||||
assert tf is not None
|
||||
assert tf.amount == 1
|
||||
|
||||
def test_parse_invalid_raises(self):
|
||||
with pytest.raises(ValueError, match="Unsupported timeframe"):
|
||||
_parse_timeframe("3Min")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bar to dict conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBarToDict:
|
||||
"""Tests for _bar_to_dict conversion."""
|
||||
|
||||
def test_basic_conversion(self):
|
||||
bar = _make_fake_bar(close=155.5)
|
||||
result = _bar_to_dict("AAPL", bar)
|
||||
|
||||
assert result["ticker"] == "AAPL"
|
||||
assert result["close"] == 155.5
|
||||
assert result["open"] == 149.0
|
||||
assert result["high"] == 151.0
|
||||
assert result["low"] == 148.0
|
||||
assert result["volume"] == 10000.0
|
||||
assert "timestamp" in result
|
||||
|
||||
def test_timestamp_is_iso_format(self):
|
||||
bar = _make_fake_bar()
|
||||
result = _bar_to_dict("TSLA", bar)
|
||||
# Should be parseable back
|
||||
parsed = datetime.fromisoformat(result["timestamp"])
|
||||
assert parsed.tzinfo is not None
|
||||
|
||||
def test_all_fields_are_floats(self):
|
||||
bar = _make_fake_bar()
|
||||
result = _bar_to_dict("AAPL", bar)
|
||||
for field in ("open", "high", "low", "close", "volume"):
|
||||
assert isinstance(result[field], float)
|
||||
|
||||
def test_ticker_is_passed_through(self):
|
||||
bar = _make_fake_bar()
|
||||
result = _bar_to_dict("NVDA", bar)
|
||||
assert result["ticker"] == "NVDA"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch historical bars
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFetchHistoricalBars:
|
||||
"""Tests for _fetch_historical_bars with mocked Alpaca client."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publishes_bars_for_each_ticker(self):
|
||||
"""Should publish all historical bars for each ticker in the watchlist."""
|
||||
bars_aapl = [_make_fake_bar(150.0), _make_fake_bar(151.0)]
|
||||
bars_tsla = [_make_fake_bar(200.0)]
|
||||
|
||||
bar_set = _FakeBarSet({"AAPL": bars_aapl, "TSLA": bars_tsla})
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock()
|
||||
counter = MagicMock()
|
||||
|
||||
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
||||
total = await _fetch_historical_bars(
|
||||
mock_client, ["AAPL", "TSLA"], MagicMock(), 100, publisher, counter
|
||||
)
|
||||
|
||||
assert total == 3
|
||||
assert publisher.publish.call_count == 3
|
||||
counter.add.assert_called_once_with(3)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_empty_response(self):
|
||||
"""Should handle tickers with no bars gracefully."""
|
||||
bar_set = _FakeBarSet({})
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
||||
|
||||
publisher = AsyncMock()
|
||||
counter = MagicMock()
|
||||
|
||||
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
||||
total = await _fetch_historical_bars(
|
||||
mock_client, ["AAPL"], MagicMock(), 100, publisher, counter
|
||||
)
|
||||
|
||||
assert total == 0
|
||||
publisher.publish.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_client_exception(self):
|
||||
"""Should log and continue if one ticker fails."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_stock_bars = MagicMock(side_effect=Exception("API error"))
|
||||
|
||||
publisher = AsyncMock()
|
||||
counter = MagicMock()
|
||||
|
||||
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
||||
total = await _fetch_historical_bars(
|
||||
mock_client, ["AAPL", "TSLA"], MagicMock(), 100, publisher, counter
|
||||
)
|
||||
|
||||
assert total == 0
|
||||
publisher.publish.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_published_message_format(self):
|
||||
"""Published messages should have the expected fields."""
|
||||
bar = _make_fake_bar(175.0)
|
||||
bar_set = _FakeBarSet({"MSFT": [bar]})
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock()
|
||||
counter = MagicMock()
|
||||
|
||||
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
||||
await _fetch_historical_bars(
|
||||
mock_client, ["MSFT"], MagicMock(), 100, publisher, counter
|
||||
)
|
||||
|
||||
msg = publisher.publish.call_args[0][0]
|
||||
assert msg["ticker"] == "MSFT"
|
||||
assert msg["close"] == 175.0
|
||||
assert "timestamp" in msg
|
||||
assert "open" in msg
|
||||
assert "high" in msg
|
||||
assert "low" in msg
|
||||
assert "volume" in msg
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Poll latest bars
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPollLatestBars:
|
||||
"""Tests for _poll_latest_bars with mocked Alpaca client."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publishes_latest_bar_per_ticker(self):
|
||||
"""Should publish one bar per ticker."""
|
||||
bar_set = _FakeBarSet({"AAPL": [_make_fake_bar(155.0)]})
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock()
|
||||
counter = MagicMock()
|
||||
|
||||
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
||||
count = await _poll_latest_bars(
|
||||
mock_client, ["AAPL"], MagicMock(), publisher, counter
|
||||
)
|
||||
|
||||
assert count == 1
|
||||
assert publisher.publish.call_count == 1
|
||||
|
||||
published_msg = publisher.publish.call_args[0][0]
|
||||
assert published_msg["ticker"] == "AAPL"
|
||||
assert published_msg["close"] == 155.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publishes_only_most_recent_bar(self):
|
||||
"""When multiple bars returned, should only publish the last one."""
|
||||
bars = [_make_fake_bar(150.0), _make_fake_bar(155.0)]
|
||||
bar_set = _FakeBarSet({"AAPL": bars})
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock()
|
||||
counter = MagicMock()
|
||||
|
||||
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
||||
count = await _poll_latest_bars(
|
||||
mock_client, ["AAPL"], MagicMock(), publisher, counter
|
||||
)
|
||||
|
||||
assert count == 1
|
||||
published_msg = publisher.publish.call_args[0][0]
|
||||
assert published_msg["close"] == 155.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_no_bars_for_ticker(self):
|
||||
"""Should return 0 when a ticker has no bars."""
|
||||
bar_set = _FakeBarSet({"AAPL": []})
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
||||
|
||||
publisher = AsyncMock()
|
||||
counter = MagicMock()
|
||||
|
||||
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
||||
count = await _poll_latest_bars(
|
||||
mock_client, ["AAPL"], MagicMock(), publisher, counter
|
||||
)
|
||||
|
||||
assert count == 0
|
||||
publisher.publish.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_ticker_not_in_response(self):
|
||||
"""Should handle gracefully when the response doesn't contain the ticker."""
|
||||
bar_set = _FakeBarSet({})
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
||||
|
||||
publisher = AsyncMock()
|
||||
counter = MagicMock()
|
||||
|
||||
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
||||
count = await _poll_latest_bars(
|
||||
mock_client, ["AAPL"], MagicMock(), publisher, counter
|
||||
)
|
||||
|
||||
assert count == 0
|
||||
publisher.publish.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_tickers(self):
|
||||
"""Should poll and publish bars for all tickers in watchlist."""
|
||||
bar_set_aapl = _FakeBarSet({"AAPL": [_make_fake_bar(150.0)]})
|
||||
bar_set_tsla = _FakeBarSet({"TSLA": [_make_fake_bar(250.0)]})
|
||||
|
||||
mock_client = MagicMock()
|
||||
# Return different bar sets for each call
|
||||
mock_client.get_stock_bars = MagicMock(
|
||||
side_effect=[bar_set_aapl, bar_set_tsla]
|
||||
)
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock()
|
||||
counter = MagicMock()
|
||||
|
||||
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
||||
count = await _poll_latest_bars(
|
||||
mock_client, ["AAPL", "TSLA"], MagicMock(), publisher, counter
|
||||
)
|
||||
|
||||
assert count == 2
|
||||
assert publisher.publish.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stream name constant
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStreamConstants:
|
||||
"""Verify the stream name constant."""
|
||||
|
||||
def test_market_bars_stream_name(self):
|
||||
assert MARKET_BARS_STREAM == "market:bars"
|
||||
456
tests/services/test_portfolio_sync.py
Normal file
456
tests/services/test_portfolio_sync.py
Normal file
|
|
@ -0,0 +1,456 @@
|
|||
"""Tests for portfolio sync background task.
|
||||
|
||||
Verifies that the sync loop correctly:
|
||||
- Creates PortfolioSnapshot rows from broker account data
|
||||
- Upserts Position rows from broker positions
|
||||
- Removes Position rows for closed positions
|
||||
- Handles broker errors gracefully
|
||||
- Respects US market hours
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, time, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.api_gateway.config import ApiGatewayConfig
|
||||
from services.api_gateway.tasks.portfolio_sync import (
|
||||
_sync_once,
|
||||
is_market_open,
|
||||
portfolio_sync_loop,
|
||||
)
|
||||
from shared.schemas.trading import AccountInfo, PositionInfo
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def config() -> ApiGatewayConfig:
|
||||
return ApiGatewayConfig(
|
||||
jwt_secret_key="test-secret-for-sync",
|
||||
database_url="sqlite+aiosqlite:///:memory:",
|
||||
redis_url="redis://localhost:6379/0",
|
||||
alpaca_api_key="test-key",
|
||||
alpaca_secret_key="test-secret",
|
||||
paper_trading=True,
|
||||
snapshot_interval_seconds=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def config_no_creds() -> ApiGatewayConfig:
|
||||
return ApiGatewayConfig(
|
||||
jwt_secret_key="test-secret-for-sync",
|
||||
database_url="sqlite+aiosqlite:///:memory:",
|
||||
redis_url="redis://localhost:6379/0",
|
||||
alpaca_api_key="",
|
||||
alpaca_secret_key="",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_account() -> AccountInfo:
|
||||
return AccountInfo(
|
||||
equity=105000.0,
|
||||
cash=50000.0,
|
||||
buying_power=100000.0,
|
||||
portfolio_value=105000.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_positions() -> list[PositionInfo]:
|
||||
return [
|
||||
PositionInfo(
|
||||
ticker="AAPL",
|
||||
qty=10.0,
|
||||
avg_entry=150.0,
|
||||
current_price=155.0,
|
||||
unrealized_pnl=50.0,
|
||||
market_value=1550.0,
|
||||
),
|
||||
PositionInfo(
|
||||
ticker="MSFT",
|
||||
qty=5.0,
|
||||
avg_entry=400.0,
|
||||
current_price=410.0,
|
||||
unrealized_pnl=50.0,
|
||||
market_value=2050.0,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_broker(mock_account, mock_positions):
|
||||
broker = AsyncMock()
|
||||
broker.get_account = AsyncMock(return_value=mock_account)
|
||||
broker.get_positions = AsyncMock(return_value=mock_positions)
|
||||
return broker
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_session():
|
||||
"""Create a mock async session with context manager support."""
|
||||
session = AsyncMock()
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Mock the begin() context manager
|
||||
begin_ctx = AsyncMock()
|
||||
begin_ctx.__aenter__ = AsyncMock(return_value=None)
|
||||
begin_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
session.begin = MagicMock(return_value=begin_ctx)
|
||||
|
||||
# session.add is synchronous in SQLAlchemy — use MagicMock to avoid warnings
|
||||
session.add = MagicMock()
|
||||
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_session_factory(mock_session):
|
||||
factory = MagicMock()
|
||||
factory.return_value = mock_session
|
||||
return factory
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Market hours tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarketHours:
|
||||
"""Tests for the is_market_open() function."""
|
||||
|
||||
def test_weekday_during_market_hours(self) -> None:
|
||||
# Wednesday 2024-01-10 at 10:00 AM ET = 15:00 UTC
|
||||
dt = datetime(2024, 1, 10, 15, 0, 0, tzinfo=timezone.utc)
|
||||
assert is_market_open(dt) is True
|
||||
|
||||
def test_weekday_before_market_open(self) -> None:
|
||||
# Wednesday 2024-01-10 at 9:00 AM ET = 14:00 UTC
|
||||
dt = datetime(2024, 1, 10, 14, 0, 0, tzinfo=timezone.utc)
|
||||
assert is_market_open(dt) is False
|
||||
|
||||
def test_weekday_after_market_close(self) -> None:
|
||||
# Wednesday 2024-01-10 at 4:30 PM ET = 21:30 UTC
|
||||
dt = datetime(2024, 1, 10, 21, 30, 0, tzinfo=timezone.utc)
|
||||
assert is_market_open(dt) is False
|
||||
|
||||
def test_weekend_saturday(self) -> None:
|
||||
# Saturday 2024-01-13 at 12:00 PM ET = 17:00 UTC
|
||||
dt = datetime(2024, 1, 13, 17, 0, 0, tzinfo=timezone.utc)
|
||||
assert is_market_open(dt) is False
|
||||
|
||||
def test_weekend_sunday(self) -> None:
|
||||
# Sunday 2024-01-14 at 12:00 PM ET = 17:00 UTC
|
||||
dt = datetime(2024, 1, 14, 17, 0, 0, tzinfo=timezone.utc)
|
||||
assert is_market_open(dt) is False
|
||||
|
||||
def test_market_open_boundary(self) -> None:
|
||||
# Wednesday 2024-01-10 at exactly 9:30 AM ET = 14:30 UTC
|
||||
dt = datetime(2024, 1, 10, 14, 30, 0, tzinfo=timezone.utc)
|
||||
assert is_market_open(dt) is True
|
||||
|
||||
def test_market_close_boundary(self) -> None:
|
||||
# Wednesday 2024-01-10 at exactly 4:00 PM ET = 21:00 UTC
|
||||
dt = datetime(2024, 1, 10, 21, 0, 0, tzinfo=timezone.utc)
|
||||
assert is_market_open(dt) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snapshot creation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyncOnce:
|
||||
"""Tests for the _sync_once() function."""
|
||||
|
||||
async def test_creates_portfolio_snapshot(
|
||||
self, mock_broker, mock_session_factory, mock_session
|
||||
) -> None:
|
||||
# Mock the select query to return None (no existing positions)
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=execute_result)
|
||||
|
||||
await _sync_once(mock_broker, mock_session_factory)
|
||||
|
||||
# Verify the broker was called
|
||||
mock_broker.get_account.assert_awaited_once()
|
||||
mock_broker.get_positions.assert_awaited_once()
|
||||
|
||||
# Verify session.add was called (snapshot + 2 new positions)
|
||||
assert mock_session.add.call_count == 3 # 1 snapshot + 2 positions
|
||||
|
||||
# Check the snapshot
|
||||
snapshot_call = mock_session.add.call_args_list[0]
|
||||
snapshot = snapshot_call[0][0]
|
||||
assert snapshot.total_value == 105000.0
|
||||
assert snapshot.cash == 50000.0
|
||||
assert snapshot.positions_value == 55000.0 # 105000 - 50000
|
||||
assert snapshot.daily_pnl == 0.0
|
||||
|
||||
async def test_creates_position_rows_for_new_positions(
|
||||
self, mock_broker, mock_session_factory, mock_session
|
||||
) -> None:
|
||||
# No existing positions in DB
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=execute_result)
|
||||
|
||||
await _sync_once(mock_broker, mock_session_factory)
|
||||
|
||||
# Positions are added via session.add (after the snapshot)
|
||||
position_calls = mock_session.add.call_args_list[1:]
|
||||
assert len(position_calls) == 2
|
||||
|
||||
pos1 = position_calls[0][0][0]
|
||||
assert pos1.ticker == "AAPL"
|
||||
assert pos1.qty == 10.0
|
||||
assert pos1.avg_entry == 150.0
|
||||
assert pos1.unrealized_pnl == 50.0
|
||||
|
||||
pos2 = position_calls[1][0][0]
|
||||
assert pos2.ticker == "MSFT"
|
||||
assert pos2.qty == 5.0
|
||||
assert pos2.avg_entry == 400.0
|
||||
|
||||
async def test_updates_existing_position(
|
||||
self, mock_broker, mock_session_factory, mock_session
|
||||
) -> None:
|
||||
# Mock an existing position for AAPL, None for MSFT
|
||||
existing_aapl = MagicMock()
|
||||
existing_aapl.ticker = "AAPL"
|
||||
existing_aapl.qty = 5.0 # old qty
|
||||
existing_aapl.avg_entry = 140.0 # old entry
|
||||
|
||||
result_aapl = MagicMock()
|
||||
result_aapl.scalar_one_or_none.return_value = existing_aapl
|
||||
result_msft = MagicMock()
|
||||
result_msft.scalar_one_or_none.return_value = None
|
||||
|
||||
# First execute call is for the delete of stale positions;
|
||||
# but within the loop, select calls come first
|
||||
mock_session.execute = AsyncMock(
|
||||
side_effect=[result_aapl, result_msft, MagicMock()]
|
||||
)
|
||||
|
||||
await _sync_once(mock_broker, mock_session_factory)
|
||||
|
||||
# AAPL should be updated in place
|
||||
assert existing_aapl.qty == 10.0
|
||||
assert existing_aapl.avg_entry == 150.0
|
||||
assert existing_aapl.unrealized_pnl == 50.0
|
||||
|
||||
# MSFT should be added as new (snapshot + MSFT = 2 adds)
|
||||
assert mock_session.add.call_count == 2 # snapshot + new MSFT
|
||||
|
||||
async def test_removes_closed_positions(
|
||||
self, mock_session_factory, mock_session
|
||||
) -> None:
|
||||
# Broker returns only AAPL (MSFT was sold)
|
||||
broker = AsyncMock()
|
||||
broker.get_account = AsyncMock(
|
||||
return_value=AccountInfo(
|
||||
equity=100000, cash=90000, buying_power=90000, portfolio_value=100000
|
||||
)
|
||||
)
|
||||
broker.get_positions = AsyncMock(
|
||||
return_value=[
|
||||
PositionInfo(
|
||||
ticker="AAPL",
|
||||
qty=10.0,
|
||||
avg_entry=150.0,
|
||||
current_price=155.0,
|
||||
unrealized_pnl=50.0,
|
||||
market_value=1550.0,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=execute_result)
|
||||
|
||||
await _sync_once(broker, mock_session_factory)
|
||||
|
||||
# The delete statement should have been executed
|
||||
# Find the delete call among execute calls
|
||||
delete_called = False
|
||||
for call in mock_session.execute.call_args_list:
|
||||
stmt = call[0][0]
|
||||
# Check if it's a delete statement (SQLAlchemy Delete object)
|
||||
stmt_str = str(stmt)
|
||||
if "DELETE" in stmt_str.upper():
|
||||
delete_called = True
|
||||
break
|
||||
assert delete_called, "Expected a DELETE statement for closed positions"
|
||||
|
||||
async def test_removes_all_positions_when_broker_has_none(
|
||||
self, mock_session_factory, mock_session
|
||||
) -> None:
|
||||
broker = AsyncMock()
|
||||
broker.get_account = AsyncMock(
|
||||
return_value=AccountInfo(
|
||||
equity=100000, cash=100000, buying_power=100000, portfolio_value=100000
|
||||
)
|
||||
)
|
||||
broker.get_positions = AsyncMock(return_value=[])
|
||||
|
||||
mock_session.execute = AsyncMock(return_value=MagicMock())
|
||||
|
||||
await _sync_once(broker, mock_session_factory)
|
||||
|
||||
# Should delete all positions since broker has none
|
||||
delete_called = False
|
||||
for call in mock_session.execute.call_args_list:
|
||||
stmt = call[0][0]
|
||||
stmt_str = str(stmt)
|
||||
if "DELETE" in stmt_str.upper():
|
||||
delete_called = True
|
||||
break
|
||||
assert delete_called, "Expected a DELETE statement to clear all positions"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error handling tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyncErrorHandling:
|
||||
"""Tests that the sync loop handles errors gracefully."""
|
||||
|
||||
async def test_broker_error_does_not_crash_loop(
|
||||
self, config, mock_session_factory
|
||||
) -> None:
|
||||
"""Broker raises an exception — loop should catch it and continue."""
|
||||
call_count = 0
|
||||
|
||||
async def mock_sync_once(broker, sf):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise ConnectionError("Broker API down")
|
||||
# Second call succeeds
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.api_gateway.tasks.portfolio_sync.AlpacaBroker"
|
||||
) as MockBroker,
|
||||
patch(
|
||||
"services.api_gateway.tasks.portfolio_sync._sync_once",
|
||||
side_effect=mock_sync_once,
|
||||
),
|
||||
patch(
|
||||
"services.api_gateway.tasks.portfolio_sync.is_market_open",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
MockBroker.return_value = AsyncMock()
|
||||
|
||||
task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory))
|
||||
|
||||
# Give it time for 2 iterations (interval=1s)
|
||||
await asyncio.sleep(2.5)
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
assert call_count >= 2, "Loop should have retried after the error"
|
||||
|
||||
async def test_no_credentials_returns_immediately(
|
||||
self, config_no_creds, mock_session_factory
|
||||
) -> None:
|
||||
"""When Alpaca credentials are empty, the loop should exit immediately."""
|
||||
task = asyncio.create_task(
|
||||
portfolio_sync_loop(config_no_creds, mock_session_factory)
|
||||
)
|
||||
# Should complete almost immediately since no creds
|
||||
await asyncio.wait_for(task, timeout=2.0)
|
||||
# If we get here without timeout, the function returned correctly
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Market hours integration with loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyncLoopMarketHours:
|
||||
"""Tests that the loop respects market hours."""
|
||||
|
||||
async def test_skips_sync_outside_market_hours(
|
||||
self, config, mock_session_factory
|
||||
) -> None:
|
||||
sync_called = False
|
||||
|
||||
async def mock_sync(broker, sf):
|
||||
nonlocal sync_called
|
||||
sync_called = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.api_gateway.tasks.portfolio_sync.AlpacaBroker"
|
||||
) as MockBroker,
|
||||
patch(
|
||||
"services.api_gateway.tasks.portfolio_sync._sync_once",
|
||||
side_effect=mock_sync,
|
||||
),
|
||||
patch(
|
||||
"services.api_gateway.tasks.portfolio_sync.is_market_open",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
MockBroker.return_value = AsyncMock()
|
||||
|
||||
task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory))
|
||||
await asyncio.sleep(1.5)
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
assert not sync_called, "Sync should not run outside market hours"
|
||||
|
||||
async def test_runs_sync_during_market_hours(
|
||||
self, config, mock_session_factory
|
||||
) -> None:
|
||||
sync_called = False
|
||||
|
||||
async def mock_sync(broker, sf):
|
||||
nonlocal sync_called
|
||||
sync_called = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.api_gateway.tasks.portfolio_sync.AlpacaBroker"
|
||||
) as MockBroker,
|
||||
patch(
|
||||
"services.api_gateway.tasks.portfolio_sync._sync_once",
|
||||
side_effect=mock_sync,
|
||||
),
|
||||
patch(
|
||||
"services.api_gateway.tasks.portfolio_sync.is_market_open",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
MockBroker.return_value = AsyncMock()
|
||||
|
||||
task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory))
|
||||
await asyncio.sleep(1.5)
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
assert sync_called, "Sync should run during market hours"
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"""Tests for the sentiment analyzer service.
|
||||
|
||||
Covers FinBERT analyzer, Ollama analyzer, ticker extraction, and the main
|
||||
service flow.
|
||||
service flow including DB persistence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -10,6 +10,7 @@ from datetime import datetime, timezone
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
|
||||
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
|
||||
|
|
@ -409,3 +410,229 @@ class TestMainFlow:
|
|||
publisher.publish.assert_not_called()
|
||||
# Still counted as scored
|
||||
counters["articles_scored"].add.assert_called_once_with(1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for DB persistence tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_counters() -> dict:
|
||||
"""Create a dict of mock OpenTelemetry counter/histogram instruments."""
|
||||
return {
|
||||
"articles_scored": MagicMock(),
|
||||
"finbert_count": MagicMock(),
|
||||
"ollama_count": MagicMock(),
|
||||
"inference_latency": MagicMock(),
|
||||
}
|
||||
|
||||
|
||||
def _make_mock_db_session_factory(session: AsyncMock | None = None) -> AsyncMock:
|
||||
"""Create a mock async_sessionmaker that yields a mock session.
|
||||
|
||||
The returned factory, when called, returns an async context manager
|
||||
that yields ``session`` (a mock AsyncSession).
|
||||
"""
|
||||
if session is None:
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
|
||||
factory = MagicMock()
|
||||
|
||||
# factory() should return an async context manager (the session)
|
||||
ctx = AsyncMock()
|
||||
ctx.__aenter__ = AsyncMock(return_value=session)
|
||||
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
factory.return_value = ctx
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB Persistence Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDBPersistence:
|
||||
"""Tests for the DB write step in process_article."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_write_creates_article_and_sentiments(self):
|
||||
"""When db_session_factory is provided, Article and ArticleSentiment rows are created."""
|
||||
finbert = AsyncMock(spec=FinBERTAnalyzer)
|
||||
finbert.analyze = AsyncMock(return_value=(0.75, 0.88))
|
||||
|
||||
ollama = AsyncMock(spec=OllamaAnalyzer)
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||
|
||||
config = SentimentAnalyzerConfig(
|
||||
finbert_confidence_threshold=0.6,
|
||||
otel_metrics_port=0,
|
||||
)
|
||||
|
||||
counters = _make_counters()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
db_factory = _make_mock_db_session_factory(mock_session)
|
||||
|
||||
article = _make_raw_article(
|
||||
title="$AAPL and $MSFT report strong earnings",
|
||||
content="Both Apple and Microsoft beat estimates.",
|
||||
)
|
||||
|
||||
await process_article(
|
||||
article, finbert, ollama, publisher, config, counters, db_factory
|
||||
)
|
||||
|
||||
# session.add should be called: 1 Article + 2 ArticleSentiments = 3 calls
|
||||
assert mock_session.add.call_count == 3
|
||||
|
||||
# Verify the types of objects added
|
||||
added_objects = [call.args[0] for call in mock_session.add.call_args_list]
|
||||
|
||||
from shared.models.news import Article, ArticleSentiment
|
||||
|
||||
articles = [o for o in added_objects if isinstance(o, Article)]
|
||||
sentiments = [o for o in added_objects if isinstance(o, ArticleSentiment)]
|
||||
|
||||
assert len(articles) == 1
|
||||
assert len(sentiments) == 2
|
||||
|
||||
# Verify article fields
|
||||
db_article = articles[0]
|
||||
assert db_article.source == "test"
|
||||
assert db_article.url == "https://example.com/article"
|
||||
assert db_article.content_hash == "abc123"
|
||||
|
||||
# Verify sentiment fields
|
||||
tickers_in_sentiments = {s.ticker for s in sentiments}
|
||||
assert "AAPL" in tickers_in_sentiments
|
||||
assert "MSFT" in tickers_in_sentiments
|
||||
for s in sentiments:
|
||||
assert s.score == 0.75
|
||||
assert s.confidence == 0.88
|
||||
assert s.model_used == "finbert"
|
||||
|
||||
# session.commit should be called once
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
# Redis publishing should still happen
|
||||
assert publisher.publish.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_duplicate_article_handled_gracefully(self):
|
||||
"""Duplicate content_hash (IntegrityError) should be caught silently."""
|
||||
finbert = AsyncMock(spec=FinBERTAnalyzer)
|
||||
finbert.analyze = AsyncMock(return_value=(0.5, 0.9))
|
||||
|
||||
ollama = AsyncMock(spec=OllamaAnalyzer)
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||
|
||||
config = SentimentAnalyzerConfig(
|
||||
finbert_confidence_threshold=0.6,
|
||||
otel_metrics_port=0,
|
||||
)
|
||||
|
||||
counters = _make_counters()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
# Simulate IntegrityError on commit (duplicate content_hash)
|
||||
mock_session.commit = AsyncMock(
|
||||
side_effect=IntegrityError("duplicate", {}, Exception())
|
||||
)
|
||||
|
||||
db_factory = _make_mock_db_session_factory(mock_session)
|
||||
|
||||
article = _make_raw_article(
|
||||
title="$AAPL news",
|
||||
content="Apple earnings report.",
|
||||
)
|
||||
|
||||
# Should NOT raise — IntegrityError is caught
|
||||
await process_article(
|
||||
article, finbert, ollama, publisher, config, counters, db_factory
|
||||
)
|
||||
|
||||
# Redis publishing should still have happened before the DB write
|
||||
assert publisher.publish.call_count >= 1
|
||||
counters["articles_scored"].add.assert_called_once_with(1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_none_backward_compatible(self):
|
||||
"""When db_session_factory is None, process_article works without DB writes."""
|
||||
finbert = AsyncMock(spec=FinBERTAnalyzer)
|
||||
finbert.analyze = AsyncMock(return_value=(0.6, 0.85))
|
||||
|
||||
ollama = AsyncMock(spec=OllamaAnalyzer)
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||
|
||||
config = SentimentAnalyzerConfig(
|
||||
finbert_confidence_threshold=0.6,
|
||||
otel_metrics_port=0,
|
||||
)
|
||||
|
||||
counters = _make_counters()
|
||||
|
||||
article = _make_raw_article(
|
||||
title="$GOOG quarterly results",
|
||||
content="Google reports revenue growth.",
|
||||
)
|
||||
|
||||
# Pass db_session_factory=None explicitly (the default)
|
||||
await process_article(
|
||||
article, finbert, ollama, publisher, config, counters, db_session_factory=None
|
||||
)
|
||||
|
||||
# Redis publishing should work as before
|
||||
assert publisher.publish.call_count >= 1
|
||||
counters["articles_scored"].add.assert_called_once_with(1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_error_does_not_break_processing(self):
|
||||
"""A DB error should be logged but not prevent Redis publishing."""
|
||||
finbert = AsyncMock(spec=FinBERTAnalyzer)
|
||||
finbert.analyze = AsyncMock(return_value=(0.3, 0.7))
|
||||
|
||||
ollama = AsyncMock(spec=OllamaAnalyzer)
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||
|
||||
config = SentimentAnalyzerConfig(
|
||||
finbert_confidence_threshold=0.6,
|
||||
otel_metrics_port=0,
|
||||
)
|
||||
|
||||
counters = _make_counters()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
# Simulate a generic DB error
|
||||
mock_session.commit = AsyncMock(
|
||||
side_effect=RuntimeError("connection lost")
|
||||
)
|
||||
|
||||
db_factory = _make_mock_db_session_factory(mock_session)
|
||||
|
||||
article = _make_raw_article(
|
||||
title="$TSLA stock update",
|
||||
content="Tesla announces new factory.",
|
||||
)
|
||||
|
||||
# Should NOT raise — generic exceptions in DB write are caught
|
||||
await process_article(
|
||||
article, finbert, ollama, publisher, config, counters, db_factory
|
||||
)
|
||||
|
||||
# Redis publishing should still have happened
|
||||
assert publisher.publish.call_count >= 1
|
||||
counters["articles_scored"].add.assert_called_once_with(1)
|
||||
|
|
|
|||
|
|
@ -357,3 +357,201 @@ class TestEnsembleTagsStrategySources:
|
|||
assert parts[0] == "alpha"
|
||||
assert parts[1] == "LONG"
|
||||
assert float(parts[2]) == pytest.approx(0.75, abs=0.01)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Signal Generator — current_price injection into sentiment_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCurrentPriceInjection:
|
||||
"""Verify that current_price flows into sentiment_context on published signals."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_price_set_when_snapshot_has_price(self):
|
||||
"""When snapshot has a positive current_price, it should appear in sentiment_context."""
|
||||
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
|
||||
ensemble = WeightedEnsemble([s1], threshold=0.0)
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL",
|
||||
current_price=155.50,
|
||||
open=154.0,
|
||||
high=156.0,
|
||||
low=153.0,
|
||||
close=155.50,
|
||||
volume=5000,
|
||||
)
|
||||
sentiment = SentimentContext(
|
||||
ticker="AAPL",
|
||||
avg_score=0.7,
|
||||
article_count=5,
|
||||
recent_scores=[0.6, 0.7, 0.8],
|
||||
avg_confidence=0.85,
|
||||
)
|
||||
weights = {"alpha": 1.0}
|
||||
|
||||
signal_result = await ensemble.evaluate("AAPL", market, sentiment, weights)
|
||||
assert signal_result is not None
|
||||
|
||||
# Simulate the current_price injection from the signal generator main loop
|
||||
if signal_result.sentiment_context is None:
|
||||
signal_result.sentiment_context = {}
|
||||
if market.current_price > 0:
|
||||
signal_result.sentiment_context["current_price"] = market.current_price
|
||||
|
||||
assert "current_price" in signal_result.sentiment_context
|
||||
assert signal_result.sentiment_context["current_price"] == 155.50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_price_not_set_when_snapshot_is_none(self):
|
||||
"""When snapshot is None (minimal fallback), current_price should not be injected."""
|
||||
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
|
||||
ensemble = WeightedEnsemble([s1], threshold=0.0)
|
||||
# Minimal fallback snapshot with current_price=0.0
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL",
|
||||
current_price=0.0,
|
||||
open=0.0,
|
||||
high=0.0,
|
||||
low=0.0,
|
||||
close=0.0,
|
||||
volume=0.0,
|
||||
)
|
||||
weights = {"alpha": 1.0}
|
||||
|
||||
signal_result = await ensemble.evaluate("AAPL", market, None, weights)
|
||||
assert signal_result is not None
|
||||
|
||||
# Simulate the injection logic — should NOT set current_price when price is 0
|
||||
if market.current_price > 0:
|
||||
if signal_result.sentiment_context is None:
|
||||
signal_result.sentiment_context = {}
|
||||
signal_result.sentiment_context["current_price"] = market.current_price
|
||||
|
||||
# sentiment_context should be None (no sentiment was passed, and no price injection)
|
||||
assert signal_result.sentiment_context is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_price_not_set_when_price_is_zero(self):
|
||||
"""current_price should NOT be injected when snapshot.current_price is exactly 0."""
|
||||
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.8))
|
||||
ensemble = WeightedEnsemble([s1], threshold=0.0)
|
||||
sentiment = SentimentContext(
|
||||
ticker="AAPL",
|
||||
avg_score=0.5,
|
||||
article_count=3,
|
||||
recent_scores=[0.5],
|
||||
avg_confidence=0.9,
|
||||
)
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL",
|
||||
current_price=0.0,
|
||||
open=0.0,
|
||||
high=0.0,
|
||||
low=0.0,
|
||||
close=0.0,
|
||||
volume=0.0,
|
||||
)
|
||||
weights = {"alpha": 1.0}
|
||||
|
||||
signal_result = await ensemble.evaluate("AAPL", market, sentiment, weights)
|
||||
assert signal_result is not None
|
||||
|
||||
# Simulate injection logic
|
||||
if market.current_price > 0:
|
||||
if signal_result.sentiment_context is None:
|
||||
signal_result.sentiment_context = {}
|
||||
signal_result.sentiment_context["current_price"] = market.current_price
|
||||
|
||||
# sentiment_context should exist (from ensemble) but WITHOUT current_price
|
||||
assert signal_result.sentiment_context is not None
|
||||
assert "current_price" not in signal_result.sentiment_context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_price_preserves_existing_sentiment_context(self):
|
||||
"""Injecting current_price should not overwrite other sentiment_context fields."""
|
||||
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
|
||||
ensemble = WeightedEnsemble([s1], threshold=0.0)
|
||||
sentiment = SentimentContext(
|
||||
ticker="AAPL",
|
||||
avg_score=0.75,
|
||||
article_count=10,
|
||||
recent_scores=[0.7, 0.8],
|
||||
avg_confidence=0.9,
|
||||
)
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL",
|
||||
current_price=200.0,
|
||||
open=198.0,
|
||||
high=202.0,
|
||||
low=197.0,
|
||||
close=200.0,
|
||||
volume=8000,
|
||||
)
|
||||
weights = {"alpha": 1.0}
|
||||
|
||||
signal_result = await ensemble.evaluate("AAPL", market, sentiment, weights)
|
||||
assert signal_result is not None
|
||||
assert signal_result.sentiment_context is not None
|
||||
|
||||
# Preserve existing keys
|
||||
original_keys = set(signal_result.sentiment_context.keys())
|
||||
|
||||
# Inject current_price
|
||||
if market.current_price > 0:
|
||||
signal_result.sentiment_context["current_price"] = market.current_price
|
||||
|
||||
# All original keys should still be present
|
||||
for key in original_keys:
|
||||
assert key in signal_result.sentiment_context
|
||||
# Plus the new one
|
||||
assert signal_result.sentiment_context["current_price"] == 200.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Signal Generator — signal_id generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSignalIdGeneration:
|
||||
"""Verify that TradeSignal includes a signal_id UUID."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_has_uuid(self):
|
||||
"""Each generated signal should have a signal_id UUID."""
|
||||
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
|
||||
ensemble = WeightedEnsemble([s1], threshold=0.0)
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL", current_price=150.0,
|
||||
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||
)
|
||||
weights = {"alpha": 1.0}
|
||||
|
||||
signal_result = await ensemble.evaluate("AAPL", market, None, weights)
|
||||
assert signal_result is not None
|
||||
assert signal_result.signal_id is not None
|
||||
from uuid import UUID
|
||||
assert isinstance(signal_result.signal_id, UUID)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_ids_are_unique(self):
|
||||
"""Multiple signals should get different signal_id values."""
|
||||
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
|
||||
ensemble = WeightedEnsemble([s1], threshold=0.0)
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL", current_price=150.0,
|
||||
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||
)
|
||||
weights = {"alpha": 1.0}
|
||||
|
||||
sig1 = await ensemble.evaluate("AAPL", market, None, weights)
|
||||
sig2 = await ensemble.evaluate("AAPL", market, None, weights)
|
||||
assert sig1 is not None and sig2 is not None
|
||||
assert sig1.signal_id != sig2.signal_id
|
||||
|
||||
def test_signal_id_serializes_in_json(self):
|
||||
"""signal_id should serialize properly in JSON mode."""
|
||||
signal = _make_signal(SignalDirection.LONG, 0.8)
|
||||
data = signal.model_dump(mode="json")
|
||||
assert "signal_id" in data
|
||||
assert isinstance(data["signal_id"], str)
|
||||
|
|
|
|||
|
|
@ -401,3 +401,119 @@ class TestExecutorFlowRejected:
|
|||
|
||||
# Rejection counter should have been incremented
|
||||
counters["rejections"].add.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor flow — DB persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_db_session_factory(session=None):
|
||||
"""Create a mock async_sessionmaker that yields a mock session."""
|
||||
if session is None:
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
|
||||
factory = MagicMock()
|
||||
ctx = AsyncMock()
|
||||
ctx.__aenter__ = AsyncMock(return_value=session)
|
||||
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
factory.return_value = ctx
|
||||
return factory
|
||||
|
||||
|
||||
class TestExecutorDBPersistence:
|
||||
"""Verify that trades are persisted to the DB when db_session_factory is provided."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trade_persisted_with_signal_id(self):
|
||||
"""When db_session_factory is provided, a Trade row should be created."""
|
||||
config = _make_config()
|
||||
broker = _mock_broker(positions=[], account=_make_account(100_000))
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||
counters = {
|
||||
"trades_executed": MagicMock(),
|
||||
"rejections": MagicMock(),
|
||||
"fill_latency": MagicMock(),
|
||||
}
|
||||
|
||||
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
db_factory = _make_mock_db_session_factory(mock_session)
|
||||
|
||||
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
|
||||
await process_signal(
|
||||
signal, RiskManager(config, broker), broker, publisher, counters, db_factory
|
||||
)
|
||||
|
||||
# Trade should be persisted
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
# Verify the trade object
|
||||
trade_obj = mock_session.add.call_args[0][0]
|
||||
assert trade_obj.ticker == "AAPL"
|
||||
assert trade_obj.signal_id == signal.signal_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trade_not_persisted_without_db(self):
|
||||
"""When db_session_factory is None, no DB write should happen."""
|
||||
config = _make_config()
|
||||
broker = _mock_broker(positions=[], account=_make_account(100_000))
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||
counters = {
|
||||
"trades_executed": MagicMock(),
|
||||
"rejections": MagicMock(),
|
||||
"fill_latency": MagicMock(),
|
||||
}
|
||||
|
||||
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
|
||||
|
||||
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
|
||||
await process_signal(
|
||||
signal, RiskManager(config, broker), broker, publisher, counters, None
|
||||
)
|
||||
|
||||
# Should still publish
|
||||
publisher.publish.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_error_does_not_block_publishing(self):
|
||||
"""A DB error should not prevent the trade from being published."""
|
||||
config = _make_config()
|
||||
broker = _mock_broker(positions=[], account=_make_account(100_000))
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||
counters = {
|
||||
"trades_executed": MagicMock(),
|
||||
"rejections": MagicMock(),
|
||||
"fill_latency": MagicMock(),
|
||||
}
|
||||
|
||||
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock(side_effect=RuntimeError("DB connection lost"))
|
||||
db_factory = _make_mock_db_session_factory(mock_session)
|
||||
|
||||
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
|
||||
await process_signal(
|
||||
signal, RiskManager(config, broker), broker, publisher, counters, db_factory
|
||||
)
|
||||
|
||||
# Trade should still be published despite DB error
|
||||
publisher.publish.assert_called_once()
|
||||
counters["trades_executed"].add.assert_called_once_with(1)
|
||||
|
||||
def test_signal_id_flows_through_execution(self):
|
||||
"""signal_id from TradeSignal should appear in the published TradeExecution."""
|
||||
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
|
||||
assert signal.signal_id is not None
|
||||
# Verify signal_id is a UUID
|
||||
from uuid import UUID
|
||||
assert isinstance(signal.signal_id, UUID)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue