feat: integration tests, seed data, and smoke test script
Add integration tests for the news pipeline (test_news_pipeline.py) and trading flow (test_trading_flow.py) using real Redis with mocked FinBERT and Alpaca. Add seed_strategies.py to insert default strategies (momentum, mean_reversion, news_driven) with equal weights. Add smoke_test.sh for end-to-end stack validation. Update pyproject.toml with integration marker and scripts package discovery.
This commit is contained in:
parent
b255b3edbe
commit
e6ae4bdccd
7 changed files with 948 additions and 1 deletions
|
|
@ -27,11 +27,12 @@ requires = ["setuptools>=70.0"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
include = ["shared*", "services*", "backtester*", "tests*"]
|
include = ["shared*", "services*", "backtester*", "scripts*", "tests*"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
|
markers = ["integration: marks tests requiring docker services (redis, postgres)"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
|
|
||||||
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
109
scripts/seed_strategies.py
Normal file
109
scripts/seed_strategies.py
Normal file
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""Seed default trading strategies.
|
||||||
|
|
||||||
|
Inserts three strategies with equal initial weights (0.333 each):
|
||||||
|
- momentum
|
||||||
|
- mean_reversion
|
||||||
|
- news_driven
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m scripts.seed_strategies
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from shared.config import BaseConfig
|
||||||
|
from shared.db import create_db
|
||||||
|
from shared.models.trading import Strategy
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Default strategies to seed
|
||||||
|
DEFAULT_STRATEGIES = [
|
||||||
|
{
|
||||||
|
"name": "momentum",
|
||||||
|
"description": (
|
||||||
|
"Buy when price crosses above N-period SMA with increasing volume; "
|
||||||
|
"sell when it crosses below."
|
||||||
|
),
|
||||||
|
"current_weight": 0.333,
|
||||||
|
"active": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mean_reversion",
|
||||||
|
"description": (
|
||||||
|
"Buy when RSI < 30 (oversold); sell when RSI > 70 (overbought). "
|
||||||
|
"Signal strength proportional to RSI extremity."
|
||||||
|
),
|
||||||
|
"current_weight": 0.333,
|
||||||
|
"active": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "news_driven",
|
||||||
|
"description": (
|
||||||
|
"Buy on strong positive sentiment (score > 0.7, confidence > 0.6); "
|
||||||
|
"sell on strong negative. Decay factor for stale news (> 4 hours)."
|
||||||
|
),
|
||||||
|
"current_weight": 0.333,
|
||||||
|
"active": True,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def seed(database_url: str | None = None) -> None:
|
||||||
|
"""Insert default strategies if they do not already exist.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
database_url:
|
||||||
|
Override for the database URL. If ``None``, the default from
|
||||||
|
:class:`~shared.config.BaseConfig` is used.
|
||||||
|
"""
|
||||||
|
config = BaseConfig()
|
||||||
|
if database_url:
|
||||||
|
config.database_url = database_url
|
||||||
|
|
||||||
|
_engine, session_factory = create_db(config)
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
for strategy_data in DEFAULT_STRATEGIES:
|
||||||
|
# Check if the strategy already exists by name
|
||||||
|
result = await session.execute(
|
||||||
|
select(Strategy).where(Strategy.name == strategy_data["name"])
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
logger.info(
|
||||||
|
"Strategy '%s' already exists (weight=%.3f), skipping",
|
||||||
|
existing.name,
|
||||||
|
existing.current_weight,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
strategy = Strategy(**strategy_data)
|
||||||
|
session.add(strategy)
|
||||||
|
logger.info(
|
||||||
|
"Inserted strategy '%s' with weight %.3f",
|
||||||
|
strategy_data["name"],
|
||||||
|
strategy_data["current_weight"],
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
await _engine.dispose()
|
||||||
|
logger.info("Strategy seeding complete")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""CLI entry-point."""
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
asyncio.run(seed())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
139
scripts/smoke_test.sh
Executable file
139
scripts/smoke_test.sh
Executable file
|
|
@ -0,0 +1,139 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Smoke test for the full trading-bot Docker Compose stack.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./scripts/smoke_test.sh
|
||||||
|
#
|
||||||
|
# Prerequisites:
|
||||||
|
# - Docker Compose stack must be running (docker compose up -d)
|
||||||
|
#
|
||||||
|
# This script:
|
||||||
|
# 1. Waits for services to become healthy
|
||||||
|
# 2. Hits GET /health -> expects 200
|
||||||
|
# 3. Hits GET /api/portfolio -> expects 401 (unauthenticated)
|
||||||
|
# 4. Hits GET /api/strategies -> expects 401 (unauthenticated)
|
||||||
|
# 5. Checks docker compose ps shows all services running
|
||||||
|
# 6. Exits 0 on success, 1 on failure
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
API_BASE="${API_BASE:-http://localhost:8000}"
|
||||||
|
DASHBOARD_BASE="${DASHBOARD_BASE:-http://localhost:3000}"
|
||||||
|
MAX_RETRIES="${MAX_RETRIES:-30}"
|
||||||
|
RETRY_INTERVAL="${RETRY_INTERVAL:-2}"
|
||||||
|
|
||||||
|
PASS=0
|
||||||
|
FAIL=0
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
log() {
|
||||||
|
echo "[smoke-test] $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
pass() {
|
||||||
|
log "PASS: $*"
|
||||||
|
PASS=$((PASS + 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
fail() {
|
||||||
|
log "FAIL: $*"
|
||||||
|
FAIL=$((FAIL + 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_for_endpoint() {
|
||||||
|
local url="$1"
|
||||||
|
local expected_code="$2"
|
||||||
|
local description="$3"
|
||||||
|
local attempt=0
|
||||||
|
|
||||||
|
while [ "$attempt" -lt "$MAX_RETRIES" ]; do
|
||||||
|
attempt=$((attempt + 1))
|
||||||
|
status_code=$(curl -s -o /dev/null -w "%{http_code}" "$url" 2>/dev/null || echo "000")
|
||||||
|
if [ "$status_code" = "$expected_code" ]; then
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
log "Waiting for $description ($url) ... attempt $attempt/$MAX_RETRIES (got $status_code, want $expected_code)"
|
||||||
|
sleep "$RETRY_INTERVAL"
|
||||||
|
done
|
||||||
|
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
check_endpoint() {
|
||||||
|
local url="$1"
|
||||||
|
local expected_code="$2"
|
||||||
|
local description="$3"
|
||||||
|
|
||||||
|
status_code=$(curl -s -o /dev/null -w "%{http_code}" "$url" 2>/dev/null || echo "000")
|
||||||
|
if [ "$status_code" = "$expected_code" ]; then
|
||||||
|
pass "$description -> $status_code"
|
||||||
|
else
|
||||||
|
fail "$description -> expected $expected_code, got $status_code"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 1. Wait for the API gateway health endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
log "Waiting for API gateway to be healthy ..."
|
||||||
|
if wait_for_endpoint "$API_BASE/health" "200" "API health"; then
|
||||||
|
pass "API gateway is healthy"
|
||||||
|
else
|
||||||
|
fail "API gateway did not become healthy within timeout"
|
||||||
|
log "Aborting — cannot run further checks without a healthy API"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 2. Health check
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
check_endpoint "$API_BASE/health" "200" "GET /health"
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 3. Unauthenticated trading endpoints should return 401/403
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
check_endpoint "$API_BASE/api/portfolio" "401" "GET /api/portfolio (no auth)"
|
||||||
|
check_endpoint "$API_BASE/api/strategies" "401" "GET /api/strategies (no auth)"
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 4. Dashboard responds
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
log "Checking dashboard ..."
|
||||||
|
if wait_for_endpoint "$DASHBOARD_BASE/" "200" "Dashboard"; then
|
||||||
|
pass "Dashboard is serving"
|
||||||
|
else
|
||||||
|
fail "Dashboard did not respond"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 5. Docker Compose services status
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
log "Checking docker compose service status ..."
|
||||||
|
if command -v docker &>/dev/null; then
|
||||||
|
running_count=$(docker compose ps --format json 2>/dev/null | grep -c '"running"' || echo "0")
|
||||||
|
if [ "$running_count" -gt 0 ]; then
|
||||||
|
pass "docker compose shows $running_count running services"
|
||||||
|
else
|
||||||
|
fail "No running services found in docker compose ps"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
log "SKIP: docker command not available"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Summary
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
echo ""
|
||||||
|
log "================================"
|
||||||
|
log "Results: $PASS passed, $FAIL failed"
|
||||||
|
log "================================"
|
||||||
|
|
||||||
|
if [ "$FAIL" -gt 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
exit 0
|
||||||
0
tests/integration/__init__.py
Normal file
0
tests/integration/__init__.py
Normal file
299
tests/integration/test_news_pipeline.py
Normal file
299
tests/integration/test_news_pipeline.py
Normal file
|
|
@ -0,0 +1,299 @@
|
||||||
|
"""Integration test: news fetcher -> sentiment analyzer pipeline.
|
||||||
|
|
||||||
|
Publishes a mock RawArticle to the ``news:raw`` Redis stream and verifies
|
||||||
|
that a ScoredArticle appears on ``news:scored``.
|
||||||
|
|
||||||
|
Requires a running Redis instance (from docker-compose).
|
||||||
|
FinBERT and Ollama are mocked so the test does not need GPU / model weights.
|
||||||
|
|
||||||
|
Run with:
|
||||||
|
pytest tests/integration/test_news_pipeline.py -v -m integration
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from shared.redis_streams import StreamConsumer, StreamPublisher
|
||||||
|
from shared.schemas.news import RawArticle, ScoredArticle
|
||||||
|
from services.sentiment_analyzer.main import process_article
|
||||||
|
from services.sentiment_analyzer.config import SentimentAnalyzerConfig
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
REDIS_URL = "redis://localhost:6379/1" # Use DB 1 to avoid conflicts
|
||||||
|
|
||||||
|
RAW_STREAM = "test:news:raw"
|
||||||
|
SCORED_STREAM = "test:news:scored"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def redis_client():
|
||||||
|
"""Provide a clean Redis connection on DB 1 and clean up streams after."""
|
||||||
|
client = Redis.from_url(REDIS_URL, decode_responses=False)
|
||||||
|
# Ensure streams are clean before the test
|
||||||
|
await client.delete(RAW_STREAM, SCORED_STREAM)
|
||||||
|
yield client
|
||||||
|
# Clean up after
|
||||||
|
await client.delete(RAW_STREAM, SCORED_STREAM)
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_article() -> RawArticle:
|
||||||
|
"""Return a sample RawArticle mentioning AAPL."""
|
||||||
|
return RawArticle(
|
||||||
|
source="rss",
|
||||||
|
url="https://example.com/aapl-news",
|
||||||
|
title="Apple Inc AAPL reports record quarterly earnings",
|
||||||
|
content=(
|
||||||
|
"Apple Inc ($AAPL) reported record-breaking quarterly earnings "
|
||||||
|
"today, beating analyst estimates by a wide margin. Revenue grew "
|
||||||
|
"15% year-over-year driven by strong iPhone and Services demand."
|
||||||
|
),
|
||||||
|
published_at=datetime.now(timezone.utc),
|
||||||
|
fetched_at=datetime.now(timezone.utc),
|
||||||
|
content_hash="test-hash-aapl-001",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Mock counters (stand-in for OpenTelemetry instruments)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeCounter:
|
||||||
|
"""Minimal fake that records how many times ``add`` was called."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.total = 0
|
||||||
|
|
||||||
|
def add(self, amount: int = 1, attributes: dict | None = None):
|
||||||
|
self.total += amount
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeHistogram:
|
||||||
|
def __init__(self):
|
||||||
|
self.values: list[float] = []
|
||||||
|
|
||||||
|
def record(self, value: float, attributes: dict | None = None):
|
||||||
|
self.values.append(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_counters() -> dict:
|
||||||
|
return {
|
||||||
|
"articles_scored": _FakeCounter(),
|
||||||
|
"finbert_count": _FakeCounter(),
|
||||||
|
"ollama_count": _FakeCounter(),
|
||||||
|
"inference_latency": _FakeHistogram(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raw_article_flows_to_scored(redis_client: Redis, sample_article: RawArticle):
|
||||||
|
"""Publish a RawArticle to news:raw, run the sentiment analyzer's
|
||||||
|
process_article function, and verify a ScoredArticle is published
|
||||||
|
to news:scored.
|
||||||
|
"""
|
||||||
|
publisher = StreamPublisher(redis_client, SCORED_STREAM)
|
||||||
|
|
||||||
|
# Mock FinBERT to return high-confidence positive sentiment
|
||||||
|
mock_finbert = AsyncMock()
|
||||||
|
mock_finbert.analyze = AsyncMock(return_value=(0.85, 0.92))
|
||||||
|
|
||||||
|
# Mock Ollama (should not be called when FinBERT confidence is high)
|
||||||
|
mock_ollama = AsyncMock()
|
||||||
|
mock_ollama.analyze = AsyncMock(return_value=(0.0, 0.0))
|
||||||
|
|
||||||
|
config = SentimentAnalyzerConfig()
|
||||||
|
counters = _make_counters()
|
||||||
|
|
||||||
|
# Process the article
|
||||||
|
await process_article(
|
||||||
|
sample_article,
|
||||||
|
mock_finbert,
|
||||||
|
mock_ollama,
|
||||||
|
publisher,
|
||||||
|
config,
|
||||||
|
counters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# FinBERT should have been called, Ollama should NOT
|
||||||
|
mock_finbert.analyze.assert_called_once()
|
||||||
|
mock_ollama.analyze.assert_not_called()
|
||||||
|
|
||||||
|
# Verify a ScoredArticle was published to the scored stream
|
||||||
|
messages = await redis_client.xrange(SCORED_STREAM)
|
||||||
|
assert len(messages) >= 1, "Expected at least one message on the scored stream"
|
||||||
|
|
||||||
|
# Parse the first message
|
||||||
|
_msg_id, fields = messages[0]
|
||||||
|
data = json.loads(fields[b"data"])
|
||||||
|
scored = ScoredArticle.model_validate(data)
|
||||||
|
|
||||||
|
assert scored.ticker == "AAPL"
|
||||||
|
assert scored.sentiment_score == pytest.approx(0.85, abs=0.01)
|
||||||
|
assert scored.confidence == pytest.approx(0.92, abs=0.01)
|
||||||
|
assert scored.model_used == "finbert"
|
||||||
|
assert scored.source == "rss"
|
||||||
|
assert scored.title == sample_article.title
|
||||||
|
|
||||||
|
# Counter checks
|
||||||
|
assert counters["articles_scored"].total == 1
|
||||||
|
assert counters["finbert_count"].total == 1
|
||||||
|
assert counters["ollama_count"].total == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_low_confidence_falls_back_to_ollama(redis_client: Redis, sample_article: RawArticle):
|
||||||
|
"""When FinBERT confidence is below the threshold, the sentiment
|
||||||
|
analyzer should fall back to Ollama.
|
||||||
|
"""
|
||||||
|
publisher = StreamPublisher(redis_client, SCORED_STREAM)
|
||||||
|
|
||||||
|
# FinBERT returns low confidence -> triggers Ollama fallback
|
||||||
|
mock_finbert = AsyncMock()
|
||||||
|
mock_finbert.analyze = AsyncMock(return_value=(0.3, 0.4))
|
||||||
|
|
||||||
|
mock_ollama = AsyncMock()
|
||||||
|
mock_ollama.analyze = AsyncMock(return_value=(0.72, 0.88))
|
||||||
|
|
||||||
|
config = SentimentAnalyzerConfig()
|
||||||
|
config.finbert_confidence_threshold = 0.6 # 0.4 < 0.6 -> fallback
|
||||||
|
counters = _make_counters()
|
||||||
|
|
||||||
|
await process_article(
|
||||||
|
sample_article,
|
||||||
|
mock_finbert,
|
||||||
|
mock_ollama,
|
||||||
|
publisher,
|
||||||
|
config,
|
||||||
|
counters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Both should have been called
|
||||||
|
mock_finbert.analyze.assert_called_once()
|
||||||
|
mock_ollama.analyze.assert_called_once()
|
||||||
|
|
||||||
|
# Verify the published message used Ollama's scores
|
||||||
|
messages = await redis_client.xrange(SCORED_STREAM)
|
||||||
|
assert len(messages) >= 1
|
||||||
|
|
||||||
|
_msg_id, fields = messages[0]
|
||||||
|
data = json.loads(fields[b"data"])
|
||||||
|
scored = ScoredArticle.model_validate(data)
|
||||||
|
|
||||||
|
assert scored.model_used == "ollama"
|
||||||
|
assert scored.sentiment_score == pytest.approx(0.72, abs=0.01)
|
||||||
|
assert scored.confidence == pytest.approx(0.88, abs=0.01)
|
||||||
|
|
||||||
|
# Counter checks
|
||||||
|
assert counters["ollama_count"].total == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_article_without_tickers_does_not_publish(redis_client: Redis):
|
||||||
|
"""An article with no recognizable ticker mentions should not produce
|
||||||
|
any ScoredArticle messages.
|
||||||
|
"""
|
||||||
|
article = RawArticle(
|
||||||
|
source="reddit",
|
||||||
|
url="https://reddit.com/r/finance/post123",
|
||||||
|
title="General market outlook for next week",
|
||||||
|
content="The market is looking bullish with strong consumer spending data.",
|
||||||
|
published_at=datetime.now(timezone.utc),
|
||||||
|
fetched_at=datetime.now(timezone.utc),
|
||||||
|
content_hash="test-hash-no-ticker-001",
|
||||||
|
)
|
||||||
|
|
||||||
|
publisher = StreamPublisher(redis_client, SCORED_STREAM)
|
||||||
|
|
||||||
|
mock_finbert = AsyncMock()
|
||||||
|
mock_finbert.analyze = AsyncMock(return_value=(0.6, 0.85))
|
||||||
|
|
||||||
|
mock_ollama = AsyncMock()
|
||||||
|
|
||||||
|
config = SentimentAnalyzerConfig()
|
||||||
|
counters = _make_counters()
|
||||||
|
|
||||||
|
await process_article(
|
||||||
|
article,
|
||||||
|
mock_finbert,
|
||||||
|
mock_ollama,
|
||||||
|
publisher,
|
||||||
|
config,
|
||||||
|
counters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No tickers extracted -> no messages on the scored stream
|
||||||
|
messages = await redis_client.xrange(SCORED_STREAM)
|
||||||
|
assert len(messages) == 0
|
||||||
|
|
||||||
|
# Article was still counted as scored
|
||||||
|
assert counters["articles_scored"].total == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publish_and_consume_roundtrip(redis_client: Redis, sample_article: RawArticle):
|
||||||
|
"""End-to-end: publish a RawArticle to the raw stream, consume it via
|
||||||
|
StreamConsumer, process it, and verify the scored output is consumable.
|
||||||
|
"""
|
||||||
|
raw_publisher = StreamPublisher(redis_client, RAW_STREAM)
|
||||||
|
scored_publisher = StreamPublisher(redis_client, SCORED_STREAM)
|
||||||
|
|
||||||
|
# Publish the raw article
|
||||||
|
await raw_publisher.publish(sample_article.model_dump(mode="json"))
|
||||||
|
|
||||||
|
# Verify it's on the raw stream
|
||||||
|
raw_messages = await redis_client.xrange(RAW_STREAM)
|
||||||
|
assert len(raw_messages) == 1
|
||||||
|
|
||||||
|
# Parse it back
|
||||||
|
_msg_id, fields = raw_messages[0]
|
||||||
|
data = json.loads(fields[b"data"])
|
||||||
|
parsed = RawArticle.model_validate(data)
|
||||||
|
assert parsed.title == sample_article.title
|
||||||
|
|
||||||
|
# Now process it through the analyzer
|
||||||
|
mock_finbert = AsyncMock()
|
||||||
|
mock_finbert.analyze = AsyncMock(return_value=(0.9, 0.95))
|
||||||
|
mock_ollama = AsyncMock()
|
||||||
|
|
||||||
|
config = SentimentAnalyzerConfig()
|
||||||
|
counters = _make_counters()
|
||||||
|
|
||||||
|
await process_article(
|
||||||
|
parsed,
|
||||||
|
mock_finbert,
|
||||||
|
mock_ollama,
|
||||||
|
scored_publisher,
|
||||||
|
config,
|
||||||
|
counters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify scored output
|
||||||
|
scored_messages = await redis_client.xrange(SCORED_STREAM)
|
||||||
|
assert len(scored_messages) >= 1
|
||||||
|
|
||||||
|
_msg_id, fields = scored_messages[0]
|
||||||
|
scored_data = json.loads(fields[b"data"])
|
||||||
|
scored = ScoredArticle.model_validate(scored_data)
|
||||||
|
assert scored.ticker == "AAPL"
|
||||||
|
assert scored.sentiment_score == pytest.approx(0.9, abs=0.01)
|
||||||
399
tests/integration/test_trading_flow.py
Normal file
399
tests/integration/test_trading_flow.py
Normal file
|
|
@ -0,0 +1,399 @@
|
||||||
|
"""Integration test: signal generator -> trade executor flow.
|
||||||
|
|
||||||
|
Publishes a mock TradeSignal to the ``signals:generated`` Redis stream
|
||||||
|
and verifies that a TradeExecution appears on ``trades:executed``.
|
||||||
|
|
||||||
|
Requires a running Redis instance (from docker-compose).
|
||||||
|
The Alpaca broker is mocked.
|
||||||
|
|
||||||
|
Run with:
|
||||||
|
pytest tests/integration/test_trading_flow.py -v -m integration
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from shared.redis_streams import StreamPublisher
|
||||||
|
from shared.schemas.trading import (
|
||||||
|
AccountInfo,
|
||||||
|
OrderResult,
|
||||||
|
OrderSide,
|
||||||
|
OrderStatus,
|
||||||
|
PositionInfo,
|
||||||
|
SignalDirection,
|
||||||
|
TradeExecution,
|
||||||
|
TradeSignal,
|
||||||
|
)
|
||||||
|
from services.trade_executor.config import TradeExecutorConfig
|
||||||
|
from services.trade_executor.main import process_signal
|
||||||
|
from services.trade_executor.risk_manager import RiskManager
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
REDIS_URL = "redis://localhost:6379/1" # Use DB 1 to avoid conflicts
|
||||||
|
|
||||||
|
SIGNALS_STREAM = "test:signals:generated"
|
||||||
|
TRADES_STREAM = "test:trades:executed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def redis_client():
|
||||||
|
"""Provide a clean Redis connection on DB 1 and clean up streams after."""
|
||||||
|
client = Redis.from_url(REDIS_URL, decode_responses=False)
|
||||||
|
await client.delete(SIGNALS_STREAM, TRADES_STREAM)
|
||||||
|
yield client
|
||||||
|
await client.delete(SIGNALS_STREAM, TRADES_STREAM)
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_signal() -> TradeSignal:
|
||||||
|
"""Return a sample trade signal for AAPL."""
|
||||||
|
return TradeSignal(
|
||||||
|
ticker="AAPL",
|
||||||
|
direction=SignalDirection.LONG,
|
||||||
|
strength=0.8,
|
||||||
|
strategy_sources=["momentum", "news_driven"],
|
||||||
|
sentiment_context={"avg_score": 0.85, "current_price": 190.50},
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_account() -> AccountInfo:
|
||||||
|
"""Return a mock account with 100k equity."""
|
||||||
|
return AccountInfo(
|
||||||
|
equity=100_000.0,
|
||||||
|
cash=50_000.0,
|
||||||
|
buying_power=100_000.0,
|
||||||
|
portfolio_value=100_000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_order_result() -> OrderResult:
|
||||||
|
"""Return a mock filled order result."""
|
||||||
|
return OrderResult(
|
||||||
|
order_id="test-order-001",
|
||||||
|
ticker="AAPL",
|
||||||
|
side=OrderSide.BUY,
|
||||||
|
qty=20.0,
|
||||||
|
filled_price=190.50,
|
||||||
|
status=OrderStatus.FILLED,
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Mock counters
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeCounter:
|
||||||
|
def __init__(self):
|
||||||
|
self.total = 0
|
||||||
|
self.attrs: list[dict] = []
|
||||||
|
|
||||||
|
def add(self, amount: int = 1, attributes: dict | None = None):
|
||||||
|
self.total += amount
|
||||||
|
if attributes:
|
||||||
|
self.attrs.append(attributes)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeHistogram:
|
||||||
|
def __init__(self):
|
||||||
|
self.values: list[float] = []
|
||||||
|
|
||||||
|
def record(self, value: float, attributes: dict | None = None):
|
||||||
|
self.values.append(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_counters() -> dict:
|
||||||
|
return {
|
||||||
|
"trades_executed": _FakeCounter(),
|
||||||
|
"rejections": _FakeCounter(),
|
||||||
|
"fill_latency": _FakeHistogram(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_signal_produces_trade_execution(
|
||||||
|
redis_client: Redis,
|
||||||
|
sample_signal: TradeSignal,
|
||||||
|
mock_account: AccountInfo,
|
||||||
|
mock_order_result: OrderResult,
|
||||||
|
):
|
||||||
|
"""Process a trade signal through the executor and verify a
|
||||||
|
TradeExecution is published to the trades:executed stream.
|
||||||
|
"""
|
||||||
|
publisher = StreamPublisher(redis_client, TRADES_STREAM)
|
||||||
|
counters = _make_counters()
|
||||||
|
|
||||||
|
# Create mock broker
|
||||||
|
mock_broker = AsyncMock()
|
||||||
|
mock_broker.get_account = AsyncMock(return_value=mock_account)
|
||||||
|
mock_broker.get_positions = AsyncMock(return_value=[])
|
||||||
|
mock_broker.submit_order = AsyncMock(return_value=mock_order_result)
|
||||||
|
|
||||||
|
# Create risk manager with the mock broker, patching market hours check
|
||||||
|
config = TradeExecutorConfig()
|
||||||
|
risk_manager = RiskManager(config, mock_broker)
|
||||||
|
|
||||||
|
# Patch _is_market_hours to always return True
|
||||||
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||||
|
await process_signal(
|
||||||
|
sample_signal,
|
||||||
|
risk_manager,
|
||||||
|
mock_broker,
|
||||||
|
publisher,
|
||||||
|
counters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the broker was called
|
||||||
|
mock_broker.submit_order.assert_called_once()
|
||||||
|
order_arg = mock_broker.submit_order.call_args[0][0]
|
||||||
|
assert order_arg.ticker == "AAPL"
|
||||||
|
assert order_arg.side == OrderSide.BUY
|
||||||
|
|
||||||
|
# Verify a TradeExecution was published
|
||||||
|
messages = await redis_client.xrange(TRADES_STREAM)
|
||||||
|
assert len(messages) == 1
|
||||||
|
|
||||||
|
_msg_id, fields = messages[0]
|
||||||
|
data = json.loads(fields[b"data"])
|
||||||
|
execution = TradeExecution.model_validate(data)
|
||||||
|
|
||||||
|
assert execution.ticker == "AAPL"
|
||||||
|
assert execution.side == OrderSide.BUY
|
||||||
|
assert execution.qty == 20.0
|
||||||
|
assert execution.price == 190.50
|
||||||
|
assert execution.status == OrderStatus.FILLED
|
||||||
|
|
||||||
|
# Counter checks
|
||||||
|
assert counters["trades_executed"].total == 1
|
||||||
|
assert len(counters["fill_latency"].values) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_short_signal_produces_sell_execution(
|
||||||
|
redis_client: Redis,
|
||||||
|
mock_account: AccountInfo,
|
||||||
|
):
|
||||||
|
"""A SHORT signal should produce a SELL order."""
|
||||||
|
short_signal = TradeSignal(
|
||||||
|
ticker="TSLA",
|
||||||
|
direction=SignalDirection.SHORT,
|
||||||
|
strength=0.7,
|
||||||
|
strategy_sources=["mean_reversion"],
|
||||||
|
sentiment_context={"avg_score": -0.6, "current_price": 250.00},
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
sell_result = OrderResult(
|
||||||
|
order_id="test-order-002",
|
||||||
|
ticker="TSLA",
|
||||||
|
side=OrderSide.SELL,
|
||||||
|
qty=14.0,
|
||||||
|
filled_price=250.00,
|
||||||
|
status=OrderStatus.FILLED,
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
publisher = StreamPublisher(redis_client, TRADES_STREAM)
|
||||||
|
counters = _make_counters()
|
||||||
|
|
||||||
|
mock_broker = AsyncMock()
|
||||||
|
mock_broker.get_account = AsyncMock(return_value=mock_account)
|
||||||
|
mock_broker.get_positions = AsyncMock(return_value=[])
|
||||||
|
mock_broker.submit_order = AsyncMock(return_value=sell_result)
|
||||||
|
|
||||||
|
config = TradeExecutorConfig()
|
||||||
|
risk_manager = RiskManager(config, mock_broker)
|
||||||
|
|
||||||
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||||
|
await process_signal(
|
||||||
|
short_signal,
|
||||||
|
risk_manager,
|
||||||
|
mock_broker,
|
||||||
|
publisher,
|
||||||
|
counters,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = await redis_client.xrange(TRADES_STREAM)
|
||||||
|
assert len(messages) == 1
|
||||||
|
|
||||||
|
_msg_id, fields = messages[0]
|
||||||
|
data = json.loads(fields[b"data"])
|
||||||
|
execution = TradeExecution.model_validate(data)
|
||||||
|
|
||||||
|
assert execution.ticker == "TSLA"
|
||||||
|
assert execution.side == OrderSide.SELL
|
||||||
|
assert execution.status == OrderStatus.FILLED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_risk_rejection_does_not_publish(
|
||||||
|
redis_client: Redis,
|
||||||
|
sample_signal: TradeSignal,
|
||||||
|
mock_account: AccountInfo,
|
||||||
|
):
|
||||||
|
"""When risk checks fail (outside market hours), no TradeExecution
|
||||||
|
should be published.
|
||||||
|
"""
|
||||||
|
publisher = StreamPublisher(redis_client, TRADES_STREAM)
|
||||||
|
counters = _make_counters()
|
||||||
|
|
||||||
|
mock_broker = AsyncMock()
|
||||||
|
mock_broker.get_account = AsyncMock(return_value=mock_account)
|
||||||
|
mock_broker.get_positions = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
config = TradeExecutorConfig()
|
||||||
|
risk_manager = RiskManager(config, mock_broker)
|
||||||
|
|
||||||
|
# Market is closed -> risk check fails
|
||||||
|
with patch.object(RiskManager, "_is_market_hours", return_value=False):
|
||||||
|
await process_signal(
|
||||||
|
sample_signal,
|
||||||
|
risk_manager,
|
||||||
|
mock_broker,
|
||||||
|
publisher,
|
||||||
|
counters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No order should have been submitted
|
||||||
|
mock_broker.submit_order.assert_not_called()
|
||||||
|
|
||||||
|
# No messages on the trades stream
|
||||||
|
messages = await redis_client.xrange(TRADES_STREAM)
|
||||||
|
assert len(messages) == 0
|
||||||
|
|
||||||
|
# Rejection counter should be incremented
|
||||||
|
assert counters["rejections"].total == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_positions_rejection(
|
||||||
|
redis_client: Redis,
|
||||||
|
sample_signal: TradeSignal,
|
||||||
|
mock_account: AccountInfo,
|
||||||
|
):
|
||||||
|
"""When the maximum number of positions is reached, the signal
|
||||||
|
should be rejected.
|
||||||
|
"""
|
||||||
|
publisher = StreamPublisher(redis_client, TRADES_STREAM)
|
||||||
|
counters = _make_counters()
|
||||||
|
|
||||||
|
# Create enough mock positions to exceed the limit
|
||||||
|
existing_positions = [
|
||||||
|
PositionInfo(
|
||||||
|
ticker=f"STOCK{i}",
|
||||||
|
qty=10.0,
|
||||||
|
avg_entry=100.0,
|
||||||
|
current_price=105.0,
|
||||||
|
unrealized_pnl=50.0,
|
||||||
|
market_value=1050.0,
|
||||||
|
)
|
||||||
|
for i in range(25) # Default max is 20
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_broker = AsyncMock()
|
||||||
|
mock_broker.get_account = AsyncMock(return_value=mock_account)
|
||||||
|
mock_broker.get_positions = AsyncMock(return_value=existing_positions)
|
||||||
|
|
||||||
|
config = TradeExecutorConfig()
|
||||||
|
risk_manager = RiskManager(config, mock_broker)
|
||||||
|
|
||||||
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||||
|
await process_signal(
|
||||||
|
sample_signal,
|
||||||
|
risk_manager,
|
||||||
|
mock_broker,
|
||||||
|
publisher,
|
||||||
|
counters,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_broker.submit_order.assert_not_called()
|
||||||
|
|
||||||
|
messages = await redis_client.xrange(TRADES_STREAM)
|
||||||
|
assert len(messages) == 0
|
||||||
|
|
||||||
|
assert counters["rejections"].total == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publish_signal_and_consume_execution_roundtrip(
|
||||||
|
redis_client: Redis,
|
||||||
|
sample_signal: TradeSignal,
|
||||||
|
mock_account: AccountInfo,
|
||||||
|
mock_order_result: OrderResult,
|
||||||
|
):
|
||||||
|
"""End-to-end: publish a signal to the signals stream, process it,
|
||||||
|
and verify the execution can be read back from the trades stream.
|
||||||
|
"""
|
||||||
|
# Publish the signal to the signals stream
|
||||||
|
signal_publisher = StreamPublisher(redis_client, SIGNALS_STREAM)
|
||||||
|
await signal_publisher.publish(sample_signal.model_dump(mode="json"))
|
||||||
|
|
||||||
|
# Verify the signal is on the stream
|
||||||
|
signal_messages = await redis_client.xrange(SIGNALS_STREAM)
|
||||||
|
assert len(signal_messages) == 1
|
||||||
|
|
||||||
|
# Parse it back to verify serialization
|
||||||
|
_msg_id, fields = signal_messages[0]
|
||||||
|
data = json.loads(fields[b"data"])
|
||||||
|
parsed_signal = TradeSignal.model_validate(data)
|
||||||
|
assert parsed_signal.ticker == "AAPL"
|
||||||
|
assert parsed_signal.direction == SignalDirection.LONG
|
||||||
|
|
||||||
|
# Process the signal through the executor
|
||||||
|
trades_publisher = StreamPublisher(redis_client, TRADES_STREAM)
|
||||||
|
counters = _make_counters()
|
||||||
|
|
||||||
|
mock_broker = AsyncMock()
|
||||||
|
mock_broker.get_account = AsyncMock(return_value=mock_account)
|
||||||
|
mock_broker.get_positions = AsyncMock(return_value=[])
|
||||||
|
mock_broker.submit_order = AsyncMock(return_value=mock_order_result)
|
||||||
|
|
||||||
|
config = TradeExecutorConfig()
|
||||||
|
risk_manager = RiskManager(config, mock_broker)
|
||||||
|
|
||||||
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||||
|
await process_signal(
|
||||||
|
parsed_signal,
|
||||||
|
risk_manager,
|
||||||
|
mock_broker,
|
||||||
|
trades_publisher,
|
||||||
|
counters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read the execution from the trades stream
|
||||||
|
trade_messages = await redis_client.xrange(TRADES_STREAM)
|
||||||
|
assert len(trade_messages) == 1
|
||||||
|
|
||||||
|
_msg_id, fields = trade_messages[0]
|
||||||
|
data = json.loads(fields[b"data"])
|
||||||
|
execution = TradeExecution.model_validate(data)
|
||||||
|
|
||||||
|
assert execution.ticker == "AAPL"
|
||||||
|
assert execution.side == OrderSide.BUY
|
||||||
|
assert execution.status == OrderStatus.FILLED
|
||||||
|
assert execution.price == 190.50
|
||||||
Loading…
Add table
Add a link
Reference in a new issue