diff --git a/pyproject.toml b/pyproject.toml index fbd6349..64e9c2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,11 +27,12 @@ requires = ["setuptools>=70.0"] build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] -include = ["shared*", "services*", "backtester*", "tests*"] +include = ["shared*", "services*", "backtester*", "scripts*", "tests*"] [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] +markers = ["integration: marks tests requiring docker services (redis, postgres)"] [tool.ruff] line-length = 120 diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/seed_strategies.py b/scripts/seed_strategies.py new file mode 100644 index 0000000..390ba0d --- /dev/null +++ b/scripts/seed_strategies.py @@ -0,0 +1,109 @@ +"""Seed default trading strategies. + +Inserts three strategies with equal initial weights (0.333 each): + - momentum + - mean_reversion + - news_driven + +Usage: + python -m scripts.seed_strategies +""" + +from __future__ import annotations + +import asyncio +import logging + +from sqlalchemy import select + +from shared.config import BaseConfig +from shared.db import create_db +from shared.models.trading import Strategy + +logger = logging.getLogger(__name__) + +# Default strategies to seed +DEFAULT_STRATEGIES = [ + { + "name": "momentum", + "description": ( + "Buy when price crosses above N-period SMA with increasing volume; " + "sell when it crosses below." + ), + "current_weight": 0.333, + "active": True, + }, + { + "name": "mean_reversion", + "description": ( + "Buy when RSI < 30 (oversold); sell when RSI > 70 (overbought). " + "Signal strength proportional to RSI extremity." + ), + "current_weight": 0.333, + "active": True, + }, + { + "name": "news_driven", + "description": ( + "Buy on strong positive sentiment (score > 0.7, confidence > 0.6); " + "sell on strong negative. Decay factor for stale news (> 4 hours)." + ), + "current_weight": 0.333, + "active": True, + }, +] + + +async def seed(database_url: str | None = None) -> None: + """Insert default strategies if they do not already exist. + + Parameters + ---------- + database_url: + Override for the database URL. If ``None``, the default from + :class:`~shared.config.BaseConfig` is used. + """ + config = BaseConfig() + if database_url: + config.database_url = database_url + + _engine, session_factory = create_db(config) + + async with session_factory() as session: + for strategy_data in DEFAULT_STRATEGIES: + # Check if the strategy already exists by name + result = await session.execute( + select(Strategy).where(Strategy.name == strategy_data["name"]) + ) + existing = result.scalar_one_or_none() + + if existing: + logger.info( + "Strategy '%s' already exists (weight=%.3f), skipping", + existing.name, + existing.current_weight, + ) + continue + + strategy = Strategy(**strategy_data) + session.add(strategy) + logger.info( + "Inserted strategy '%s' with weight %.3f", + strategy_data["name"], + strategy_data["current_weight"], + ) + + await session.commit() + + await _engine.dispose() + logger.info("Strategy seeding complete") + + +def main() -> None: + """CLI entry-point.""" + logging.basicConfig(level=logging.INFO) + asyncio.run(seed()) + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke_test.sh b/scripts/smoke_test.sh new file mode 100755 index 0000000..91b84bf --- /dev/null +++ b/scripts/smoke_test.sh @@ -0,0 +1,139 @@ +#!/bin/bash +# Smoke test for the full trading-bot Docker Compose stack. +# +# Usage: +# ./scripts/smoke_test.sh +# +# Prerequisites: +# - Docker Compose stack must be running (docker compose up -d) +# +# This script: +# 1. Waits for services to become healthy +# 2. Hits GET /health -> expects 200 +# 3. Hits GET /api/portfolio -> expects 401 (unauthenticated) +# 4. Hits GET /api/strategies -> expects 401 (unauthenticated) +# 5. Checks docker compose ps shows all services running +# 6. Exits 0 on success, 1 on failure + +set -euo pipefail + +# Configuration +API_BASE="${API_BASE:-http://localhost:8000}" +DASHBOARD_BASE="${DASHBOARD_BASE:-http://localhost:3000}" +MAX_RETRIES="${MAX_RETRIES:-30}" +RETRY_INTERVAL="${RETRY_INTERVAL:-2}" + +PASS=0 +FAIL=0 + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + +log() { + echo "[smoke-test] $*" +} + +pass() { + log "PASS: $*" + PASS=$((PASS + 1)) +} + +fail() { + log "FAIL: $*" + FAIL=$((FAIL + 1)) +} + +wait_for_endpoint() { + local url="$1" + local expected_code="$2" + local description="$3" + local attempt=0 + + while [ "$attempt" -lt "$MAX_RETRIES" ]; do + attempt=$((attempt + 1)) + status_code=$(curl -s -o /dev/null -w "%{http_code}" "$url" 2>/dev/null || echo "000") + if [ "$status_code" = "$expected_code" ]; then + return 0 + fi + log "Waiting for $description ($url) ... attempt $attempt/$MAX_RETRIES (got $status_code, want $expected_code)" + sleep "$RETRY_INTERVAL" + done + + return 1 +} + +check_endpoint() { + local url="$1" + local expected_code="$2" + local description="$3" + + status_code=$(curl -s -o /dev/null -w "%{http_code}" "$url" 2>/dev/null || echo "000") + if [ "$status_code" = "$expected_code" ]; then + pass "$description -> $status_code" + else + fail "$description -> expected $expected_code, got $status_code" + fi +} + +# --------------------------------------------------------------------------- +# 1. Wait for the API gateway health endpoint +# --------------------------------------------------------------------------- +log "Waiting for API gateway to be healthy ..." +if wait_for_endpoint "$API_BASE/health" "200" "API health"; then + pass "API gateway is healthy" +else + fail "API gateway did not become healthy within timeout" + log "Aborting — cannot run further checks without a healthy API" + exit 1 +fi + +# --------------------------------------------------------------------------- +# 2. Health check +# --------------------------------------------------------------------------- +check_endpoint "$API_BASE/health" "200" "GET /health" + +# --------------------------------------------------------------------------- +# 3. Unauthenticated trading endpoints should return 401/403 +# --------------------------------------------------------------------------- +check_endpoint "$API_BASE/api/portfolio" "401" "GET /api/portfolio (no auth)" +check_endpoint "$API_BASE/api/strategies" "401" "GET /api/strategies (no auth)" + +# --------------------------------------------------------------------------- +# 4. Dashboard responds +# --------------------------------------------------------------------------- +log "Checking dashboard ..." +if wait_for_endpoint "$DASHBOARD_BASE/" "200" "Dashboard"; then + pass "Dashboard is serving" +else + fail "Dashboard did not respond" +fi + +# --------------------------------------------------------------------------- +# 5. Docker Compose services status +# --------------------------------------------------------------------------- +log "Checking docker compose service status ..." +if command -v docker &>/dev/null; then + running_count=$(docker compose ps --format json 2>/dev/null | grep -c '"running"' || echo "0") + if [ "$running_count" -gt 0 ]; then + pass "docker compose shows $running_count running services" + else + fail "No running services found in docker compose ps" + fi +else + log "SKIP: docker command not available" +fi + +# --------------------------------------------------------------------------- +# Summary +# --------------------------------------------------------------------------- +echo "" +log "================================" +log "Results: $PASS passed, $FAIL failed" +log "================================" + +if [ "$FAIL" -gt 0 ]; then + exit 1 +fi + +exit 0 diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_news_pipeline.py b/tests/integration/test_news_pipeline.py new file mode 100644 index 0000000..6b0934b --- /dev/null +++ b/tests/integration/test_news_pipeline.py @@ -0,0 +1,299 @@ +"""Integration test: news fetcher -> sentiment analyzer pipeline. + +Publishes a mock RawArticle to the ``news:raw`` Redis stream and verifies +that a ScoredArticle appears on ``news:scored``. + +Requires a running Redis instance (from docker-compose). +FinBERT and Ollama are mocked so the test does not need GPU / model weights. + +Run with: + pytest tests/integration/test_news_pipeline.py -v -m integration +""" + +from __future__ import annotations + +import asyncio +import json +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import pytest +from redis.asyncio import Redis + +from shared.redis_streams import StreamConsumer, StreamPublisher +from shared.schemas.news import RawArticle, ScoredArticle +from services.sentiment_analyzer.main import process_article +from services.sentiment_analyzer.config import SentimentAnalyzerConfig + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +REDIS_URL = "redis://localhost:6379/1" # Use DB 1 to avoid conflicts + +RAW_STREAM = "test:news:raw" +SCORED_STREAM = "test:news:scored" + + +@pytest.fixture +async def redis_client(): + """Provide a clean Redis connection on DB 1 and clean up streams after.""" + client = Redis.from_url(REDIS_URL, decode_responses=False) + # Ensure streams are clean before the test + await client.delete(RAW_STREAM, SCORED_STREAM) + yield client + # Clean up after + await client.delete(RAW_STREAM, SCORED_STREAM) + await client.aclose() + + +@pytest.fixture +def sample_article() -> RawArticle: + """Return a sample RawArticle mentioning AAPL.""" + return RawArticle( + source="rss", + url="https://example.com/aapl-news", + title="Apple Inc AAPL reports record quarterly earnings", + content=( + "Apple Inc ($AAPL) reported record-breaking quarterly earnings " + "today, beating analyst estimates by a wide margin. Revenue grew " + "15% year-over-year driven by strong iPhone and Services demand." + ), + published_at=datetime.now(timezone.utc), + fetched_at=datetime.now(timezone.utc), + content_hash="test-hash-aapl-001", + ) + + +# --------------------------------------------------------------------------- +# Mock counters (stand-in for OpenTelemetry instruments) +# --------------------------------------------------------------------------- + + +class _FakeCounter: + """Minimal fake that records how many times ``add`` was called.""" + + def __init__(self): + self.total = 0 + + def add(self, amount: int = 1, attributes: dict | None = None): + self.total += amount + + +class _FakeHistogram: + def __init__(self): + self.values: list[float] = [] + + def record(self, value: float, attributes: dict | None = None): + self.values.append(value) + + +def _make_counters() -> dict: + return { + "articles_scored": _FakeCounter(), + "finbert_count": _FakeCounter(), + "ollama_count": _FakeCounter(), + "inference_latency": _FakeHistogram(), + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_raw_article_flows_to_scored(redis_client: Redis, sample_article: RawArticle): + """Publish a RawArticle to news:raw, run the sentiment analyzer's + process_article function, and verify a ScoredArticle is published + to news:scored. + """ + publisher = StreamPublisher(redis_client, SCORED_STREAM) + + # Mock FinBERT to return high-confidence positive sentiment + mock_finbert = AsyncMock() + mock_finbert.analyze = AsyncMock(return_value=(0.85, 0.92)) + + # Mock Ollama (should not be called when FinBERT confidence is high) + mock_ollama = AsyncMock() + mock_ollama.analyze = AsyncMock(return_value=(0.0, 0.0)) + + config = SentimentAnalyzerConfig() + counters = _make_counters() + + # Process the article + await process_article( + sample_article, + mock_finbert, + mock_ollama, + publisher, + config, + counters, + ) + + # FinBERT should have been called, Ollama should NOT + mock_finbert.analyze.assert_called_once() + mock_ollama.analyze.assert_not_called() + + # Verify a ScoredArticle was published to the scored stream + messages = await redis_client.xrange(SCORED_STREAM) + assert len(messages) >= 1, "Expected at least one message on the scored stream" + + # Parse the first message + _msg_id, fields = messages[0] + data = json.loads(fields[b"data"]) + scored = ScoredArticle.model_validate(data) + + assert scored.ticker == "AAPL" + assert scored.sentiment_score == pytest.approx(0.85, abs=0.01) + assert scored.confidence == pytest.approx(0.92, abs=0.01) + assert scored.model_used == "finbert" + assert scored.source == "rss" + assert scored.title == sample_article.title + + # Counter checks + assert counters["articles_scored"].total == 1 + assert counters["finbert_count"].total == 1 + assert counters["ollama_count"].total == 0 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_low_confidence_falls_back_to_ollama(redis_client: Redis, sample_article: RawArticle): + """When FinBERT confidence is below the threshold, the sentiment + analyzer should fall back to Ollama. + """ + publisher = StreamPublisher(redis_client, SCORED_STREAM) + + # FinBERT returns low confidence -> triggers Ollama fallback + mock_finbert = AsyncMock() + mock_finbert.analyze = AsyncMock(return_value=(0.3, 0.4)) + + mock_ollama = AsyncMock() + mock_ollama.analyze = AsyncMock(return_value=(0.72, 0.88)) + + config = SentimentAnalyzerConfig() + config.finbert_confidence_threshold = 0.6 # 0.4 < 0.6 -> fallback + counters = _make_counters() + + await process_article( + sample_article, + mock_finbert, + mock_ollama, + publisher, + config, + counters, + ) + + # Both should have been called + mock_finbert.analyze.assert_called_once() + mock_ollama.analyze.assert_called_once() + + # Verify the published message used Ollama's scores + messages = await redis_client.xrange(SCORED_STREAM) + assert len(messages) >= 1 + + _msg_id, fields = messages[0] + data = json.loads(fields[b"data"]) + scored = ScoredArticle.model_validate(data) + + assert scored.model_used == "ollama" + assert scored.sentiment_score == pytest.approx(0.72, abs=0.01) + assert scored.confidence == pytest.approx(0.88, abs=0.01) + + # Counter checks + assert counters["ollama_count"].total == 1 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_article_without_tickers_does_not_publish(redis_client: Redis): + """An article with no recognizable ticker mentions should not produce + any ScoredArticle messages. + """ + article = RawArticle( + source="reddit", + url="https://reddit.com/r/finance/post123", + title="General market outlook for next week", + content="The market is looking bullish with strong consumer spending data.", + published_at=datetime.now(timezone.utc), + fetched_at=datetime.now(timezone.utc), + content_hash="test-hash-no-ticker-001", + ) + + publisher = StreamPublisher(redis_client, SCORED_STREAM) + + mock_finbert = AsyncMock() + mock_finbert.analyze = AsyncMock(return_value=(0.6, 0.85)) + + mock_ollama = AsyncMock() + + config = SentimentAnalyzerConfig() + counters = _make_counters() + + await process_article( + article, + mock_finbert, + mock_ollama, + publisher, + config, + counters, + ) + + # No tickers extracted -> no messages on the scored stream + messages = await redis_client.xrange(SCORED_STREAM) + assert len(messages) == 0 + + # Article was still counted as scored + assert counters["articles_scored"].total == 1 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_publish_and_consume_roundtrip(redis_client: Redis, sample_article: RawArticle): + """End-to-end: publish a RawArticle to the raw stream, consume it via + StreamConsumer, process it, and verify the scored output is consumable. + """ + raw_publisher = StreamPublisher(redis_client, RAW_STREAM) + scored_publisher = StreamPublisher(redis_client, SCORED_STREAM) + + # Publish the raw article + await raw_publisher.publish(sample_article.model_dump(mode="json")) + + # Verify it's on the raw stream + raw_messages = await redis_client.xrange(RAW_STREAM) + assert len(raw_messages) == 1 + + # Parse it back + _msg_id, fields = raw_messages[0] + data = json.loads(fields[b"data"]) + parsed = RawArticle.model_validate(data) + assert parsed.title == sample_article.title + + # Now process it through the analyzer + mock_finbert = AsyncMock() + mock_finbert.analyze = AsyncMock(return_value=(0.9, 0.95)) + mock_ollama = AsyncMock() + + config = SentimentAnalyzerConfig() + counters = _make_counters() + + await process_article( + parsed, + mock_finbert, + mock_ollama, + scored_publisher, + config, + counters, + ) + + # Verify scored output + scored_messages = await redis_client.xrange(SCORED_STREAM) + assert len(scored_messages) >= 1 + + _msg_id, fields = scored_messages[0] + scored_data = json.loads(fields[b"data"]) + scored = ScoredArticle.model_validate(scored_data) + assert scored.ticker == "AAPL" + assert scored.sentiment_score == pytest.approx(0.9, abs=0.01) diff --git a/tests/integration/test_trading_flow.py b/tests/integration/test_trading_flow.py new file mode 100644 index 0000000..ba97306 --- /dev/null +++ b/tests/integration/test_trading_flow.py @@ -0,0 +1,399 @@ +"""Integration test: signal generator -> trade executor flow. + +Publishes a mock TradeSignal to the ``signals:generated`` Redis stream +and verifies that a TradeExecution appears on ``trades:executed``. + +Requires a running Redis instance (from docker-compose). +The Alpaca broker is mocked. + +Run with: + pytest tests/integration/test_trading_flow.py -v -m integration +""" + +from __future__ import annotations + +import json +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import pytest +from redis.asyncio import Redis + +from shared.redis_streams import StreamPublisher +from shared.schemas.trading import ( + AccountInfo, + OrderResult, + OrderSide, + OrderStatus, + PositionInfo, + SignalDirection, + TradeExecution, + TradeSignal, +) +from services.trade_executor.config import TradeExecutorConfig +from services.trade_executor.main import process_signal +from services.trade_executor.risk_manager import RiskManager + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +REDIS_URL = "redis://localhost:6379/1" # Use DB 1 to avoid conflicts + +SIGNALS_STREAM = "test:signals:generated" +TRADES_STREAM = "test:trades:executed" + + +@pytest.fixture +async def redis_client(): + """Provide a clean Redis connection on DB 1 and clean up streams after.""" + client = Redis.from_url(REDIS_URL, decode_responses=False) + await client.delete(SIGNALS_STREAM, TRADES_STREAM) + yield client + await client.delete(SIGNALS_STREAM, TRADES_STREAM) + await client.aclose() + + +@pytest.fixture +def sample_signal() -> TradeSignal: + """Return a sample trade signal for AAPL.""" + return TradeSignal( + ticker="AAPL", + direction=SignalDirection.LONG, + strength=0.8, + strategy_sources=["momentum", "news_driven"], + sentiment_context={"avg_score": 0.85, "current_price": 190.50}, + timestamp=datetime.now(timezone.utc), + ) + + +@pytest.fixture +def mock_account() -> AccountInfo: + """Return a mock account with 100k equity.""" + return AccountInfo( + equity=100_000.0, + cash=50_000.0, + buying_power=100_000.0, + portfolio_value=100_000.0, + ) + + +@pytest.fixture +def mock_order_result() -> OrderResult: + """Return a mock filled order result.""" + return OrderResult( + order_id="test-order-001", + ticker="AAPL", + side=OrderSide.BUY, + qty=20.0, + filled_price=190.50, + status=OrderStatus.FILLED, + timestamp=datetime.now(timezone.utc), + ) + + +# --------------------------------------------------------------------------- +# Mock counters +# --------------------------------------------------------------------------- + + +class _FakeCounter: + def __init__(self): + self.total = 0 + self.attrs: list[dict] = [] + + def add(self, amount: int = 1, attributes: dict | None = None): + self.total += amount + if attributes: + self.attrs.append(attributes) + + +class _FakeHistogram: + def __init__(self): + self.values: list[float] = [] + + def record(self, value: float, attributes: dict | None = None): + self.values.append(value) + + +def _make_counters() -> dict: + return { + "trades_executed": _FakeCounter(), + "rejections": _FakeCounter(), + "fill_latency": _FakeHistogram(), + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_signal_produces_trade_execution( + redis_client: Redis, + sample_signal: TradeSignal, + mock_account: AccountInfo, + mock_order_result: OrderResult, +): + """Process a trade signal through the executor and verify a + TradeExecution is published to the trades:executed stream. + """ + publisher = StreamPublisher(redis_client, TRADES_STREAM) + counters = _make_counters() + + # Create mock broker + mock_broker = AsyncMock() + mock_broker.get_account = AsyncMock(return_value=mock_account) + mock_broker.get_positions = AsyncMock(return_value=[]) + mock_broker.submit_order = AsyncMock(return_value=mock_order_result) + + # Create risk manager with the mock broker, patching market hours check + config = TradeExecutorConfig() + risk_manager = RiskManager(config, mock_broker) + + # Patch _is_market_hours to always return True + with patch.object(RiskManager, "_is_market_hours", return_value=True): + await process_signal( + sample_signal, + risk_manager, + mock_broker, + publisher, + counters, + ) + + # Verify the broker was called + mock_broker.submit_order.assert_called_once() + order_arg = mock_broker.submit_order.call_args[0][0] + assert order_arg.ticker == "AAPL" + assert order_arg.side == OrderSide.BUY + + # Verify a TradeExecution was published + messages = await redis_client.xrange(TRADES_STREAM) + assert len(messages) == 1 + + _msg_id, fields = messages[0] + data = json.loads(fields[b"data"]) + execution = TradeExecution.model_validate(data) + + assert execution.ticker == "AAPL" + assert execution.side == OrderSide.BUY + assert execution.qty == 20.0 + assert execution.price == 190.50 + assert execution.status == OrderStatus.FILLED + + # Counter checks + assert counters["trades_executed"].total == 1 + assert len(counters["fill_latency"].values) == 1 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_short_signal_produces_sell_execution( + redis_client: Redis, + mock_account: AccountInfo, +): + """A SHORT signal should produce a SELL order.""" + short_signal = TradeSignal( + ticker="TSLA", + direction=SignalDirection.SHORT, + strength=0.7, + strategy_sources=["mean_reversion"], + sentiment_context={"avg_score": -0.6, "current_price": 250.00}, + timestamp=datetime.now(timezone.utc), + ) + + sell_result = OrderResult( + order_id="test-order-002", + ticker="TSLA", + side=OrderSide.SELL, + qty=14.0, + filled_price=250.00, + status=OrderStatus.FILLED, + timestamp=datetime.now(timezone.utc), + ) + + publisher = StreamPublisher(redis_client, TRADES_STREAM) + counters = _make_counters() + + mock_broker = AsyncMock() + mock_broker.get_account = AsyncMock(return_value=mock_account) + mock_broker.get_positions = AsyncMock(return_value=[]) + mock_broker.submit_order = AsyncMock(return_value=sell_result) + + config = TradeExecutorConfig() + risk_manager = RiskManager(config, mock_broker) + + with patch.object(RiskManager, "_is_market_hours", return_value=True): + await process_signal( + short_signal, + risk_manager, + mock_broker, + publisher, + counters, + ) + + messages = await redis_client.xrange(TRADES_STREAM) + assert len(messages) == 1 + + _msg_id, fields = messages[0] + data = json.loads(fields[b"data"]) + execution = TradeExecution.model_validate(data) + + assert execution.ticker == "TSLA" + assert execution.side == OrderSide.SELL + assert execution.status == OrderStatus.FILLED + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_risk_rejection_does_not_publish( + redis_client: Redis, + sample_signal: TradeSignal, + mock_account: AccountInfo, +): + """When risk checks fail (outside market hours), no TradeExecution + should be published. + """ + publisher = StreamPublisher(redis_client, TRADES_STREAM) + counters = _make_counters() + + mock_broker = AsyncMock() + mock_broker.get_account = AsyncMock(return_value=mock_account) + mock_broker.get_positions = AsyncMock(return_value=[]) + + config = TradeExecutorConfig() + risk_manager = RiskManager(config, mock_broker) + + # Market is closed -> risk check fails + with patch.object(RiskManager, "_is_market_hours", return_value=False): + await process_signal( + sample_signal, + risk_manager, + mock_broker, + publisher, + counters, + ) + + # No order should have been submitted + mock_broker.submit_order.assert_not_called() + + # No messages on the trades stream + messages = await redis_client.xrange(TRADES_STREAM) + assert len(messages) == 0 + + # Rejection counter should be incremented + assert counters["rejections"].total == 1 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_max_positions_rejection( + redis_client: Redis, + sample_signal: TradeSignal, + mock_account: AccountInfo, +): + """When the maximum number of positions is reached, the signal + should be rejected. + """ + publisher = StreamPublisher(redis_client, TRADES_STREAM) + counters = _make_counters() + + # Create enough mock positions to exceed the limit + existing_positions = [ + PositionInfo( + ticker=f"STOCK{i}", + qty=10.0, + avg_entry=100.0, + current_price=105.0, + unrealized_pnl=50.0, + market_value=1050.0, + ) + for i in range(25) # Default max is 20 + ] + + mock_broker = AsyncMock() + mock_broker.get_account = AsyncMock(return_value=mock_account) + mock_broker.get_positions = AsyncMock(return_value=existing_positions) + + config = TradeExecutorConfig() + risk_manager = RiskManager(config, mock_broker) + + with patch.object(RiskManager, "_is_market_hours", return_value=True): + await process_signal( + sample_signal, + risk_manager, + mock_broker, + publisher, + counters, + ) + + mock_broker.submit_order.assert_not_called() + + messages = await redis_client.xrange(TRADES_STREAM) + assert len(messages) == 0 + + assert counters["rejections"].total == 1 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_publish_signal_and_consume_execution_roundtrip( + redis_client: Redis, + sample_signal: TradeSignal, + mock_account: AccountInfo, + mock_order_result: OrderResult, +): + """End-to-end: publish a signal to the signals stream, process it, + and verify the execution can be read back from the trades stream. + """ + # Publish the signal to the signals stream + signal_publisher = StreamPublisher(redis_client, SIGNALS_STREAM) + await signal_publisher.publish(sample_signal.model_dump(mode="json")) + + # Verify the signal is on the stream + signal_messages = await redis_client.xrange(SIGNALS_STREAM) + assert len(signal_messages) == 1 + + # Parse it back to verify serialization + _msg_id, fields = signal_messages[0] + data = json.loads(fields[b"data"]) + parsed_signal = TradeSignal.model_validate(data) + assert parsed_signal.ticker == "AAPL" + assert parsed_signal.direction == SignalDirection.LONG + + # Process the signal through the executor + trades_publisher = StreamPublisher(redis_client, TRADES_STREAM) + counters = _make_counters() + + mock_broker = AsyncMock() + mock_broker.get_account = AsyncMock(return_value=mock_account) + mock_broker.get_positions = AsyncMock(return_value=[]) + mock_broker.submit_order = AsyncMock(return_value=mock_order_result) + + config = TradeExecutorConfig() + risk_manager = RiskManager(config, mock_broker) + + with patch.object(RiskManager, "_is_market_hours", return_value=True): + await process_signal( + parsed_signal, + risk_manager, + mock_broker, + trades_publisher, + counters, + ) + + # Read the execution from the trades stream + trade_messages = await redis_client.xrange(TRADES_STREAM) + assert len(trade_messages) == 1 + + _msg_id, fields = trade_messages[0] + data = json.loads(fields[b"data"]) + execution = TradeExecution.model_validate(data) + + assert execution.ticker == "AAPL" + assert execution.side == OrderSide.BUY + assert execution.status == OrderStatus.FILLED + assert execution.price == 190.50