feat: sentiment analyzer — FinBERT + Ollama tiered analysis

This commit is contained in:
Viktor Barzin 2026-02-22 15:27:06 +00:00
parent 9f46071502
commit 6952a829ae
No known key found for this signature in database
GPG key ID: 0EB088298288D958
11 changed files with 976 additions and 1 deletions

View file

@ -27,7 +27,7 @@ requires = ["setuptools>=70.0"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
include = ["shared*", "tests*"]
include = ["shared*", "services*", "tests*"]
[tool.pytest.ini_options]
asyncio_mode = "auto"

0
services/__init__.py Normal file
View file

View file

@ -0,0 +1 @@
"""Sentiment Analyzer service — FinBERT + Ollama tiered analysis."""

View file

@ -0,0 +1 @@
"""Sentiment analysis backends (FinBERT, Ollama)."""

View file

@ -0,0 +1,113 @@
"""FinBERT-based financial sentiment analyzer.
Uses the ProsusAI/finbert model via the HuggingFace transformers library
to classify article text as positive, negative, or neutral.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any
logger = logging.getLogger(__name__)
class FinBERTAnalyzer:
"""Lazy-loading wrapper around a transformers sentiment-analysis pipeline.
The heavy ``transformers`` + ``torch`` imports and model download happen
only once, on the first call to :meth:`analyze`.
"""
def __init__(self, model_name: str = "ProsusAI/finbert", max_content_length: int = 512) -> None:
self.model_name = model_name
self.max_content_length = max_content_length
self._pipeline: Any | None = None
def _load_pipeline(self) -> Any:
"""Lazily load the transformers pipeline on first use."""
if self._pipeline is None:
from transformers import pipeline # type: ignore[import-untyped]
logger.info("Loading FinBERT model: %s", self.model_name)
self._pipeline = pipeline(
"sentiment-analysis",
model=self.model_name,
return_all_scores=True,
)
logger.info("FinBERT model loaded successfully")
return self._pipeline
async def analyze(self, title: str, content: str) -> tuple[float, float]:
"""Score the sentiment of an article.
Parameters
----------
title:
Article headline.
content:
Article body text.
Returns
-------
tuple[float, float]
``(sentiment_score, confidence)`` where *sentiment_score* is in
``[-1.0, 1.0]`` and *confidence* is in ``[0.0, 1.0]``.
The input text is truncated to ``max_content_length`` tokens by
passing it through the model's tokenizer truncation (handled
automatically by the transformers pipeline).
"""
pipe = self._load_pipeline()
# Combine title and content; the pipeline will truncate to model max tokens.
text = f"{title}. {content}"
# Truncate to a reasonable character length proportional to max_content_length
# tokens (rough estimate: 1 token ~ 4 chars for English). The pipeline
# tokenizer will do the precise truncation, but this avoids sending
# enormous strings.
char_limit = self.max_content_length * 4
text = text[:char_limit]
# Run the blocking model inference in a thread pool so we don't block
# the event loop.
loop = asyncio.get_running_loop()
results = await loop.run_in_executor(
None,
lambda: pipe(text, truncation=True, max_length=self.max_content_length),
)
return self._parse_scores(results)
@staticmethod
def _parse_scores(results: list[list[dict[str, Any]]]) -> tuple[float, float]:
"""Map pipeline output to ``(score, confidence)``.
The ``return_all_scores=True`` pipeline returns a list of lists of dicts:
``[[{"label": "positive", "score": 0.85}, ...]]``.
Mapping:
- ``"positive"`` -> +1
- ``"negative"`` -> -1
- ``"neutral"`` -> 0
The sentiment score is the weighted sum of label polarities scaled by
their softmax probabilities. Confidence is the maximum softmax
probability.
"""
label_map = {"positive": 1.0, "negative": -1.0, "neutral": 0.0}
# results is [[{label, score}, ...]]
scores = results[0]
sentiment_score = 0.0
confidence = 0.0
for entry in scores:
label = entry["label"].lower()
prob = entry["score"]
sentiment_score += label_map.get(label, 0.0) * prob
if prob > confidence:
confidence = prob
return sentiment_score, confidence

View file

@ -0,0 +1,102 @@
"""Ollama-based sentiment analyzer (LLM fallback).
Used when the FinBERT model's confidence is below the configured threshold.
Sends a structured prompt to a local Ollama instance and parses JSON output.
"""
from __future__ import annotations
import json
import logging
logger = logging.getLogger(__name__)
_SYSTEM_PROMPT = (
"You are a financial sentiment analysis assistant. "
"You will be given a news article title and content. "
"Analyze the sentiment and respond with ONLY valid JSON in this exact format:\n"
'{"sentiment_score": <float between -1.0 and 1.0>, '
'"confidence": <float between 0.0 and 1.0>, '
'"entities": [<list of mentioned company/ticker names>]}\n'
"Where sentiment_score: -1.0 = very negative, 0.0 = neutral, 1.0 = very positive.\n"
"Respond with ONLY the JSON object, no other text."
)
class OllamaAnalyzer:
"""Fallback sentiment analyzer using a local Ollama LLM."""
def __init__(self, model: str = "mistral", host: str = "http://localhost:11434") -> None:
self.model = model
self.host = host
self._client: object | None = None
def _get_client(self) -> object:
"""Lazily create the Ollama async client."""
if self._client is None:
import ollama # type: ignore[import-untyped]
self._client = ollama.AsyncClient(host=self.host)
return self._client
async def analyze(self, title: str, content: str) -> tuple[float, float]:
"""Analyze sentiment using the Ollama LLM.
Parameters
----------
title:
Article headline.
content:
Article body text.
Returns
-------
tuple[float, float]
``(sentiment_score, confidence)``. On any parse error or
communication failure, returns ``(0.0, 0.0)`` as a safe fallback.
"""
user_prompt = f"Title: {title}\n\nContent: {content}"
try:
client = self._get_client()
response = await client.chat( # type: ignore[union-attr]
model=self.model,
messages=[
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
)
raw_text: str = response["message"]["content"] # type: ignore[index]
return self._parse_response(raw_text)
except Exception:
logger.exception("Ollama analysis failed")
return 0.0, 0.0
@staticmethod
def _parse_response(raw_text: str) -> tuple[float, float]:
"""Extract sentiment_score and confidence from LLM JSON output.
Robust against markdown code fences and leading/trailing whitespace.
Returns ``(0.0, 0.0)`` on any parsing failure.
"""
try:
# Strip potential markdown code fences.
text = raw_text.strip()
if text.startswith("```"):
# Remove ```json ... ``` wrapper
lines = text.split("\n")
lines = [ln for ln in lines if not ln.strip().startswith("```")]
text = "\n".join(lines).strip()
data = json.loads(text)
score = float(data["sentiment_score"])
confidence = float(data["confidence"])
# Clamp to valid ranges.
score = max(-1.0, min(1.0, score))
confidence = max(0.0, min(1.0, confidence))
return score, confidence
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
logger.warning("Failed to parse Ollama response: %s", raw_text[:200])
return 0.0, 0.0

View file

@ -0,0 +1,15 @@
"""Configuration for the sentiment analyzer service."""
from shared.config import BaseConfig
class SentimentAnalyzerConfig(BaseConfig):
"""Extends BaseConfig with sentiment-analysis-specific settings."""
finbert_model: str = "ProsusAI/finbert"
finbert_confidence_threshold: float = 0.6
ollama_model: str = "mistral"
ollama_host: str = "http://localhost:11434"
max_content_length: int = 512
model_config = {"env_prefix": "TRADING_"}

View file

@ -0,0 +1,169 @@
"""Sentiment Analyzer service — main entry point.
Consumes ``news:raw`` articles from Redis Streams, scores them using a
tiered approach (FinBERT first, Ollama fallback for low-confidence results),
extracts ticker mentions, and publishes ``ScoredArticle`` messages to
``news:scored``.
"""
from __future__ import annotations
import asyncio
import logging
import time
from redis.asyncio import Redis
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
from services.sentiment_analyzer.config import SentimentAnalyzerConfig
from services.sentiment_analyzer.ticker_extractor import extract_tickers
from shared.redis_streams import StreamConsumer, StreamPublisher
from shared.schemas.news import RawArticle, ScoredArticle
from shared.telemetry import setup_telemetry
logger = logging.getLogger(__name__)
async def process_article(
article: RawArticle,
finbert: FinBERTAnalyzer,
ollama: OllamaAnalyzer,
publisher: StreamPublisher,
config: SentimentAnalyzerConfig,
counters: dict,
) -> None:
"""Score a single article and publish one ScoredArticle per extracted ticker.
Parameters
----------
article:
The raw article consumed from the ``news:raw`` stream.
finbert:
FinBERT analyzer instance.
ollama:
Ollama analyzer instance (used as fallback).
publisher:
Publishes results to ``news:scored``.
config:
Service configuration (confidence threshold, etc.).
counters:
Dict of OpenTelemetry counter/histogram instruments.
"""
start = time.monotonic()
# --- Step 1: Run FinBERT ---
score, confidence = await finbert.analyze(article.title, article.content)
model_used = "finbert"
counters["finbert_count"].add(1)
# --- Step 2: Fallback to Ollama if confidence is too low ---
if confidence < config.finbert_confidence_threshold:
logger.info(
"FinBERT confidence %.2f below threshold %.2f — falling back to Ollama",
confidence,
config.finbert_confidence_threshold,
)
score, confidence = await ollama.analyze(article.title, article.content)
model_used = "ollama"
counters["ollama_count"].add(1)
elapsed = time.monotonic() - start
counters["inference_latency"].record(elapsed)
# --- Step 3: Extract tickers ---
combined_text = f"{article.title} {article.content}"
tickers = extract_tickers(combined_text)
if not tickers:
logger.debug("No tickers found in article: %s", article.title[:80])
# Still count the article as scored even if no tickers found.
counters["articles_scored"].add(1)
return
# --- Step 4: Publish one ScoredArticle per ticker ---
for ticker in tickers:
scored = ScoredArticle(
source=article.source,
url=article.url,
title=article.title,
content=article.content,
published_at=article.published_at,
fetched_at=article.fetched_at,
content_hash=article.content_hash,
ticker=ticker,
sentiment_score=score,
confidence=confidence,
model_used=model_used,
entities=tickers,
)
await publisher.publish(scored.model_dump(mode="json"))
logger.debug("Published scored article for %s (score=%.2f)", ticker, score)
counters["articles_scored"].add(1)
async def run(config: SentimentAnalyzerConfig | None = None) -> None:
"""Main service loop.
Connects to Redis, initialises analysers and telemetry, then
continuously consumes from ``news:raw`` and publishes to ``news:scored``.
"""
if config is None:
config = SentimentAnalyzerConfig()
logging.basicConfig(level=config.log_level)
logger.info("Starting Sentiment Analyzer service")
# --- Telemetry ---
meter = setup_telemetry("sentiment-analyzer", config.otel_metrics_port)
counters = {
"articles_scored": meter.create_counter(
"articles_scored",
description="Total articles scored by the sentiment analyzer",
),
"finbert_count": meter.create_counter(
"finbert_count",
description="Number of articles scored by FinBERT",
),
"ollama_count": meter.create_counter(
"ollama_count",
description="Number of articles scored by Ollama (fallback)",
),
"inference_latency": meter.create_histogram(
"inference_latency_seconds",
description="Time spent on sentiment inference per article",
unit="s",
),
}
# --- Redis ---
redis = Redis.from_url(config.redis_url, decode_responses=False)
consumer = StreamConsumer(redis, "news:raw", "sentiment-analyzer", "worker-1")
publisher = StreamPublisher(redis, "news:scored")
# --- Analyzers ---
finbert = FinBERTAnalyzer(
model_name=config.finbert_model,
max_content_length=config.max_content_length,
)
ollama = OllamaAnalyzer(model=config.ollama_model, host=config.ollama_host)
logger.info("Consuming from news:raw, publishing to news:scored")
# --- Consume loop ---
async for _msg_id, data in consumer.consume():
try:
article = RawArticle.model_validate(data)
await process_article(article, finbert, ollama, publisher, config, counters)
except Exception:
logger.exception("Error processing article: %s", data.get("title", "<unknown>"))
def main() -> None:
"""CLI entry point."""
asyncio.run(run())
if __name__ == "__main__":
main()

View file

@ -0,0 +1,163 @@
"""Extract stock ticker symbols from free-form text.
Handles common formats:
- Dollar-prefixed: ``$AAPL``
- Exchange-prefixed: ``NASDAQ:AAPL``, ``NYSE:TSLA``
- Standalone uppercase words that look like tickers (1-5 uppercase letters)
"""
from __future__ import annotations
import re
# Common false positives: short English words, financial abbreviations, and
# exchange names that match the 1-5 uppercase letter pattern.
_FALSE_POSITIVES: frozenset[str] = frozenset(
{
# Common English words / pronouns
"A",
"I",
"AM",
"AN",
"AS",
"AT",
"BE",
"BY",
"DO",
"GO",
"IF",
"IN",
"IS",
"IT",
"ME",
"MY",
"NO",
"OF",
"OK",
"ON",
"OR",
"SO",
"TO",
"UP",
"US",
"WE",
"PM",
"THE",
"AND",
"FOR",
"NOT",
"BUT",
"ARE",
"WAS",
"HAS",
"HAD",
"ALL",
"CAN",
"HER",
"HIS",
"HOW",
"ITS",
"MAY",
"NEW",
"NOW",
"OLD",
"OUR",
"OUT",
"OWN",
"SAY",
"SHE",
"TOO",
"USE",
# Time-related
"EST",
"PST",
"CST",
"MST",
"UTC",
"GMT",
# Financial jargon
"CEO",
"CFO",
"COO",
"CTO",
"IPO",
"ETF",
"SEC",
"NYSE",
"AMEX",
"DJIA",
"GDP",
"CPI",
"FED",
"FOMC",
"FDA",
"EPS",
"P&L",
"ROI",
"YTD",
"QOQ",
"YOY",
"ATH",
"ATL",
"RSI",
"SMA",
"EMA",
"IOT",
"API",
"AI",
"ML",
"US",
"USA",
"UK",
"EU",
"IMF",
"FTC",
"DOJ",
"IRS",
"DOT",
"SPAC",
}
)
# Pattern 1: $AAPL (dollar-sign prefix)
_DOLLAR_PATTERN = re.compile(r"\$([A-Z]{1,5})\b")
# Pattern 2: NASDAQ:AAPL, NYSE:TSLA (exchange prefix)
_EXCHANGE_PATTERN = re.compile(r"\b(?:NASDAQ|NYSE|AMEX|OTC|BATS|ARCA):([A-Z]{1,5})\b")
# Pattern 3: standalone uppercase words at word boundaries (1-5 chars)
_STANDALONE_PATTERN = re.compile(r"\b([A-Z]{1,5})\b")
def extract_tickers(text: str) -> list[str]:
"""Extract deduplicated stock ticker symbols from *text*.
Returns a list of unique ticker strings in the order they were first
encountered. False positives (common English words, acronyms) are
filtered out.
"""
seen: set[str] = set()
result: list[str] = []
def _add(ticker: str) -> None:
if ticker not in seen and ticker not in _FALSE_POSITIVES:
seen.add(ticker)
result.append(ticker)
# Dollar-sign tickers have the highest signal — always include.
for match in _DOLLAR_PATTERN.finditer(text):
_add(match.group(1))
# Exchange-prefixed tickers are also high confidence.
for match in _EXCHANGE_PATTERN.finditer(text):
_add(match.group(1))
# Standalone uppercase words: only include if they look like real tickers
# (not in the false positives list). We restrict to 2-5 chars to reduce
# noise, unless they were already captured by the dollar/exchange patterns.
for match in _STANDALONE_PATTERN.finditer(text):
candidate = match.group(1)
if len(candidate) >= 2:
_add(candidate)
return result

View file

View file

@ -0,0 +1,411 @@
"""Tests for the sentiment analyzer service.
Covers FinBERT analyzer, Ollama analyzer, ticker extraction, and the main
service flow.
"""
from __future__ import annotations
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
from services.sentiment_analyzer.config import SentimentAnalyzerConfig
from services.sentiment_analyzer.main import process_article
from services.sentiment_analyzer.ticker_extractor import extract_tickers
from shared.schemas.news import RawArticle
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_raw_article(**overrides) -> RawArticle:
defaults = {
"source": "test",
"url": "https://example.com/article",
"title": "Test Article About $AAPL",
"content": "Apple Inc announced strong earnings.",
"published_at": datetime(2026, 1, 15, tzinfo=timezone.utc),
"fetched_at": datetime(2026, 1, 15, 0, 5, tzinfo=timezone.utc),
"content_hash": "abc123",
}
defaults.update(overrides)
return RawArticle(**defaults)
def _make_pipeline_result(label: str, score: float) -> list[list[dict]]:
"""Build a return value matching transformers pipeline(return_all_scores=True)."""
labels = {"positive": score if label == "positive" else 0.0,
"negative": score if label == "negative" else 0.0,
"neutral": score if label == "neutral" else 0.0}
# Distribute remaining probability
remaining = 1.0 - score
other_labels = [l for l in labels if l != label]
for ol in other_labels:
labels[ol] = remaining / len(other_labels)
return [[{"label": l, "score": s} for l, s in labels.items()]]
# ---------------------------------------------------------------------------
# FinBERT Analyzer Tests
# ---------------------------------------------------------------------------
class TestFinBERTAnalyzer:
"""Tests for FinBERTAnalyzer with a mocked transformers pipeline."""
@pytest.mark.asyncio
async def test_finbert_positive_sentiment(self):
"""Positive article should yield a positive score."""
mock_pipe = MagicMock()
mock_pipe.return_value = _make_pipeline_result("positive", 0.9)
analyzer = FinBERTAnalyzer(model_name="test-model")
analyzer._pipeline = mock_pipe
score, confidence = await analyzer.analyze(
"Apple beats earnings expectations",
"Apple reported revenue above analyst estimates.",
)
assert score > 0.0, f"Expected positive score, got {score}"
assert confidence == pytest.approx(0.9, abs=0.01)
mock_pipe.assert_called_once()
@pytest.mark.asyncio
async def test_finbert_negative_sentiment(self):
"""Negative article should yield a negative score."""
mock_pipe = MagicMock()
mock_pipe.return_value = _make_pipeline_result("negative", 0.85)
analyzer = FinBERTAnalyzer(model_name="test-model")
analyzer._pipeline = mock_pipe
score, confidence = await analyzer.analyze(
"Major bank reports massive losses",
"The bank lost $2 billion in the quarter.",
)
assert score < 0.0, f"Expected negative score, got {score}"
assert confidence == pytest.approx(0.85, abs=0.01)
@pytest.mark.asyncio
async def test_finbert_neutral_sentiment(self):
"""Neutral article should yield a near-zero score."""
mock_pipe = MagicMock()
mock_pipe.return_value = _make_pipeline_result("neutral", 0.8)
analyzer = FinBERTAnalyzer(model_name="test-model")
analyzer._pipeline = mock_pipe
score, confidence = await analyzer.analyze(
"Company releases quarterly report",
"The quarterly report was filed with the SEC.",
)
# Neutral dominant => score close to zero (neutral maps to 0).
# The small residual comes from the remaining probability split
# between positive and negative.
assert abs(score) < 0.2, f"Expected near-zero score, got {score}"
assert confidence == pytest.approx(0.8, abs=0.01)
# ---------------------------------------------------------------------------
# Ollama Analyzer Tests
# ---------------------------------------------------------------------------
class TestOllamaAnalyzer:
"""Tests for OllamaAnalyzer with a mocked ollama client."""
@pytest.mark.asyncio
async def test_ollama_successful_analysis(self):
"""Valid JSON response should be parsed correctly."""
mock_client = AsyncMock()
mock_client.chat.return_value = {
"message": {
"content": '{"sentiment_score": 0.75, "confidence": 0.85, "entities": ["AAPL"]}'
}
}
analyzer = OllamaAnalyzer(model="test-model")
analyzer._client = mock_client
score, confidence = await analyzer.analyze("Good news for Apple", "Apple stock surges.")
assert score == pytest.approx(0.75)
assert confidence == pytest.approx(0.85)
@pytest.mark.asyncio
async def test_ollama_parse_error_returns_zero(self):
"""Invalid JSON should return (0.0, 0.0) fallback."""
mock_client = AsyncMock()
mock_client.chat.return_value = {
"message": {"content": "I think the sentiment is positive but I'm not sure."}
}
analyzer = OllamaAnalyzer(model="test-model")
analyzer._client = mock_client
score, confidence = await analyzer.analyze("Some headline", "Some content")
assert score == 0.0
assert confidence == 0.0
@pytest.mark.asyncio
async def test_ollama_connection_error_returns_zero(self):
"""Network/connection errors should return (0.0, 0.0) fallback."""
mock_client = AsyncMock()
mock_client.chat.side_effect = ConnectionError("Cannot reach Ollama")
analyzer = OllamaAnalyzer(model="test-model")
analyzer._client = mock_client
score, confidence = await analyzer.analyze("Some headline", "Some content")
assert score == 0.0
assert confidence == 0.0
@pytest.mark.asyncio
async def test_ollama_markdown_code_fence(self):
"""JSON wrapped in markdown code fences should still be parsed."""
mock_client = AsyncMock()
mock_client.chat.return_value = {
"message": {
"content": '```json\n{"sentiment_score": -0.5, "confidence": 0.7, "entities": []}\n```'
}
}
analyzer = OllamaAnalyzer(model="test-model")
analyzer._client = mock_client
score, confidence = await analyzer.analyze("Bad news", "Markets tumble.")
assert score == pytest.approx(-0.5)
assert confidence == pytest.approx(0.7)
# ---------------------------------------------------------------------------
# Ticker Extraction Tests
# ---------------------------------------------------------------------------
class TestTickerExtraction:
"""Tests for the ticker extraction utility."""
def test_ticker_extraction_dollar_sign(self):
"""$AAPL should extract AAPL."""
tickers = extract_tickers("Big news for $AAPL today.")
assert "AAPL" in tickers
def test_ticker_extraction_exchange_prefix(self):
"""NASDAQ:TSLA should extract TSLA."""
tickers = extract_tickers("Check out NASDAQ:TSLA performance.")
assert "TSLA" in tickers
def test_ticker_extraction_nyse_prefix(self):
"""NYSE:AAPL should extract AAPL."""
tickers = extract_tickers("NYSE:AAPL is trading higher.")
assert "AAPL" in tickers
def test_ticker_extraction_filters_false_positives(self):
"""Common words like CEO, IPO, ETF, SEC, NYSE should be filtered."""
tickers = extract_tickers(
"The CEO announced a new IPO. The ETF was approved by the SEC on NYSE."
)
assert "CEO" not in tickers
assert "IPO" not in tickers
assert "ETF" not in tickers
assert "SEC" not in tickers
assert "NYSE" not in tickers
def test_ticker_extraction_deduplicates(self):
"""Repeated mentions of the same ticker should appear only once."""
tickers = extract_tickers("$AAPL rose 5%. $AAPL is now above $200. NASDAQ:AAPL is great.")
assert tickers.count("AAPL") == 1
def test_ticker_extraction_multiple_tickers(self):
"""Multiple different tickers should all be extracted."""
tickers = extract_tickers("$AAPL and $MSFT both reported earnings. $GOOG is next.")
assert "AAPL" in tickers
assert "MSFT" in tickers
assert "GOOG" in tickers
def test_ticker_extraction_empty_text(self):
"""Empty text should return no tickers."""
assert extract_tickers("") == []
def test_ticker_extraction_no_tickers(self):
"""Text with no ticker-like patterns should return empty list."""
tickers = extract_tickers("The market was flat today with no major movers.")
# Should be empty — all uppercase words are filtered as false positives or too short.
assert len(tickers) == 0
# ---------------------------------------------------------------------------
# Ollama Fallback Routing Test
# ---------------------------------------------------------------------------
class TestFallbackRouting:
"""Test that Ollama is called when FinBERT confidence is below threshold."""
@pytest.mark.asyncio
async def test_ollama_fallback_on_low_confidence(self):
"""When FinBERT confidence < threshold, Ollama should be called."""
# FinBERT returns low confidence
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.1, 0.4)) # confidence 0.4 < 0.6 threshold
# Ollama returns higher confidence
ollama = AsyncMock(spec=OllamaAnalyzer)
ollama.analyze = AsyncMock(return_value=(0.8, 0.9))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
# Mock counters
counters = {
"articles_scored": MagicMock(),
"finbert_count": MagicMock(),
"ollama_count": MagicMock(),
"inference_latency": MagicMock(),
}
article = _make_raw_article(title="Test $AAPL Article", content="Apple stock rises.")
await process_article(article, finbert, ollama, publisher, config, counters)
# Both should have been called
finbert.analyze.assert_called_once()
ollama.analyze.assert_called_once()
counters["finbert_count"].add.assert_called_once_with(1)
counters["ollama_count"].add.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_no_ollama_on_high_confidence(self):
"""When FinBERT confidence >= threshold, Ollama should NOT be called."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.8, 0.9)) # confidence 0.9 >= 0.6
ollama = AsyncMock(spec=OllamaAnalyzer)
ollama.analyze = AsyncMock(return_value=(0.5, 0.7))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = {
"articles_scored": MagicMock(),
"finbert_count": MagicMock(),
"ollama_count": MagicMock(),
"inference_latency": MagicMock(),
}
article = _make_raw_article(title="Test $AAPL Article", content="Apple stock rises.")
await process_article(article, finbert, ollama, publisher, config, counters)
finbert.analyze.assert_called_once()
ollama.analyze.assert_not_called()
counters["ollama_count"].add.assert_not_called()
# ---------------------------------------------------------------------------
# Main Flow / Integration Test
# ---------------------------------------------------------------------------
class TestMainFlow:
"""Test the full process_article flow with mocked analyzers and Redis."""
@pytest.mark.asyncio
async def test_main_flow_publishes_scored_articles(self):
"""process_article should publish a ScoredArticle for each ticker found."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.75, 0.88))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = {
"articles_scored": MagicMock(),
"finbert_count": MagicMock(),
"ollama_count": MagicMock(),
"inference_latency": MagicMock(),
}
# Article mentions two tickers
article = _make_raw_article(
title="$AAPL and $MSFT report strong earnings",
content="Both Apple and Microsoft beat estimates.",
)
await process_article(article, finbert, ollama, publisher, config, counters)
# Should publish one ScoredArticle per ticker
assert publisher.publish.call_count == 2
counters["articles_scored"].add.assert_called_once_with(1)
# Verify the published data
calls = publisher.publish.call_args_list
published_tickers = {call.args[0]["ticker"] for call in calls}
assert "AAPL" in published_tickers
assert "MSFT" in published_tickers
# Each published message should have the correct sentiment score
for call in calls:
data = call.args[0]
assert data["sentiment_score"] == pytest.approx(0.75)
assert data["confidence"] == pytest.approx(0.88)
assert data["model_used"] == "finbert"
@pytest.mark.asyncio
async def test_main_flow_no_tickers_no_publish(self):
"""Articles with no tickers should not publish anything."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.5, 0.9))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock()
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = {
"articles_scored": MagicMock(),
"finbert_count": MagicMock(),
"ollama_count": MagicMock(),
"inference_latency": MagicMock(),
}
article = _make_raw_article(
title="Market is flat today",
content="Nothing much happening in the market.",
)
await process_article(article, finbert, ollama, publisher, config, counters)
publisher.publish.assert_not_called()
# Still counted as scored
counters["articles_scored"].add.assert_called_once_with(1)