Add integration tests for the news pipeline (test_news_pipeline.py) and trading flow (test_trading_flow.py) using real Redis with mocked FinBERT and Alpaca. Add seed_strategies.py to insert default strategies (momentum, mean_reversion, news_driven) with equal weights. Add smoke_test.sh for end-to-end stack validation. Update pyproject.toml with integration marker and scripts package discovery.
399 lines
12 KiB
Python
399 lines
12 KiB
Python
"""Integration test: signal generator -> trade executor flow.
|
|
|
|
Publishes a mock TradeSignal to the ``signals:generated`` Redis stream
|
|
and verifies that a TradeExecution appears on ``trades:executed``.
|
|
|
|
Requires a running Redis instance (from docker-compose).
|
|
The Alpaca broker is mocked.
|
|
|
|
Run with:
|
|
pytest tests/integration/test_trading_flow.py -v -m integration
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from redis.asyncio import Redis
|
|
|
|
from shared.redis_streams import StreamPublisher
|
|
from shared.schemas.trading import (
|
|
AccountInfo,
|
|
OrderResult,
|
|
OrderSide,
|
|
OrderStatus,
|
|
PositionInfo,
|
|
SignalDirection,
|
|
TradeExecution,
|
|
TradeSignal,
|
|
)
|
|
from services.trade_executor.config import TradeExecutorConfig
|
|
from services.trade_executor.main import process_signal
|
|
from services.trade_executor.risk_manager import RiskManager
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
REDIS_URL = "redis://localhost:6379/1" # Use DB 1 to avoid conflicts
|
|
|
|
SIGNALS_STREAM = "test:signals:generated"
|
|
TRADES_STREAM = "test:trades:executed"
|
|
|
|
|
|
@pytest.fixture
|
|
async def redis_client():
|
|
"""Provide a clean Redis connection on DB 1 and clean up streams after."""
|
|
client = Redis.from_url(REDIS_URL, decode_responses=False)
|
|
await client.delete(SIGNALS_STREAM, TRADES_STREAM)
|
|
yield client
|
|
await client.delete(SIGNALS_STREAM, TRADES_STREAM)
|
|
await client.aclose()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_signal() -> TradeSignal:
|
|
"""Return a sample trade signal for AAPL."""
|
|
return TradeSignal(
|
|
ticker="AAPL",
|
|
direction=SignalDirection.LONG,
|
|
strength=0.8,
|
|
strategy_sources=["momentum", "news_driven"],
|
|
sentiment_context={"avg_score": 0.85, "current_price": 190.50},
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_account() -> AccountInfo:
|
|
"""Return a mock account with 100k equity."""
|
|
return AccountInfo(
|
|
equity=100_000.0,
|
|
cash=50_000.0,
|
|
buying_power=100_000.0,
|
|
portfolio_value=100_000.0,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_order_result() -> OrderResult:
|
|
"""Return a mock filled order result."""
|
|
return OrderResult(
|
|
order_id="test-order-001",
|
|
ticker="AAPL",
|
|
side=OrderSide.BUY,
|
|
qty=20.0,
|
|
filled_price=190.50,
|
|
status=OrderStatus.FILLED,
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Mock counters
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _FakeCounter:
|
|
def __init__(self):
|
|
self.total = 0
|
|
self.attrs: list[dict] = []
|
|
|
|
def add(self, amount: int = 1, attributes: dict | None = None):
|
|
self.total += amount
|
|
if attributes:
|
|
self.attrs.append(attributes)
|
|
|
|
|
|
class _FakeHistogram:
|
|
def __init__(self):
|
|
self.values: list[float] = []
|
|
|
|
def record(self, value: float, attributes: dict | None = None):
|
|
self.values.append(value)
|
|
|
|
|
|
def _make_counters() -> dict:
|
|
return {
|
|
"trades_executed": _FakeCounter(),
|
|
"rejections": _FakeCounter(),
|
|
"fill_latency": _FakeHistogram(),
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_signal_produces_trade_execution(
|
|
redis_client: Redis,
|
|
sample_signal: TradeSignal,
|
|
mock_account: AccountInfo,
|
|
mock_order_result: OrderResult,
|
|
):
|
|
"""Process a trade signal through the executor and verify a
|
|
TradeExecution is published to the trades:executed stream.
|
|
"""
|
|
publisher = StreamPublisher(redis_client, TRADES_STREAM)
|
|
counters = _make_counters()
|
|
|
|
# Create mock broker
|
|
mock_broker = AsyncMock()
|
|
mock_broker.get_account = AsyncMock(return_value=mock_account)
|
|
mock_broker.get_positions = AsyncMock(return_value=[])
|
|
mock_broker.submit_order = AsyncMock(return_value=mock_order_result)
|
|
|
|
# Create risk manager with the mock broker, patching market hours check
|
|
config = TradeExecutorConfig()
|
|
risk_manager = RiskManager(config, mock_broker)
|
|
|
|
# Patch _is_market_hours to always return True
|
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
|
await process_signal(
|
|
sample_signal,
|
|
risk_manager,
|
|
mock_broker,
|
|
publisher,
|
|
counters,
|
|
)
|
|
|
|
# Verify the broker was called
|
|
mock_broker.submit_order.assert_called_once()
|
|
order_arg = mock_broker.submit_order.call_args[0][0]
|
|
assert order_arg.ticker == "AAPL"
|
|
assert order_arg.side == OrderSide.BUY
|
|
|
|
# Verify a TradeExecution was published
|
|
messages = await redis_client.xrange(TRADES_STREAM)
|
|
assert len(messages) == 1
|
|
|
|
_msg_id, fields = messages[0]
|
|
data = json.loads(fields[b"data"])
|
|
execution = TradeExecution.model_validate(data)
|
|
|
|
assert execution.ticker == "AAPL"
|
|
assert execution.side == OrderSide.BUY
|
|
assert execution.qty == 20.0
|
|
assert execution.price == 190.50
|
|
assert execution.status == OrderStatus.FILLED
|
|
|
|
# Counter checks
|
|
assert counters["trades_executed"].total == 1
|
|
assert len(counters["fill_latency"].values) == 1
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_short_signal_produces_sell_execution(
|
|
redis_client: Redis,
|
|
mock_account: AccountInfo,
|
|
):
|
|
"""A SHORT signal should produce a SELL order."""
|
|
short_signal = TradeSignal(
|
|
ticker="TSLA",
|
|
direction=SignalDirection.SHORT,
|
|
strength=0.7,
|
|
strategy_sources=["mean_reversion"],
|
|
sentiment_context={"avg_score": -0.6, "current_price": 250.00},
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
|
|
sell_result = OrderResult(
|
|
order_id="test-order-002",
|
|
ticker="TSLA",
|
|
side=OrderSide.SELL,
|
|
qty=14.0,
|
|
filled_price=250.00,
|
|
status=OrderStatus.FILLED,
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
|
|
publisher = StreamPublisher(redis_client, TRADES_STREAM)
|
|
counters = _make_counters()
|
|
|
|
mock_broker = AsyncMock()
|
|
mock_broker.get_account = AsyncMock(return_value=mock_account)
|
|
mock_broker.get_positions = AsyncMock(return_value=[])
|
|
mock_broker.submit_order = AsyncMock(return_value=sell_result)
|
|
|
|
config = TradeExecutorConfig()
|
|
risk_manager = RiskManager(config, mock_broker)
|
|
|
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
|
await process_signal(
|
|
short_signal,
|
|
risk_manager,
|
|
mock_broker,
|
|
publisher,
|
|
counters,
|
|
)
|
|
|
|
messages = await redis_client.xrange(TRADES_STREAM)
|
|
assert len(messages) == 1
|
|
|
|
_msg_id, fields = messages[0]
|
|
data = json.loads(fields[b"data"])
|
|
execution = TradeExecution.model_validate(data)
|
|
|
|
assert execution.ticker == "TSLA"
|
|
assert execution.side == OrderSide.SELL
|
|
assert execution.status == OrderStatus.FILLED
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_risk_rejection_does_not_publish(
|
|
redis_client: Redis,
|
|
sample_signal: TradeSignal,
|
|
mock_account: AccountInfo,
|
|
):
|
|
"""When risk checks fail (outside market hours), no TradeExecution
|
|
should be published.
|
|
"""
|
|
publisher = StreamPublisher(redis_client, TRADES_STREAM)
|
|
counters = _make_counters()
|
|
|
|
mock_broker = AsyncMock()
|
|
mock_broker.get_account = AsyncMock(return_value=mock_account)
|
|
mock_broker.get_positions = AsyncMock(return_value=[])
|
|
|
|
config = TradeExecutorConfig()
|
|
risk_manager = RiskManager(config, mock_broker)
|
|
|
|
# Market is closed -> risk check fails
|
|
with patch.object(RiskManager, "_is_market_hours", return_value=False):
|
|
await process_signal(
|
|
sample_signal,
|
|
risk_manager,
|
|
mock_broker,
|
|
publisher,
|
|
counters,
|
|
)
|
|
|
|
# No order should have been submitted
|
|
mock_broker.submit_order.assert_not_called()
|
|
|
|
# No messages on the trades stream
|
|
messages = await redis_client.xrange(TRADES_STREAM)
|
|
assert len(messages) == 0
|
|
|
|
# Rejection counter should be incremented
|
|
assert counters["rejections"].total == 1
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_max_positions_rejection(
|
|
redis_client: Redis,
|
|
sample_signal: TradeSignal,
|
|
mock_account: AccountInfo,
|
|
):
|
|
"""When the maximum number of positions is reached, the signal
|
|
should be rejected.
|
|
"""
|
|
publisher = StreamPublisher(redis_client, TRADES_STREAM)
|
|
counters = _make_counters()
|
|
|
|
# Create enough mock positions to exceed the limit
|
|
existing_positions = [
|
|
PositionInfo(
|
|
ticker=f"STOCK{i}",
|
|
qty=10.0,
|
|
avg_entry=100.0,
|
|
current_price=105.0,
|
|
unrealized_pnl=50.0,
|
|
market_value=1050.0,
|
|
)
|
|
for i in range(25) # Default max is 20
|
|
]
|
|
|
|
mock_broker = AsyncMock()
|
|
mock_broker.get_account = AsyncMock(return_value=mock_account)
|
|
mock_broker.get_positions = AsyncMock(return_value=existing_positions)
|
|
|
|
config = TradeExecutorConfig()
|
|
risk_manager = RiskManager(config, mock_broker)
|
|
|
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
|
await process_signal(
|
|
sample_signal,
|
|
risk_manager,
|
|
mock_broker,
|
|
publisher,
|
|
counters,
|
|
)
|
|
|
|
mock_broker.submit_order.assert_not_called()
|
|
|
|
messages = await redis_client.xrange(TRADES_STREAM)
|
|
assert len(messages) == 0
|
|
|
|
assert counters["rejections"].total == 1
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_publish_signal_and_consume_execution_roundtrip(
|
|
redis_client: Redis,
|
|
sample_signal: TradeSignal,
|
|
mock_account: AccountInfo,
|
|
mock_order_result: OrderResult,
|
|
):
|
|
"""End-to-end: publish a signal to the signals stream, process it,
|
|
and verify the execution can be read back from the trades stream.
|
|
"""
|
|
# Publish the signal to the signals stream
|
|
signal_publisher = StreamPublisher(redis_client, SIGNALS_STREAM)
|
|
await signal_publisher.publish(sample_signal.model_dump(mode="json"))
|
|
|
|
# Verify the signal is on the stream
|
|
signal_messages = await redis_client.xrange(SIGNALS_STREAM)
|
|
assert len(signal_messages) == 1
|
|
|
|
# Parse it back to verify serialization
|
|
_msg_id, fields = signal_messages[0]
|
|
data = json.loads(fields[b"data"])
|
|
parsed_signal = TradeSignal.model_validate(data)
|
|
assert parsed_signal.ticker == "AAPL"
|
|
assert parsed_signal.direction == SignalDirection.LONG
|
|
|
|
# Process the signal through the executor
|
|
trades_publisher = StreamPublisher(redis_client, TRADES_STREAM)
|
|
counters = _make_counters()
|
|
|
|
mock_broker = AsyncMock()
|
|
mock_broker.get_account = AsyncMock(return_value=mock_account)
|
|
mock_broker.get_positions = AsyncMock(return_value=[])
|
|
mock_broker.submit_order = AsyncMock(return_value=mock_order_result)
|
|
|
|
config = TradeExecutorConfig()
|
|
risk_manager = RiskManager(config, mock_broker)
|
|
|
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
|
await process_signal(
|
|
parsed_signal,
|
|
risk_manager,
|
|
mock_broker,
|
|
trades_publisher,
|
|
counters,
|
|
)
|
|
|
|
# Read the execution from the trades stream
|
|
trade_messages = await redis_client.xrange(TRADES_STREAM)
|
|
assert len(trade_messages) == 1
|
|
|
|
_msg_id, fields = trade_messages[0]
|
|
data = json.loads(fields[b"data"])
|
|
execution = TradeExecution.model_validate(data)
|
|
|
|
assert execution.ticker == "AAPL"
|
|
assert execution.side == OrderSide.BUY
|
|
assert execution.status == OrderStatus.FILLED
|
|
assert execution.price == 190.50
|