Add integration tests for the news pipeline (test_news_pipeline.py) and trading flow (test_trading_flow.py) using real Redis with mocked FinBERT and Alpaca. Add seed_strategies.py to insert default strategies (momentum, mean_reversion, news_driven) with equal weights. Add smoke_test.sh for end-to-end stack validation. Update pyproject.toml with integration marker and scripts package discovery.
299 lines
9.4 KiB
Python
299 lines
9.4 KiB
Python
"""Integration test: news fetcher -> sentiment analyzer pipeline.
|
|
|
|
Publishes a mock RawArticle to the ``news:raw`` Redis stream and verifies
|
|
that a ScoredArticle appears on ``news:scored``.
|
|
|
|
Requires a running Redis instance (from docker-compose).
|
|
FinBERT and Ollama are mocked so the test does not need GPU / model weights.
|
|
|
|
Run with:
|
|
pytest tests/integration/test_news_pipeline.py -v -m integration
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from redis.asyncio import Redis
|
|
|
|
from shared.redis_streams import StreamConsumer, StreamPublisher
|
|
from shared.schemas.news import RawArticle, ScoredArticle
|
|
from services.sentiment_analyzer.main import process_article
|
|
from services.sentiment_analyzer.config import SentimentAnalyzerConfig
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
REDIS_URL = "redis://localhost:6379/1" # Use DB 1 to avoid conflicts
|
|
|
|
RAW_STREAM = "test:news:raw"
|
|
SCORED_STREAM = "test:news:scored"
|
|
|
|
|
|
@pytest.fixture
|
|
async def redis_client():
|
|
"""Provide a clean Redis connection on DB 1 and clean up streams after."""
|
|
client = Redis.from_url(REDIS_URL, decode_responses=False)
|
|
# Ensure streams are clean before the test
|
|
await client.delete(RAW_STREAM, SCORED_STREAM)
|
|
yield client
|
|
# Clean up after
|
|
await client.delete(RAW_STREAM, SCORED_STREAM)
|
|
await client.aclose()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_article() -> RawArticle:
|
|
"""Return a sample RawArticle mentioning AAPL."""
|
|
return RawArticle(
|
|
source="rss",
|
|
url="https://example.com/aapl-news",
|
|
title="Apple Inc AAPL reports record quarterly earnings",
|
|
content=(
|
|
"Apple Inc ($AAPL) reported record-breaking quarterly earnings "
|
|
"today, beating analyst estimates by a wide margin. Revenue grew "
|
|
"15% year-over-year driven by strong iPhone and Services demand."
|
|
),
|
|
published_at=datetime.now(timezone.utc),
|
|
fetched_at=datetime.now(timezone.utc),
|
|
content_hash="test-hash-aapl-001",
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Mock counters (stand-in for OpenTelemetry instruments)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _FakeCounter:
|
|
"""Minimal fake that records how many times ``add`` was called."""
|
|
|
|
def __init__(self):
|
|
self.total = 0
|
|
|
|
def add(self, amount: int = 1, attributes: dict | None = None):
|
|
self.total += amount
|
|
|
|
|
|
class _FakeHistogram:
|
|
def __init__(self):
|
|
self.values: list[float] = []
|
|
|
|
def record(self, value: float, attributes: dict | None = None):
|
|
self.values.append(value)
|
|
|
|
|
|
def _make_counters() -> dict:
|
|
return {
|
|
"articles_scored": _FakeCounter(),
|
|
"finbert_count": _FakeCounter(),
|
|
"ollama_count": _FakeCounter(),
|
|
"inference_latency": _FakeHistogram(),
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_raw_article_flows_to_scored(redis_client: Redis, sample_article: RawArticle):
|
|
"""Publish a RawArticle to news:raw, run the sentiment analyzer's
|
|
process_article function, and verify a ScoredArticle is published
|
|
to news:scored.
|
|
"""
|
|
publisher = StreamPublisher(redis_client, SCORED_STREAM)
|
|
|
|
# Mock FinBERT to return high-confidence positive sentiment
|
|
mock_finbert = AsyncMock()
|
|
mock_finbert.analyze = AsyncMock(return_value=(0.85, 0.92))
|
|
|
|
# Mock Ollama (should not be called when FinBERT confidence is high)
|
|
mock_ollama = AsyncMock()
|
|
mock_ollama.analyze = AsyncMock(return_value=(0.0, 0.0))
|
|
|
|
config = SentimentAnalyzerConfig()
|
|
counters = _make_counters()
|
|
|
|
# Process the article
|
|
await process_article(
|
|
sample_article,
|
|
mock_finbert,
|
|
mock_ollama,
|
|
publisher,
|
|
config,
|
|
counters,
|
|
)
|
|
|
|
# FinBERT should have been called, Ollama should NOT
|
|
mock_finbert.analyze.assert_called_once()
|
|
mock_ollama.analyze.assert_not_called()
|
|
|
|
# Verify a ScoredArticle was published to the scored stream
|
|
messages = await redis_client.xrange(SCORED_STREAM)
|
|
assert len(messages) >= 1, "Expected at least one message on the scored stream"
|
|
|
|
# Parse the first message
|
|
_msg_id, fields = messages[0]
|
|
data = json.loads(fields[b"data"])
|
|
scored = ScoredArticle.model_validate(data)
|
|
|
|
assert scored.ticker == "AAPL"
|
|
assert scored.sentiment_score == pytest.approx(0.85, abs=0.01)
|
|
assert scored.confidence == pytest.approx(0.92, abs=0.01)
|
|
assert scored.model_used == "finbert"
|
|
assert scored.source == "rss"
|
|
assert scored.title == sample_article.title
|
|
|
|
# Counter checks
|
|
assert counters["articles_scored"].total == 1
|
|
assert counters["finbert_count"].total == 1
|
|
assert counters["ollama_count"].total == 0
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_low_confidence_falls_back_to_ollama(redis_client: Redis, sample_article: RawArticle):
|
|
"""When FinBERT confidence is below the threshold, the sentiment
|
|
analyzer should fall back to Ollama.
|
|
"""
|
|
publisher = StreamPublisher(redis_client, SCORED_STREAM)
|
|
|
|
# FinBERT returns low confidence -> triggers Ollama fallback
|
|
mock_finbert = AsyncMock()
|
|
mock_finbert.analyze = AsyncMock(return_value=(0.3, 0.4))
|
|
|
|
mock_ollama = AsyncMock()
|
|
mock_ollama.analyze = AsyncMock(return_value=(0.72, 0.88))
|
|
|
|
config = SentimentAnalyzerConfig()
|
|
config.finbert_confidence_threshold = 0.6 # 0.4 < 0.6 -> fallback
|
|
counters = _make_counters()
|
|
|
|
await process_article(
|
|
sample_article,
|
|
mock_finbert,
|
|
mock_ollama,
|
|
publisher,
|
|
config,
|
|
counters,
|
|
)
|
|
|
|
# Both should have been called
|
|
mock_finbert.analyze.assert_called_once()
|
|
mock_ollama.analyze.assert_called_once()
|
|
|
|
# Verify the published message used Ollama's scores
|
|
messages = await redis_client.xrange(SCORED_STREAM)
|
|
assert len(messages) >= 1
|
|
|
|
_msg_id, fields = messages[0]
|
|
data = json.loads(fields[b"data"])
|
|
scored = ScoredArticle.model_validate(data)
|
|
|
|
assert scored.model_used == "ollama"
|
|
assert scored.sentiment_score == pytest.approx(0.72, abs=0.01)
|
|
assert scored.confidence == pytest.approx(0.88, abs=0.01)
|
|
|
|
# Counter checks
|
|
assert counters["ollama_count"].total == 1
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_article_without_tickers_does_not_publish(redis_client: Redis):
|
|
"""An article with no recognizable ticker mentions should not produce
|
|
any ScoredArticle messages.
|
|
"""
|
|
article = RawArticle(
|
|
source="reddit",
|
|
url="https://reddit.com/r/finance/post123",
|
|
title="General market outlook for next week",
|
|
content="The market is looking bullish with strong consumer spending data.",
|
|
published_at=datetime.now(timezone.utc),
|
|
fetched_at=datetime.now(timezone.utc),
|
|
content_hash="test-hash-no-ticker-001",
|
|
)
|
|
|
|
publisher = StreamPublisher(redis_client, SCORED_STREAM)
|
|
|
|
mock_finbert = AsyncMock()
|
|
mock_finbert.analyze = AsyncMock(return_value=(0.6, 0.85))
|
|
|
|
mock_ollama = AsyncMock()
|
|
|
|
config = SentimentAnalyzerConfig()
|
|
counters = _make_counters()
|
|
|
|
await process_article(
|
|
article,
|
|
mock_finbert,
|
|
mock_ollama,
|
|
publisher,
|
|
config,
|
|
counters,
|
|
)
|
|
|
|
# No tickers extracted -> no messages on the scored stream
|
|
messages = await redis_client.xrange(SCORED_STREAM)
|
|
assert len(messages) == 0
|
|
|
|
# Article was still counted as scored
|
|
assert counters["articles_scored"].total == 1
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_publish_and_consume_roundtrip(redis_client: Redis, sample_article: RawArticle):
|
|
"""End-to-end: publish a RawArticle to the raw stream, consume it via
|
|
StreamConsumer, process it, and verify the scored output is consumable.
|
|
"""
|
|
raw_publisher = StreamPublisher(redis_client, RAW_STREAM)
|
|
scored_publisher = StreamPublisher(redis_client, SCORED_STREAM)
|
|
|
|
# Publish the raw article
|
|
await raw_publisher.publish(sample_article.model_dump(mode="json"))
|
|
|
|
# Verify it's on the raw stream
|
|
raw_messages = await redis_client.xrange(RAW_STREAM)
|
|
assert len(raw_messages) == 1
|
|
|
|
# Parse it back
|
|
_msg_id, fields = raw_messages[0]
|
|
data = json.loads(fields[b"data"])
|
|
parsed = RawArticle.model_validate(data)
|
|
assert parsed.title == sample_article.title
|
|
|
|
# Now process it through the analyzer
|
|
mock_finbert = AsyncMock()
|
|
mock_finbert.analyze = AsyncMock(return_value=(0.9, 0.95))
|
|
mock_ollama = AsyncMock()
|
|
|
|
config = SentimentAnalyzerConfig()
|
|
counters = _make_counters()
|
|
|
|
await process_article(
|
|
parsed,
|
|
mock_finbert,
|
|
mock_ollama,
|
|
scored_publisher,
|
|
config,
|
|
counters,
|
|
)
|
|
|
|
# Verify scored output
|
|
scored_messages = await redis_client.xrange(SCORED_STREAM)
|
|
assert len(scored_messages) >= 1
|
|
|
|
_msg_id, fields = scored_messages[0]
|
|
scored_data = json.loads(fields[b"data"])
|
|
scored = ScoredArticle.model_validate(scored_data)
|
|
assert scored.ticker == "AAPL"
|
|
assert scored.sentiment_score == pytest.approx(0.9, abs=0.01)
|