trading/tests/services/test_sentiment_analyzer.py

639 lines
23 KiB
Python
Raw Permalink Normal View History

"""Tests for the sentiment analyzer service.
Covers FinBERT analyzer, Ollama analyzer, ticker extraction, and the main
service flow including DB persistence.
"""
from __future__ import annotations
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy.exc import IntegrityError
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
from services.sentiment_analyzer.config import SentimentAnalyzerConfig
from services.sentiment_analyzer.main import process_article
from services.sentiment_analyzer.ticker_extractor import extract_tickers
from shared.schemas.news import RawArticle
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_raw_article(**overrides) -> RawArticle:
defaults = {
"source": "test",
"url": "https://example.com/article",
"title": "Test Article About $AAPL",
"content": "Apple Inc announced strong earnings.",
"published_at": datetime(2026, 1, 15, tzinfo=timezone.utc),
"fetched_at": datetime(2026, 1, 15, 0, 5, tzinfo=timezone.utc),
"content_hash": "abc123",
}
defaults.update(overrides)
return RawArticle(**defaults)
def _make_pipeline_result(label: str, score: float) -> list[list[dict]]:
"""Build a return value matching transformers pipeline(return_all_scores=True)."""
labels = {"positive": score if label == "positive" else 0.0,
"negative": score if label == "negative" else 0.0,
"neutral": score if label == "neutral" else 0.0}
# Distribute remaining probability
remaining = 1.0 - score
other_labels = [l for l in labels if l != label]
for ol in other_labels:
labels[ol] = remaining / len(other_labels)
return [[{"label": l, "score": s} for l, s in labels.items()]]
# ---------------------------------------------------------------------------
# FinBERT Analyzer Tests
# ---------------------------------------------------------------------------
class TestFinBERTAnalyzer:
"""Tests for FinBERTAnalyzer with a mocked transformers pipeline."""
@pytest.mark.asyncio
async def test_finbert_positive_sentiment(self):
"""Positive article should yield a positive score."""
mock_pipe = MagicMock()
mock_pipe.return_value = _make_pipeline_result("positive", 0.9)
analyzer = FinBERTAnalyzer(model_name="test-model")
analyzer._pipeline = mock_pipe
score, confidence = await analyzer.analyze(
"Apple beats earnings expectations",
"Apple reported revenue above analyst estimates.",
)
assert score > 0.0, f"Expected positive score, got {score}"
assert confidence == pytest.approx(0.9, abs=0.01)
mock_pipe.assert_called_once()
@pytest.mark.asyncio
async def test_finbert_negative_sentiment(self):
"""Negative article should yield a negative score."""
mock_pipe = MagicMock()
mock_pipe.return_value = _make_pipeline_result("negative", 0.85)
analyzer = FinBERTAnalyzer(model_name="test-model")
analyzer._pipeline = mock_pipe
score, confidence = await analyzer.analyze(
"Major bank reports massive losses",
"The bank lost $2 billion in the quarter.",
)
assert score < 0.0, f"Expected negative score, got {score}"
assert confidence == pytest.approx(0.85, abs=0.01)
@pytest.mark.asyncio
async def test_finbert_neutral_sentiment(self):
"""Neutral article should yield a near-zero score."""
mock_pipe = MagicMock()
mock_pipe.return_value = _make_pipeline_result("neutral", 0.8)
analyzer = FinBERTAnalyzer(model_name="test-model")
analyzer._pipeline = mock_pipe
score, confidence = await analyzer.analyze(
"Company releases quarterly report",
"The quarterly report was filed with the SEC.",
)
# Neutral dominant => score close to zero (neutral maps to 0).
# The small residual comes from the remaining probability split
# between positive and negative.
assert abs(score) < 0.2, f"Expected near-zero score, got {score}"
assert confidence == pytest.approx(0.8, abs=0.01)
# ---------------------------------------------------------------------------
# Ollama Analyzer Tests
# ---------------------------------------------------------------------------
class TestOllamaAnalyzer:
"""Tests for OllamaAnalyzer with a mocked ollama client."""
@pytest.mark.asyncio
async def test_ollama_successful_analysis(self):
"""Valid JSON response should be parsed correctly."""
mock_client = AsyncMock()
mock_client.chat.return_value = {
"message": {
"content": '{"sentiment_score": 0.75, "confidence": 0.85, "entities": ["AAPL"]}'
}
}
analyzer = OllamaAnalyzer(model="test-model")
analyzer._client = mock_client
score, confidence = await analyzer.analyze("Good news for Apple", "Apple stock surges.")
assert score == pytest.approx(0.75)
assert confidence == pytest.approx(0.85)
@pytest.mark.asyncio
async def test_ollama_parse_error_returns_zero(self):
"""Invalid JSON should return (0.0, 0.0) fallback."""
mock_client = AsyncMock()
mock_client.chat.return_value = {
"message": {"content": "I think the sentiment is positive but I'm not sure."}
}
analyzer = OllamaAnalyzer(model="test-model")
analyzer._client = mock_client
score, confidence = await analyzer.analyze("Some headline", "Some content")
assert score == 0.0
assert confidence == 0.0
@pytest.mark.asyncio
async def test_ollama_connection_error_returns_zero(self):
"""Network/connection errors should return (0.0, 0.0) fallback."""
mock_client = AsyncMock()
mock_client.chat.side_effect = ConnectionError("Cannot reach Ollama")
analyzer = OllamaAnalyzer(model="test-model")
analyzer._client = mock_client
score, confidence = await analyzer.analyze("Some headline", "Some content")
assert score == 0.0
assert confidence == 0.0
@pytest.mark.asyncio
async def test_ollama_markdown_code_fence(self):
"""JSON wrapped in markdown code fences should still be parsed."""
mock_client = AsyncMock()
mock_client.chat.return_value = {
"message": {
"content": '```json\n{"sentiment_score": -0.5, "confidence": 0.7, "entities": []}\n```'
}
}
analyzer = OllamaAnalyzer(model="test-model")
analyzer._client = mock_client
score, confidence = await analyzer.analyze("Bad news", "Markets tumble.")
assert score == pytest.approx(-0.5)
assert confidence == pytest.approx(0.7)
# ---------------------------------------------------------------------------
# Ticker Extraction Tests
# ---------------------------------------------------------------------------
class TestTickerExtraction:
"""Tests for the ticker extraction utility."""
def test_ticker_extraction_dollar_sign(self):
"""$AAPL should extract AAPL."""
tickers = extract_tickers("Big news for $AAPL today.")
assert "AAPL" in tickers
def test_ticker_extraction_exchange_prefix(self):
"""NASDAQ:TSLA should extract TSLA."""
tickers = extract_tickers("Check out NASDAQ:TSLA performance.")
assert "TSLA" in tickers
def test_ticker_extraction_nyse_prefix(self):
"""NYSE:AAPL should extract AAPL."""
tickers = extract_tickers("NYSE:AAPL is trading higher.")
assert "AAPL" in tickers
def test_ticker_extraction_filters_false_positives(self):
"""Common words like CEO, IPO, ETF, SEC, NYSE should be filtered."""
tickers = extract_tickers(
"The CEO announced a new IPO. The ETF was approved by the SEC on NYSE."
)
assert "CEO" not in tickers
assert "IPO" not in tickers
assert "ETF" not in tickers
assert "SEC" not in tickers
assert "NYSE" not in tickers
def test_ticker_extraction_deduplicates(self):
"""Repeated mentions of the same ticker should appear only once."""
tickers = extract_tickers("$AAPL rose 5%. $AAPL is now above $200. NASDAQ:AAPL is great.")
assert tickers.count("AAPL") == 1
def test_ticker_extraction_multiple_tickers(self):
"""Multiple different tickers should all be extracted."""
tickers = extract_tickers("$AAPL and $MSFT both reported earnings. $GOOG is next.")
assert "AAPL" in tickers
assert "MSFT" in tickers
assert "GOOG" in tickers
def test_ticker_extraction_empty_text(self):
"""Empty text should return no tickers."""
assert extract_tickers("") == []
def test_ticker_extraction_no_tickers(self):
"""Text with no ticker-like patterns should return empty list."""
tickers = extract_tickers("The market was flat today with no major movers.")
# Should be empty — all uppercase words are filtered as false positives or too short.
assert len(tickers) == 0
# ---------------------------------------------------------------------------
# Ollama Fallback Routing Test
# ---------------------------------------------------------------------------
class TestFallbackRouting:
"""Test that Ollama is called when FinBERT confidence is below threshold."""
@pytest.mark.asyncio
async def test_ollama_fallback_on_low_confidence(self):
"""When FinBERT confidence < threshold, Ollama should be called."""
# FinBERT returns low confidence
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.1, 0.4)) # confidence 0.4 < 0.6 threshold
# Ollama returns higher confidence
ollama = AsyncMock(spec=OllamaAnalyzer)
ollama.analyze = AsyncMock(return_value=(0.8, 0.9))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
# Mock counters
counters = {
"articles_scored": MagicMock(),
"finbert_count": MagicMock(),
"ollama_count": MagicMock(),
"inference_latency": MagicMock(),
}
article = _make_raw_article(title="Test $AAPL Article", content="Apple stock rises.")
await process_article(article, finbert, ollama, publisher, config, counters)
# Both should have been called
finbert.analyze.assert_called_once()
ollama.analyze.assert_called_once()
counters["finbert_count"].add.assert_called_once_with(1)
counters["ollama_count"].add.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_no_ollama_on_high_confidence(self):
"""When FinBERT confidence >= threshold, Ollama should NOT be called."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.8, 0.9)) # confidence 0.9 >= 0.6
ollama = AsyncMock(spec=OllamaAnalyzer)
ollama.analyze = AsyncMock(return_value=(0.5, 0.7))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = {
"articles_scored": MagicMock(),
"finbert_count": MagicMock(),
"ollama_count": MagicMock(),
"inference_latency": MagicMock(),
}
article = _make_raw_article(title="Test $AAPL Article", content="Apple stock rises.")
await process_article(article, finbert, ollama, publisher, config, counters)
finbert.analyze.assert_called_once()
ollama.analyze.assert_not_called()
counters["ollama_count"].add.assert_not_called()
# ---------------------------------------------------------------------------
# Main Flow / Integration Test
# ---------------------------------------------------------------------------
class TestMainFlow:
"""Test the full process_article flow with mocked analyzers and Redis."""
@pytest.mark.asyncio
async def test_main_flow_publishes_scored_articles(self):
"""process_article should publish a ScoredArticle for each ticker found."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.75, 0.88))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = {
"articles_scored": MagicMock(),
"finbert_count": MagicMock(),
"ollama_count": MagicMock(),
"inference_latency": MagicMock(),
}
# Article mentions two tickers
article = _make_raw_article(
title="$AAPL and $MSFT report strong earnings",
content="Both Apple and Microsoft beat estimates.",
)
await process_article(article, finbert, ollama, publisher, config, counters)
# Should publish one ScoredArticle per ticker
assert publisher.publish.call_count == 2
counters["articles_scored"].add.assert_called_once_with(1)
# Verify the published data
calls = publisher.publish.call_args_list
published_tickers = {call.args[0]["ticker"] for call in calls}
assert "AAPL" in published_tickers
assert "MSFT" in published_tickers
# Each published message should have the correct sentiment score
for call in calls:
data = call.args[0]
assert data["sentiment_score"] == pytest.approx(0.75)
assert data["confidence"] == pytest.approx(0.88)
assert data["model_used"] == "finbert"
@pytest.mark.asyncio
async def test_main_flow_no_tickers_no_publish(self):
"""Articles with no tickers should not publish anything."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.5, 0.9))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock()
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = {
"articles_scored": MagicMock(),
"finbert_count": MagicMock(),
"ollama_count": MagicMock(),
"inference_latency": MagicMock(),
}
article = _make_raw_article(
title="Market is flat today",
content="Nothing much happening in the market.",
)
await process_article(article, finbert, ollama, publisher, config, counters)
publisher.publish.assert_not_called()
# Still counted as scored
counters["articles_scored"].add.assert_called_once_with(1)
# ---------------------------------------------------------------------------
# Helpers for DB persistence tests
# ---------------------------------------------------------------------------
def _make_counters() -> dict:
"""Create a dict of mock OpenTelemetry counter/histogram instruments."""
return {
"articles_scored": MagicMock(),
"finbert_count": MagicMock(),
"ollama_count": MagicMock(),
"inference_latency": MagicMock(),
}
def _make_mock_db_session_factory(session: AsyncMock | None = None) -> AsyncMock:
"""Create a mock async_sessionmaker that yields a mock session.
The returned factory, when called, returns an async context manager
that yields ``session`` (a mock AsyncSession).
"""
if session is None:
session = AsyncMock()
session.add = MagicMock()
session.commit = AsyncMock()
factory = MagicMock()
# factory() should return an async context manager (the session)
ctx = AsyncMock()
ctx.__aenter__ = AsyncMock(return_value=session)
ctx.__aexit__ = AsyncMock(return_value=False)
factory.return_value = ctx
return factory
# ---------------------------------------------------------------------------
# DB Persistence Tests
# ---------------------------------------------------------------------------
class TestDBPersistence:
"""Tests for the DB write step in process_article."""
@pytest.mark.asyncio
async def test_db_write_creates_article_and_sentiments(self):
"""When db_session_factory is provided, Article and ArticleSentiment rows are created."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.75, 0.88))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = _make_counters()
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
db_factory = _make_mock_db_session_factory(mock_session)
article = _make_raw_article(
title="$AAPL and $MSFT report strong earnings",
content="Both Apple and Microsoft beat estimates.",
)
await process_article(
article, finbert, ollama, publisher, config, counters, db_factory
)
# session.add should be called: 1 Article + 2 ArticleSentiments = 3 calls
assert mock_session.add.call_count == 3
# Verify the types of objects added
added_objects = [call.args[0] for call in mock_session.add.call_args_list]
from shared.models.news import Article, ArticleSentiment
articles = [o for o in added_objects if isinstance(o, Article)]
sentiments = [o for o in added_objects if isinstance(o, ArticleSentiment)]
assert len(articles) == 1
assert len(sentiments) == 2
# Verify article fields
db_article = articles[0]
assert db_article.source == "test"
assert db_article.url == "https://example.com/article"
assert db_article.content_hash == "abc123"
# Verify sentiment fields
tickers_in_sentiments = {s.ticker for s in sentiments}
assert "AAPL" in tickers_in_sentiments
assert "MSFT" in tickers_in_sentiments
for s in sentiments:
assert s.score == 0.75
assert s.confidence == 0.88
assert s.model_used == "finbert"
# session.commit should be called once
mock_session.commit.assert_awaited_once()
# Redis publishing should still happen
assert publisher.publish.call_count == 2
@pytest.mark.asyncio
async def test_db_duplicate_article_handled_gracefully(self):
"""Duplicate content_hash (IntegrityError) should be caught silently."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.5, 0.9))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = _make_counters()
mock_session = AsyncMock()
mock_session.add = MagicMock()
# Simulate IntegrityError on commit (duplicate content_hash)
mock_session.commit = AsyncMock(
side_effect=IntegrityError("duplicate", {}, Exception())
)
db_factory = _make_mock_db_session_factory(mock_session)
article = _make_raw_article(
title="$AAPL news",
content="Apple earnings report.",
)
# Should NOT raise — IntegrityError is caught
await process_article(
article, finbert, ollama, publisher, config, counters, db_factory
)
# Redis publishing should still have happened before the DB write
assert publisher.publish.call_count >= 1
counters["articles_scored"].add.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_db_none_backward_compatible(self):
"""When db_session_factory is None, process_article works without DB writes."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.6, 0.85))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = _make_counters()
article = _make_raw_article(
title="$GOOG quarterly results",
content="Google reports revenue growth.",
)
# Pass db_session_factory=None explicitly (the default)
await process_article(
article, finbert, ollama, publisher, config, counters, db_session_factory=None
)
# Redis publishing should work as before
assert publisher.publish.call_count >= 1
counters["articles_scored"].add.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_db_error_does_not_break_processing(self):
"""A DB error should be logged but not prevent Redis publishing."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.3, 0.7))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = _make_counters()
mock_session = AsyncMock()
mock_session.add = MagicMock()
# Simulate a generic DB error
mock_session.commit = AsyncMock(
side_effect=RuntimeError("connection lost")
)
db_factory = _make_mock_db_session_factory(mock_session)
article = _make_raw_article(
title="$TSLA stock update",
content="Tesla announces new factory.",
)
# Should NOT raise — generic exceptions in DB write are caught
await process_article(
article, finbert, ollama, publisher, config, counters, db_factory
)
# Redis publishing should still have happened
assert publisher.publish.call_count >= 1
counters["articles_scored"].add.assert_called_once_with(1)