feat: sentiment analyzer — FinBERT + Ollama tiered analysis
This commit is contained in:
parent
9f46071502
commit
6952a829ae
11 changed files with 976 additions and 1 deletions
|
|
@ -27,7 +27,7 @@ requires = ["setuptools>=70.0"]
|
|||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["shared*", "tests*"]
|
||||
include = ["shared*", "services*", "tests*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
|
|
|
|||
0
services/__init__.py
Normal file
0
services/__init__.py
Normal file
1
services/sentiment_analyzer/__init__.py
Normal file
1
services/sentiment_analyzer/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Sentiment Analyzer service — FinBERT + Ollama tiered analysis."""
|
||||
1
services/sentiment_analyzer/analyzers/__init__.py
Normal file
1
services/sentiment_analyzer/analyzers/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Sentiment analysis backends (FinBERT, Ollama)."""
|
||||
113
services/sentiment_analyzer/analyzers/finbert.py
Normal file
113
services/sentiment_analyzer/analyzers/finbert.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""FinBERT-based financial sentiment analyzer.
|
||||
|
||||
Uses the ProsusAI/finbert model via the HuggingFace transformers library
|
||||
to classify article text as positive, negative, or neutral.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FinBERTAnalyzer:
|
||||
"""Lazy-loading wrapper around a transformers sentiment-analysis pipeline.
|
||||
|
||||
The heavy ``transformers`` + ``torch`` imports and model download happen
|
||||
only once, on the first call to :meth:`analyze`.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "ProsusAI/finbert", max_content_length: int = 512) -> None:
|
||||
self.model_name = model_name
|
||||
self.max_content_length = max_content_length
|
||||
self._pipeline: Any | None = None
|
||||
|
||||
def _load_pipeline(self) -> Any:
|
||||
"""Lazily load the transformers pipeline on first use."""
|
||||
if self._pipeline is None:
|
||||
from transformers import pipeline # type: ignore[import-untyped]
|
||||
|
||||
logger.info("Loading FinBERT model: %s", self.model_name)
|
||||
self._pipeline = pipeline(
|
||||
"sentiment-analysis",
|
||||
model=self.model_name,
|
||||
return_all_scores=True,
|
||||
)
|
||||
logger.info("FinBERT model loaded successfully")
|
||||
return self._pipeline
|
||||
|
||||
async def analyze(self, title: str, content: str) -> tuple[float, float]:
|
||||
"""Score the sentiment of an article.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
title:
|
||||
Article headline.
|
||||
content:
|
||||
Article body text.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[float, float]
|
||||
``(sentiment_score, confidence)`` where *sentiment_score* is in
|
||||
``[-1.0, 1.0]`` and *confidence* is in ``[0.0, 1.0]``.
|
||||
|
||||
The input text is truncated to ``max_content_length`` tokens by
|
||||
passing it through the model's tokenizer truncation (handled
|
||||
automatically by the transformers pipeline).
|
||||
"""
|
||||
pipe = self._load_pipeline()
|
||||
|
||||
# Combine title and content; the pipeline will truncate to model max tokens.
|
||||
text = f"{title}. {content}"
|
||||
# Truncate to a reasonable character length proportional to max_content_length
|
||||
# tokens (rough estimate: 1 token ~ 4 chars for English). The pipeline
|
||||
# tokenizer will do the precise truncation, but this avoids sending
|
||||
# enormous strings.
|
||||
char_limit = self.max_content_length * 4
|
||||
text = text[:char_limit]
|
||||
|
||||
# Run the blocking model inference in a thread pool so we don't block
|
||||
# the event loop.
|
||||
loop = asyncio.get_running_loop()
|
||||
results = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: pipe(text, truncation=True, max_length=self.max_content_length),
|
||||
)
|
||||
|
||||
return self._parse_scores(results)
|
||||
|
||||
@staticmethod
|
||||
def _parse_scores(results: list[list[dict[str, Any]]]) -> tuple[float, float]:
|
||||
"""Map pipeline output to ``(score, confidence)``.
|
||||
|
||||
The ``return_all_scores=True`` pipeline returns a list of lists of dicts:
|
||||
``[[{"label": "positive", "score": 0.85}, ...]]``.
|
||||
|
||||
Mapping:
|
||||
- ``"positive"`` -> +1
|
||||
- ``"negative"`` -> -1
|
||||
- ``"neutral"`` -> 0
|
||||
|
||||
The sentiment score is the weighted sum of label polarities scaled by
|
||||
their softmax probabilities. Confidence is the maximum softmax
|
||||
probability.
|
||||
"""
|
||||
label_map = {"positive": 1.0, "negative": -1.0, "neutral": 0.0}
|
||||
|
||||
# results is [[{label, score}, ...]]
|
||||
scores = results[0]
|
||||
|
||||
sentiment_score = 0.0
|
||||
confidence = 0.0
|
||||
for entry in scores:
|
||||
label = entry["label"].lower()
|
||||
prob = entry["score"]
|
||||
sentiment_score += label_map.get(label, 0.0) * prob
|
||||
if prob > confidence:
|
||||
confidence = prob
|
||||
|
||||
return sentiment_score, confidence
|
||||
102
services/sentiment_analyzer/analyzers/ollama_analyzer.py
Normal file
102
services/sentiment_analyzer/analyzers/ollama_analyzer.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""Ollama-based sentiment analyzer (LLM fallback).
|
||||
|
||||
Used when the FinBERT model's confidence is below the configured threshold.
|
||||
Sends a structured prompt to a local Ollama instance and parses JSON output.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a financial sentiment analysis assistant. "
|
||||
"You will be given a news article title and content. "
|
||||
"Analyze the sentiment and respond with ONLY valid JSON in this exact format:\n"
|
||||
'{"sentiment_score": <float between -1.0 and 1.0>, '
|
||||
'"confidence": <float between 0.0 and 1.0>, '
|
||||
'"entities": [<list of mentioned company/ticker names>]}\n'
|
||||
"Where sentiment_score: -1.0 = very negative, 0.0 = neutral, 1.0 = very positive.\n"
|
||||
"Respond with ONLY the JSON object, no other text."
|
||||
)
|
||||
|
||||
|
||||
class OllamaAnalyzer:
|
||||
"""Fallback sentiment analyzer using a local Ollama LLM."""
|
||||
|
||||
def __init__(self, model: str = "mistral", host: str = "http://localhost:11434") -> None:
|
||||
self.model = model
|
||||
self.host = host
|
||||
self._client: object | None = None
|
||||
|
||||
def _get_client(self) -> object:
|
||||
"""Lazily create the Ollama async client."""
|
||||
if self._client is None:
|
||||
import ollama # type: ignore[import-untyped]
|
||||
|
||||
self._client = ollama.AsyncClient(host=self.host)
|
||||
return self._client
|
||||
|
||||
async def analyze(self, title: str, content: str) -> tuple[float, float]:
|
||||
"""Analyze sentiment using the Ollama LLM.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
title:
|
||||
Article headline.
|
||||
content:
|
||||
Article body text.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[float, float]
|
||||
``(sentiment_score, confidence)``. On any parse error or
|
||||
communication failure, returns ``(0.0, 0.0)`` as a safe fallback.
|
||||
"""
|
||||
user_prompt = f"Title: {title}\n\nContent: {content}"
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
response = await client.chat( # type: ignore[union-attr]
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
)
|
||||
raw_text: str = response["message"]["content"] # type: ignore[index]
|
||||
return self._parse_response(raw_text)
|
||||
except Exception:
|
||||
logger.exception("Ollama analysis failed")
|
||||
return 0.0, 0.0
|
||||
|
||||
@staticmethod
|
||||
def _parse_response(raw_text: str) -> tuple[float, float]:
|
||||
"""Extract sentiment_score and confidence from LLM JSON output.
|
||||
|
||||
Robust against markdown code fences and leading/trailing whitespace.
|
||||
Returns ``(0.0, 0.0)`` on any parsing failure.
|
||||
"""
|
||||
try:
|
||||
# Strip potential markdown code fences.
|
||||
text = raw_text.strip()
|
||||
if text.startswith("```"):
|
||||
# Remove ```json ... ``` wrapper
|
||||
lines = text.split("\n")
|
||||
lines = [ln for ln in lines if not ln.strip().startswith("```")]
|
||||
text = "\n".join(lines).strip()
|
||||
|
||||
data = json.loads(text)
|
||||
score = float(data["sentiment_score"])
|
||||
confidence = float(data["confidence"])
|
||||
|
||||
# Clamp to valid ranges.
|
||||
score = max(-1.0, min(1.0, score))
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
|
||||
return score, confidence
|
||||
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
|
||||
logger.warning("Failed to parse Ollama response: %s", raw_text[:200])
|
||||
return 0.0, 0.0
|
||||
15
services/sentiment_analyzer/config.py
Normal file
15
services/sentiment_analyzer/config.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""Configuration for the sentiment analyzer service."""
|
||||
|
||||
from shared.config import BaseConfig
|
||||
|
||||
|
||||
class SentimentAnalyzerConfig(BaseConfig):
|
||||
"""Extends BaseConfig with sentiment-analysis-specific settings."""
|
||||
|
||||
finbert_model: str = "ProsusAI/finbert"
|
||||
finbert_confidence_threshold: float = 0.6
|
||||
ollama_model: str = "mistral"
|
||||
ollama_host: str = "http://localhost:11434"
|
||||
max_content_length: int = 512
|
||||
|
||||
model_config = {"env_prefix": "TRADING_"}
|
||||
169
services/sentiment_analyzer/main.py
Normal file
169
services/sentiment_analyzer/main.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
"""Sentiment Analyzer service — main entry point.
|
||||
|
||||
Consumes ``news:raw`` articles from Redis Streams, scores them using a
|
||||
tiered approach (FinBERT first, Ollama fallback for low-confidence results),
|
||||
extracts ticker mentions, and publishes ``ScoredArticle`` messages to
|
||||
``news:scored``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
|
||||
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
|
||||
from services.sentiment_analyzer.config import SentimentAnalyzerConfig
|
||||
from services.sentiment_analyzer.ticker_extractor import extract_tickers
|
||||
from shared.redis_streams import StreamConsumer, StreamPublisher
|
||||
from shared.schemas.news import RawArticle, ScoredArticle
|
||||
from shared.telemetry import setup_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def process_article(
|
||||
article: RawArticle,
|
||||
finbert: FinBERTAnalyzer,
|
||||
ollama: OllamaAnalyzer,
|
||||
publisher: StreamPublisher,
|
||||
config: SentimentAnalyzerConfig,
|
||||
counters: dict,
|
||||
) -> None:
|
||||
"""Score a single article and publish one ScoredArticle per extracted ticker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
article:
|
||||
The raw article consumed from the ``news:raw`` stream.
|
||||
finbert:
|
||||
FinBERT analyzer instance.
|
||||
ollama:
|
||||
Ollama analyzer instance (used as fallback).
|
||||
publisher:
|
||||
Publishes results to ``news:scored``.
|
||||
config:
|
||||
Service configuration (confidence threshold, etc.).
|
||||
counters:
|
||||
Dict of OpenTelemetry counter/histogram instruments.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
# --- Step 1: Run FinBERT ---
|
||||
score, confidence = await finbert.analyze(article.title, article.content)
|
||||
model_used = "finbert"
|
||||
counters["finbert_count"].add(1)
|
||||
|
||||
# --- Step 2: Fallback to Ollama if confidence is too low ---
|
||||
if confidence < config.finbert_confidence_threshold:
|
||||
logger.info(
|
||||
"FinBERT confidence %.2f below threshold %.2f — falling back to Ollama",
|
||||
confidence,
|
||||
config.finbert_confidence_threshold,
|
||||
)
|
||||
score, confidence = await ollama.analyze(article.title, article.content)
|
||||
model_used = "ollama"
|
||||
counters["ollama_count"].add(1)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
counters["inference_latency"].record(elapsed)
|
||||
|
||||
# --- Step 3: Extract tickers ---
|
||||
combined_text = f"{article.title} {article.content}"
|
||||
tickers = extract_tickers(combined_text)
|
||||
|
||||
if not tickers:
|
||||
logger.debug("No tickers found in article: %s", article.title[:80])
|
||||
# Still count the article as scored even if no tickers found.
|
||||
counters["articles_scored"].add(1)
|
||||
return
|
||||
|
||||
# --- Step 4: Publish one ScoredArticle per ticker ---
|
||||
for ticker in tickers:
|
||||
scored = ScoredArticle(
|
||||
source=article.source,
|
||||
url=article.url,
|
||||
title=article.title,
|
||||
content=article.content,
|
||||
published_at=article.published_at,
|
||||
fetched_at=article.fetched_at,
|
||||
content_hash=article.content_hash,
|
||||
ticker=ticker,
|
||||
sentiment_score=score,
|
||||
confidence=confidence,
|
||||
model_used=model_used,
|
||||
entities=tickers,
|
||||
)
|
||||
await publisher.publish(scored.model_dump(mode="json"))
|
||||
logger.debug("Published scored article for %s (score=%.2f)", ticker, score)
|
||||
|
||||
counters["articles_scored"].add(1)
|
||||
|
||||
|
||||
async def run(config: SentimentAnalyzerConfig | None = None) -> None:
|
||||
"""Main service loop.
|
||||
|
||||
Connects to Redis, initialises analysers and telemetry, then
|
||||
continuously consumes from ``news:raw`` and publishes to ``news:scored``.
|
||||
"""
|
||||
if config is None:
|
||||
config = SentimentAnalyzerConfig()
|
||||
|
||||
logging.basicConfig(level=config.log_level)
|
||||
logger.info("Starting Sentiment Analyzer service")
|
||||
|
||||
# --- Telemetry ---
|
||||
meter = setup_telemetry("sentiment-analyzer", config.otel_metrics_port)
|
||||
counters = {
|
||||
"articles_scored": meter.create_counter(
|
||||
"articles_scored",
|
||||
description="Total articles scored by the sentiment analyzer",
|
||||
),
|
||||
"finbert_count": meter.create_counter(
|
||||
"finbert_count",
|
||||
description="Number of articles scored by FinBERT",
|
||||
),
|
||||
"ollama_count": meter.create_counter(
|
||||
"ollama_count",
|
||||
description="Number of articles scored by Ollama (fallback)",
|
||||
),
|
||||
"inference_latency": meter.create_histogram(
|
||||
"inference_latency_seconds",
|
||||
description="Time spent on sentiment inference per article",
|
||||
unit="s",
|
||||
),
|
||||
}
|
||||
|
||||
# --- Redis ---
|
||||
redis = Redis.from_url(config.redis_url, decode_responses=False)
|
||||
consumer = StreamConsumer(redis, "news:raw", "sentiment-analyzer", "worker-1")
|
||||
publisher = StreamPublisher(redis, "news:scored")
|
||||
|
||||
# --- Analyzers ---
|
||||
finbert = FinBERTAnalyzer(
|
||||
model_name=config.finbert_model,
|
||||
max_content_length=config.max_content_length,
|
||||
)
|
||||
ollama = OllamaAnalyzer(model=config.ollama_model, host=config.ollama_host)
|
||||
|
||||
logger.info("Consuming from news:raw, publishing to news:scored")
|
||||
|
||||
# --- Consume loop ---
|
||||
async for _msg_id, data in consumer.consume():
|
||||
try:
|
||||
article = RawArticle.model_validate(data)
|
||||
await process_article(article, finbert, ollama, publisher, config, counters)
|
||||
except Exception:
|
||||
logger.exception("Error processing article: %s", data.get("title", "<unknown>"))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""CLI entry point."""
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
163
services/sentiment_analyzer/ticker_extractor.py
Normal file
163
services/sentiment_analyzer/ticker_extractor.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
"""Extract stock ticker symbols from free-form text.
|
||||
|
||||
Handles common formats:
|
||||
- Dollar-prefixed: ``$AAPL``
|
||||
- Exchange-prefixed: ``NASDAQ:AAPL``, ``NYSE:TSLA``
|
||||
- Standalone uppercase words that look like tickers (1-5 uppercase letters)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
# Common false positives: short English words, financial abbreviations, and
|
||||
# exchange names that match the 1-5 uppercase letter pattern.
|
||||
_FALSE_POSITIVES: frozenset[str] = frozenset(
|
||||
{
|
||||
# Common English words / pronouns
|
||||
"A",
|
||||
"I",
|
||||
"AM",
|
||||
"AN",
|
||||
"AS",
|
||||
"AT",
|
||||
"BE",
|
||||
"BY",
|
||||
"DO",
|
||||
"GO",
|
||||
"IF",
|
||||
"IN",
|
||||
"IS",
|
||||
"IT",
|
||||
"ME",
|
||||
"MY",
|
||||
"NO",
|
||||
"OF",
|
||||
"OK",
|
||||
"ON",
|
||||
"OR",
|
||||
"SO",
|
||||
"TO",
|
||||
"UP",
|
||||
"US",
|
||||
"WE",
|
||||
"PM",
|
||||
"THE",
|
||||
"AND",
|
||||
"FOR",
|
||||
"NOT",
|
||||
"BUT",
|
||||
"ARE",
|
||||
"WAS",
|
||||
"HAS",
|
||||
"HAD",
|
||||
"ALL",
|
||||
"CAN",
|
||||
"HER",
|
||||
"HIS",
|
||||
"HOW",
|
||||
"ITS",
|
||||
"MAY",
|
||||
"NEW",
|
||||
"NOW",
|
||||
"OLD",
|
||||
"OUR",
|
||||
"OUT",
|
||||
"OWN",
|
||||
"SAY",
|
||||
"SHE",
|
||||
"TOO",
|
||||
"USE",
|
||||
# Time-related
|
||||
"EST",
|
||||
"PST",
|
||||
"CST",
|
||||
"MST",
|
||||
"UTC",
|
||||
"GMT",
|
||||
# Financial jargon
|
||||
"CEO",
|
||||
"CFO",
|
||||
"COO",
|
||||
"CTO",
|
||||
"IPO",
|
||||
"ETF",
|
||||
"SEC",
|
||||
"NYSE",
|
||||
"AMEX",
|
||||
"DJIA",
|
||||
"GDP",
|
||||
"CPI",
|
||||
"FED",
|
||||
"FOMC",
|
||||
"FDA",
|
||||
"EPS",
|
||||
"P&L",
|
||||
"ROI",
|
||||
"YTD",
|
||||
"QOQ",
|
||||
"YOY",
|
||||
"ATH",
|
||||
"ATL",
|
||||
"RSI",
|
||||
"SMA",
|
||||
"EMA",
|
||||
"IOT",
|
||||
"API",
|
||||
"AI",
|
||||
"ML",
|
||||
"US",
|
||||
"USA",
|
||||
"UK",
|
||||
"EU",
|
||||
"IMF",
|
||||
"FTC",
|
||||
"DOJ",
|
||||
"IRS",
|
||||
"DOT",
|
||||
"SPAC",
|
||||
}
|
||||
)
|
||||
|
||||
# Pattern 1: $AAPL (dollar-sign prefix)
|
||||
_DOLLAR_PATTERN = re.compile(r"\$([A-Z]{1,5})\b")
|
||||
|
||||
# Pattern 2: NASDAQ:AAPL, NYSE:TSLA (exchange prefix)
|
||||
_EXCHANGE_PATTERN = re.compile(r"\b(?:NASDAQ|NYSE|AMEX|OTC|BATS|ARCA):([A-Z]{1,5})\b")
|
||||
|
||||
# Pattern 3: standalone uppercase words at word boundaries (1-5 chars)
|
||||
_STANDALONE_PATTERN = re.compile(r"\b([A-Z]{1,5})\b")
|
||||
|
||||
|
||||
def extract_tickers(text: str) -> list[str]:
|
||||
"""Extract deduplicated stock ticker symbols from *text*.
|
||||
|
||||
Returns a list of unique ticker strings in the order they were first
|
||||
encountered. False positives (common English words, acronyms) are
|
||||
filtered out.
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
|
||||
def _add(ticker: str) -> None:
|
||||
if ticker not in seen and ticker not in _FALSE_POSITIVES:
|
||||
seen.add(ticker)
|
||||
result.append(ticker)
|
||||
|
||||
# Dollar-sign tickers have the highest signal — always include.
|
||||
for match in _DOLLAR_PATTERN.finditer(text):
|
||||
_add(match.group(1))
|
||||
|
||||
# Exchange-prefixed tickers are also high confidence.
|
||||
for match in _EXCHANGE_PATTERN.finditer(text):
|
||||
_add(match.group(1))
|
||||
|
||||
# Standalone uppercase words: only include if they look like real tickers
|
||||
# (not in the false positives list). We restrict to 2-5 chars to reduce
|
||||
# noise, unless they were already captured by the dollar/exchange patterns.
|
||||
for match in _STANDALONE_PATTERN.finditer(text):
|
||||
candidate = match.group(1)
|
||||
if len(candidate) >= 2:
|
||||
_add(candidate)
|
||||
|
||||
return result
|
||||
0
tests/services/__init__.py
Normal file
0
tests/services/__init__.py
Normal file
411
tests/services/test_sentiment_analyzer.py
Normal file
411
tests/services/test_sentiment_analyzer.py
Normal file
|
|
@ -0,0 +1,411 @@
|
|||
"""Tests for the sentiment analyzer service.
|
||||
|
||||
Covers FinBERT analyzer, Ollama analyzer, ticker extraction, and the main
|
||||
service flow.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
|
||||
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
|
||||
from services.sentiment_analyzer.config import SentimentAnalyzerConfig
|
||||
from services.sentiment_analyzer.main import process_article
|
||||
from services.sentiment_analyzer.ticker_extractor import extract_tickers
|
||||
from shared.schemas.news import RawArticle
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_raw_article(**overrides) -> RawArticle:
|
||||
defaults = {
|
||||
"source": "test",
|
||||
"url": "https://example.com/article",
|
||||
"title": "Test Article About $AAPL",
|
||||
"content": "Apple Inc announced strong earnings.",
|
||||
"published_at": datetime(2026, 1, 15, tzinfo=timezone.utc),
|
||||
"fetched_at": datetime(2026, 1, 15, 0, 5, tzinfo=timezone.utc),
|
||||
"content_hash": "abc123",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return RawArticle(**defaults)
|
||||
|
||||
|
||||
def _make_pipeline_result(label: str, score: float) -> list[list[dict]]:
|
||||
"""Build a return value matching transformers pipeline(return_all_scores=True)."""
|
||||
labels = {"positive": score if label == "positive" else 0.0,
|
||||
"negative": score if label == "negative" else 0.0,
|
||||
"neutral": score if label == "neutral" else 0.0}
|
||||
# Distribute remaining probability
|
||||
remaining = 1.0 - score
|
||||
other_labels = [l for l in labels if l != label]
|
||||
for ol in other_labels:
|
||||
labels[ol] = remaining / len(other_labels)
|
||||
|
||||
return [[{"label": l, "score": s} for l, s in labels.items()]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FinBERT Analyzer Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFinBERTAnalyzer:
|
||||
"""Tests for FinBERTAnalyzer with a mocked transformers pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finbert_positive_sentiment(self):
|
||||
"""Positive article should yield a positive score."""
|
||||
mock_pipe = MagicMock()
|
||||
mock_pipe.return_value = _make_pipeline_result("positive", 0.9)
|
||||
|
||||
analyzer = FinBERTAnalyzer(model_name="test-model")
|
||||
analyzer._pipeline = mock_pipe
|
||||
|
||||
score, confidence = await analyzer.analyze(
|
||||
"Apple beats earnings expectations",
|
||||
"Apple reported revenue above analyst estimates.",
|
||||
)
|
||||
|
||||
assert score > 0.0, f"Expected positive score, got {score}"
|
||||
assert confidence == pytest.approx(0.9, abs=0.01)
|
||||
mock_pipe.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finbert_negative_sentiment(self):
|
||||
"""Negative article should yield a negative score."""
|
||||
mock_pipe = MagicMock()
|
||||
mock_pipe.return_value = _make_pipeline_result("negative", 0.85)
|
||||
|
||||
analyzer = FinBERTAnalyzer(model_name="test-model")
|
||||
analyzer._pipeline = mock_pipe
|
||||
|
||||
score, confidence = await analyzer.analyze(
|
||||
"Major bank reports massive losses",
|
||||
"The bank lost $2 billion in the quarter.",
|
||||
)
|
||||
|
||||
assert score < 0.0, f"Expected negative score, got {score}"
|
||||
assert confidence == pytest.approx(0.85, abs=0.01)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finbert_neutral_sentiment(self):
|
||||
"""Neutral article should yield a near-zero score."""
|
||||
mock_pipe = MagicMock()
|
||||
mock_pipe.return_value = _make_pipeline_result("neutral", 0.8)
|
||||
|
||||
analyzer = FinBERTAnalyzer(model_name="test-model")
|
||||
analyzer._pipeline = mock_pipe
|
||||
|
||||
score, confidence = await analyzer.analyze(
|
||||
"Company releases quarterly report",
|
||||
"The quarterly report was filed with the SEC.",
|
||||
)
|
||||
|
||||
# Neutral dominant => score close to zero (neutral maps to 0).
|
||||
# The small residual comes from the remaining probability split
|
||||
# between positive and negative.
|
||||
assert abs(score) < 0.2, f"Expected near-zero score, got {score}"
|
||||
assert confidence == pytest.approx(0.8, abs=0.01)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ollama Analyzer Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOllamaAnalyzer:
|
||||
"""Tests for OllamaAnalyzer with a mocked ollama client."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_successful_analysis(self):
|
||||
"""Valid JSON response should be parsed correctly."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.return_value = {
|
||||
"message": {
|
||||
"content": '{"sentiment_score": 0.75, "confidence": 0.85, "entities": ["AAPL"]}'
|
||||
}
|
||||
}
|
||||
|
||||
analyzer = OllamaAnalyzer(model="test-model")
|
||||
analyzer._client = mock_client
|
||||
|
||||
score, confidence = await analyzer.analyze("Good news for Apple", "Apple stock surges.")
|
||||
|
||||
assert score == pytest.approx(0.75)
|
||||
assert confidence == pytest.approx(0.85)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_parse_error_returns_zero(self):
|
||||
"""Invalid JSON should return (0.0, 0.0) fallback."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.return_value = {
|
||||
"message": {"content": "I think the sentiment is positive but I'm not sure."}
|
||||
}
|
||||
|
||||
analyzer = OllamaAnalyzer(model="test-model")
|
||||
analyzer._client = mock_client
|
||||
|
||||
score, confidence = await analyzer.analyze("Some headline", "Some content")
|
||||
|
||||
assert score == 0.0
|
||||
assert confidence == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_connection_error_returns_zero(self):
|
||||
"""Network/connection errors should return (0.0, 0.0) fallback."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.side_effect = ConnectionError("Cannot reach Ollama")
|
||||
|
||||
analyzer = OllamaAnalyzer(model="test-model")
|
||||
analyzer._client = mock_client
|
||||
|
||||
score, confidence = await analyzer.analyze("Some headline", "Some content")
|
||||
|
||||
assert score == 0.0
|
||||
assert confidence == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_markdown_code_fence(self):
|
||||
"""JSON wrapped in markdown code fences should still be parsed."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.return_value = {
|
||||
"message": {
|
||||
"content": '```json\n{"sentiment_score": -0.5, "confidence": 0.7, "entities": []}\n```'
|
||||
}
|
||||
}
|
||||
|
||||
analyzer = OllamaAnalyzer(model="test-model")
|
||||
analyzer._client = mock_client
|
||||
|
||||
score, confidence = await analyzer.analyze("Bad news", "Markets tumble.")
|
||||
|
||||
assert score == pytest.approx(-0.5)
|
||||
assert confidence == pytest.approx(0.7)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ticker Extraction Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTickerExtraction:
|
||||
"""Tests for the ticker extraction utility."""
|
||||
|
||||
def test_ticker_extraction_dollar_sign(self):
|
||||
"""$AAPL should extract AAPL."""
|
||||
tickers = extract_tickers("Big news for $AAPL today.")
|
||||
assert "AAPL" in tickers
|
||||
|
||||
def test_ticker_extraction_exchange_prefix(self):
|
||||
"""NASDAQ:TSLA should extract TSLA."""
|
||||
tickers = extract_tickers("Check out NASDAQ:TSLA performance.")
|
||||
assert "TSLA" in tickers
|
||||
|
||||
def test_ticker_extraction_nyse_prefix(self):
|
||||
"""NYSE:AAPL should extract AAPL."""
|
||||
tickers = extract_tickers("NYSE:AAPL is trading higher.")
|
||||
assert "AAPL" in tickers
|
||||
|
||||
def test_ticker_extraction_filters_false_positives(self):
|
||||
"""Common words like CEO, IPO, ETF, SEC, NYSE should be filtered."""
|
||||
tickers = extract_tickers(
|
||||
"The CEO announced a new IPO. The ETF was approved by the SEC on NYSE."
|
||||
)
|
||||
assert "CEO" not in tickers
|
||||
assert "IPO" not in tickers
|
||||
assert "ETF" not in tickers
|
||||
assert "SEC" not in tickers
|
||||
assert "NYSE" not in tickers
|
||||
|
||||
def test_ticker_extraction_deduplicates(self):
|
||||
"""Repeated mentions of the same ticker should appear only once."""
|
||||
tickers = extract_tickers("$AAPL rose 5%. $AAPL is now above $200. NASDAQ:AAPL is great.")
|
||||
assert tickers.count("AAPL") == 1
|
||||
|
||||
def test_ticker_extraction_multiple_tickers(self):
|
||||
"""Multiple different tickers should all be extracted."""
|
||||
tickers = extract_tickers("$AAPL and $MSFT both reported earnings. $GOOG is next.")
|
||||
assert "AAPL" in tickers
|
||||
assert "MSFT" in tickers
|
||||
assert "GOOG" in tickers
|
||||
|
||||
def test_ticker_extraction_empty_text(self):
|
||||
"""Empty text should return no tickers."""
|
||||
assert extract_tickers("") == []
|
||||
|
||||
def test_ticker_extraction_no_tickers(self):
|
||||
"""Text with no ticker-like patterns should return empty list."""
|
||||
tickers = extract_tickers("The market was flat today with no major movers.")
|
||||
# Should be empty — all uppercase words are filtered as false positives or too short.
|
||||
assert len(tickers) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ollama Fallback Routing Test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFallbackRouting:
|
||||
"""Test that Ollama is called when FinBERT confidence is below threshold."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_fallback_on_low_confidence(self):
|
||||
"""When FinBERT confidence < threshold, Ollama should be called."""
|
||||
# FinBERT returns low confidence
|
||||
finbert = AsyncMock(spec=FinBERTAnalyzer)
|
||||
finbert.analyze = AsyncMock(return_value=(0.1, 0.4)) # confidence 0.4 < 0.6 threshold
|
||||
|
||||
# Ollama returns higher confidence
|
||||
ollama = AsyncMock(spec=OllamaAnalyzer)
|
||||
ollama.analyze = AsyncMock(return_value=(0.8, 0.9))
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||
|
||||
config = SentimentAnalyzerConfig(
|
||||
finbert_confidence_threshold=0.6,
|
||||
otel_metrics_port=0,
|
||||
)
|
||||
|
||||
# Mock counters
|
||||
counters = {
|
||||
"articles_scored": MagicMock(),
|
||||
"finbert_count": MagicMock(),
|
||||
"ollama_count": MagicMock(),
|
||||
"inference_latency": MagicMock(),
|
||||
}
|
||||
|
||||
article = _make_raw_article(title="Test $AAPL Article", content="Apple stock rises.")
|
||||
|
||||
await process_article(article, finbert, ollama, publisher, config, counters)
|
||||
|
||||
# Both should have been called
|
||||
finbert.analyze.assert_called_once()
|
||||
ollama.analyze.assert_called_once()
|
||||
counters["finbert_count"].add.assert_called_once_with(1)
|
||||
counters["ollama_count"].add.assert_called_once_with(1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ollama_on_high_confidence(self):
|
||||
"""When FinBERT confidence >= threshold, Ollama should NOT be called."""
|
||||
finbert = AsyncMock(spec=FinBERTAnalyzer)
|
||||
finbert.analyze = AsyncMock(return_value=(0.8, 0.9)) # confidence 0.9 >= 0.6
|
||||
|
||||
ollama = AsyncMock(spec=OllamaAnalyzer)
|
||||
ollama.analyze = AsyncMock(return_value=(0.5, 0.7))
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||
|
||||
config = SentimentAnalyzerConfig(
|
||||
finbert_confidence_threshold=0.6,
|
||||
otel_metrics_port=0,
|
||||
)
|
||||
|
||||
counters = {
|
||||
"articles_scored": MagicMock(),
|
||||
"finbert_count": MagicMock(),
|
||||
"ollama_count": MagicMock(),
|
||||
"inference_latency": MagicMock(),
|
||||
}
|
||||
|
||||
article = _make_raw_article(title="Test $AAPL Article", content="Apple stock rises.")
|
||||
|
||||
await process_article(article, finbert, ollama, publisher, config, counters)
|
||||
|
||||
finbert.analyze.assert_called_once()
|
||||
ollama.analyze.assert_not_called()
|
||||
counters["ollama_count"].add.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main Flow / Integration Test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMainFlow:
|
||||
"""Test the full process_article flow with mocked analyzers and Redis."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_flow_publishes_scored_articles(self):
|
||||
"""process_article should publish a ScoredArticle for each ticker found."""
|
||||
finbert = AsyncMock(spec=FinBERTAnalyzer)
|
||||
finbert.analyze = AsyncMock(return_value=(0.75, 0.88))
|
||||
|
||||
ollama = AsyncMock(spec=OllamaAnalyzer)
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||
|
||||
config = SentimentAnalyzerConfig(
|
||||
finbert_confidence_threshold=0.6,
|
||||
otel_metrics_port=0,
|
||||
)
|
||||
|
||||
counters = {
|
||||
"articles_scored": MagicMock(),
|
||||
"finbert_count": MagicMock(),
|
||||
"ollama_count": MagicMock(),
|
||||
"inference_latency": MagicMock(),
|
||||
}
|
||||
|
||||
# Article mentions two tickers
|
||||
article = _make_raw_article(
|
||||
title="$AAPL and $MSFT report strong earnings",
|
||||
content="Both Apple and Microsoft beat estimates.",
|
||||
)
|
||||
|
||||
await process_article(article, finbert, ollama, publisher, config, counters)
|
||||
|
||||
# Should publish one ScoredArticle per ticker
|
||||
assert publisher.publish.call_count == 2
|
||||
counters["articles_scored"].add.assert_called_once_with(1)
|
||||
|
||||
# Verify the published data
|
||||
calls = publisher.publish.call_args_list
|
||||
published_tickers = {call.args[0]["ticker"] for call in calls}
|
||||
assert "AAPL" in published_tickers
|
||||
assert "MSFT" in published_tickers
|
||||
|
||||
# Each published message should have the correct sentiment score
|
||||
for call in calls:
|
||||
data = call.args[0]
|
||||
assert data["sentiment_score"] == pytest.approx(0.75)
|
||||
assert data["confidence"] == pytest.approx(0.88)
|
||||
assert data["model_used"] == "finbert"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_flow_no_tickers_no_publish(self):
|
||||
"""Articles with no tickers should not publish anything."""
|
||||
finbert = AsyncMock(spec=FinBERTAnalyzer)
|
||||
finbert.analyze = AsyncMock(return_value=(0.5, 0.9))
|
||||
|
||||
ollama = AsyncMock(spec=OllamaAnalyzer)
|
||||
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock()
|
||||
|
||||
config = SentimentAnalyzerConfig(
|
||||
finbert_confidence_threshold=0.6,
|
||||
otel_metrics_port=0,
|
||||
)
|
||||
|
||||
counters = {
|
||||
"articles_scored": MagicMock(),
|
||||
"finbert_count": MagicMock(),
|
||||
"ollama_count": MagicMock(),
|
||||
"inference_latency": MagicMock(),
|
||||
}
|
||||
|
||||
article = _make_raw_article(
|
||||
title="Market is flat today",
|
||||
content="Nothing much happening in the market.",
|
||||
)
|
||||
|
||||
await process_article(article, finbert, ollama, publisher, config, counters)
|
||||
|
||||
publisher.publish.assert_not_called()
|
||||
# Still counted as scored
|
||||
counters["articles_scored"].add.assert_called_once_with(1)
|
||||
Loading…
Add table
Add a link
Reference in a new issue