Merge branch 'worktree-agent-a6b241b2'
This commit is contained in:
commit
e483e9987f
8 changed files with 975 additions and 0 deletions
1
services/sentiment_analyzer/__init__.py
Normal file
1
services/sentiment_analyzer/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Sentiment Analyzer service — FinBERT + Ollama tiered analysis."""
|
||||||
1
services/sentiment_analyzer/analyzers/__init__.py
Normal file
1
services/sentiment_analyzer/analyzers/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Sentiment analysis backends (FinBERT, Ollama)."""
|
||||||
113
services/sentiment_analyzer/analyzers/finbert.py
Normal file
113
services/sentiment_analyzer/analyzers/finbert.py
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
"""FinBERT-based financial sentiment analyzer.
|
||||||
|
|
||||||
|
Uses the ProsusAI/finbert model via the HuggingFace transformers library
|
||||||
|
to classify article text as positive, negative, or neutral.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FinBERTAnalyzer:
|
||||||
|
"""Lazy-loading wrapper around a transformers sentiment-analysis pipeline.
|
||||||
|
|
||||||
|
The heavy ``transformers`` + ``torch`` imports and model download happen
|
||||||
|
only once, on the first call to :meth:`analyze`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "ProsusAI/finbert", max_content_length: int = 512) -> None:
|
||||||
|
self.model_name = model_name
|
||||||
|
self.max_content_length = max_content_length
|
||||||
|
self._pipeline: Any | None = None
|
||||||
|
|
||||||
|
def _load_pipeline(self) -> Any:
|
||||||
|
"""Lazily load the transformers pipeline on first use."""
|
||||||
|
if self._pipeline is None:
|
||||||
|
from transformers import pipeline # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
logger.info("Loading FinBERT model: %s", self.model_name)
|
||||||
|
self._pipeline = pipeline(
|
||||||
|
"sentiment-analysis",
|
||||||
|
model=self.model_name,
|
||||||
|
return_all_scores=True,
|
||||||
|
)
|
||||||
|
logger.info("FinBERT model loaded successfully")
|
||||||
|
return self._pipeline
|
||||||
|
|
||||||
|
async def analyze(self, title: str, content: str) -> tuple[float, float]:
|
||||||
|
"""Score the sentiment of an article.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
title:
|
||||||
|
Article headline.
|
||||||
|
content:
|
||||||
|
Article body text.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[float, float]
|
||||||
|
``(sentiment_score, confidence)`` where *sentiment_score* is in
|
||||||
|
``[-1.0, 1.0]`` and *confidence* is in ``[0.0, 1.0]``.
|
||||||
|
|
||||||
|
The input text is truncated to ``max_content_length`` tokens by
|
||||||
|
passing it through the model's tokenizer truncation (handled
|
||||||
|
automatically by the transformers pipeline).
|
||||||
|
"""
|
||||||
|
pipe = self._load_pipeline()
|
||||||
|
|
||||||
|
# Combine title and content; the pipeline will truncate to model max tokens.
|
||||||
|
text = f"{title}. {content}"
|
||||||
|
# Truncate to a reasonable character length proportional to max_content_length
|
||||||
|
# tokens (rough estimate: 1 token ~ 4 chars for English). The pipeline
|
||||||
|
# tokenizer will do the precise truncation, but this avoids sending
|
||||||
|
# enormous strings.
|
||||||
|
char_limit = self.max_content_length * 4
|
||||||
|
text = text[:char_limit]
|
||||||
|
|
||||||
|
# Run the blocking model inference in a thread pool so we don't block
|
||||||
|
# the event loop.
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
results = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: pipe(text, truncation=True, max_length=self.max_content_length),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._parse_scores(results)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_scores(results: list[list[dict[str, Any]]]) -> tuple[float, float]:
|
||||||
|
"""Map pipeline output to ``(score, confidence)``.
|
||||||
|
|
||||||
|
The ``return_all_scores=True`` pipeline returns a list of lists of dicts:
|
||||||
|
``[[{"label": "positive", "score": 0.85}, ...]]``.
|
||||||
|
|
||||||
|
Mapping:
|
||||||
|
- ``"positive"`` -> +1
|
||||||
|
- ``"negative"`` -> -1
|
||||||
|
- ``"neutral"`` -> 0
|
||||||
|
|
||||||
|
The sentiment score is the weighted sum of label polarities scaled by
|
||||||
|
their softmax probabilities. Confidence is the maximum softmax
|
||||||
|
probability.
|
||||||
|
"""
|
||||||
|
label_map = {"positive": 1.0, "negative": -1.0, "neutral": 0.0}
|
||||||
|
|
||||||
|
# results is [[{label, score}, ...]]
|
||||||
|
scores = results[0]
|
||||||
|
|
||||||
|
sentiment_score = 0.0
|
||||||
|
confidence = 0.0
|
||||||
|
for entry in scores:
|
||||||
|
label = entry["label"].lower()
|
||||||
|
prob = entry["score"]
|
||||||
|
sentiment_score += label_map.get(label, 0.0) * prob
|
||||||
|
if prob > confidence:
|
||||||
|
confidence = prob
|
||||||
|
|
||||||
|
return sentiment_score, confidence
|
||||||
102
services/sentiment_analyzer/analyzers/ollama_analyzer.py
Normal file
102
services/sentiment_analyzer/analyzers/ollama_analyzer.py
Normal file
|
|
@ -0,0 +1,102 @@
|
||||||
|
"""Ollama-based sentiment analyzer (LLM fallback).
|
||||||
|
|
||||||
|
Used when the FinBERT model's confidence is below the configured threshold.
|
||||||
|
Sends a structured prompt to a local Ollama instance and parses JSON output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a financial sentiment analysis assistant. "
|
||||||
|
"You will be given a news article title and content. "
|
||||||
|
"Analyze the sentiment and respond with ONLY valid JSON in this exact format:\n"
|
||||||
|
'{"sentiment_score": <float between -1.0 and 1.0>, '
|
||||||
|
'"confidence": <float between 0.0 and 1.0>, '
|
||||||
|
'"entities": [<list of mentioned company/ticker names>]}\n'
|
||||||
|
"Where sentiment_score: -1.0 = very negative, 0.0 = neutral, 1.0 = very positive.\n"
|
||||||
|
"Respond with ONLY the JSON object, no other text."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaAnalyzer:
|
||||||
|
"""Fallback sentiment analyzer using a local Ollama LLM."""
|
||||||
|
|
||||||
|
def __init__(self, model: str = "mistral", host: str = "http://localhost:11434") -> None:
|
||||||
|
self.model = model
|
||||||
|
self.host = host
|
||||||
|
self._client: object | None = None
|
||||||
|
|
||||||
|
def _get_client(self) -> object:
|
||||||
|
"""Lazily create the Ollama async client."""
|
||||||
|
if self._client is None:
|
||||||
|
import ollama # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
self._client = ollama.AsyncClient(host=self.host)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def analyze(self, title: str, content: str) -> tuple[float, float]:
|
||||||
|
"""Analyze sentiment using the Ollama LLM.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
title:
|
||||||
|
Article headline.
|
||||||
|
content:
|
||||||
|
Article body text.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[float, float]
|
||||||
|
``(sentiment_score, confidence)``. On any parse error or
|
||||||
|
communication failure, returns ``(0.0, 0.0)`` as a safe fallback.
|
||||||
|
"""
|
||||||
|
user_prompt = f"Title: {title}\n\nContent: {content}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
response = await client.chat( # type: ignore[union-attr]
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
raw_text: str = response["message"]["content"] # type: ignore[index]
|
||||||
|
return self._parse_response(raw_text)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Ollama analysis failed")
|
||||||
|
return 0.0, 0.0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_response(raw_text: str) -> tuple[float, float]:
|
||||||
|
"""Extract sentiment_score and confidence from LLM JSON output.
|
||||||
|
|
||||||
|
Robust against markdown code fences and leading/trailing whitespace.
|
||||||
|
Returns ``(0.0, 0.0)`` on any parsing failure.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Strip potential markdown code fences.
|
||||||
|
text = raw_text.strip()
|
||||||
|
if text.startswith("```"):
|
||||||
|
# Remove ```json ... ``` wrapper
|
||||||
|
lines = text.split("\n")
|
||||||
|
lines = [ln for ln in lines if not ln.strip().startswith("```")]
|
||||||
|
text = "\n".join(lines).strip()
|
||||||
|
|
||||||
|
data = json.loads(text)
|
||||||
|
score = float(data["sentiment_score"])
|
||||||
|
confidence = float(data["confidence"])
|
||||||
|
|
||||||
|
# Clamp to valid ranges.
|
||||||
|
score = max(-1.0, min(1.0, score))
|
||||||
|
confidence = max(0.0, min(1.0, confidence))
|
||||||
|
|
||||||
|
return score, confidence
|
||||||
|
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
|
||||||
|
logger.warning("Failed to parse Ollama response: %s", raw_text[:200])
|
||||||
|
return 0.0, 0.0
|
||||||
15
services/sentiment_analyzer/config.py
Normal file
15
services/sentiment_analyzer/config.py
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
"""Configuration for the sentiment analyzer service."""
|
||||||
|
|
||||||
|
from shared.config import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SentimentAnalyzerConfig(BaseConfig):
|
||||||
|
"""Extends BaseConfig with sentiment-analysis-specific settings."""
|
||||||
|
|
||||||
|
finbert_model: str = "ProsusAI/finbert"
|
||||||
|
finbert_confidence_threshold: float = 0.6
|
||||||
|
ollama_model: str = "mistral"
|
||||||
|
ollama_host: str = "http://localhost:11434"
|
||||||
|
max_content_length: int = 512
|
||||||
|
|
||||||
|
model_config = {"env_prefix": "TRADING_"}
|
||||||
169
services/sentiment_analyzer/main.py
Normal file
169
services/sentiment_analyzer/main.py
Normal file
|
|
@ -0,0 +1,169 @@
|
||||||
|
"""Sentiment Analyzer service — main entry point.
|
||||||
|
|
||||||
|
Consumes ``news:raw`` articles from Redis Streams, scores them using a
|
||||||
|
tiered approach (FinBERT first, Ollama fallback for low-confidence results),
|
||||||
|
extracts ticker mentions, and publishes ``ScoredArticle`` messages to
|
||||||
|
``news:scored``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
|
||||||
|
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
|
||||||
|
from services.sentiment_analyzer.config import SentimentAnalyzerConfig
|
||||||
|
from services.sentiment_analyzer.ticker_extractor import extract_tickers
|
||||||
|
from shared.redis_streams import StreamConsumer, StreamPublisher
|
||||||
|
from shared.schemas.news import RawArticle, ScoredArticle
|
||||||
|
from shared.telemetry import setup_telemetry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_article(
|
||||||
|
article: RawArticle,
|
||||||
|
finbert: FinBERTAnalyzer,
|
||||||
|
ollama: OllamaAnalyzer,
|
||||||
|
publisher: StreamPublisher,
|
||||||
|
config: SentimentAnalyzerConfig,
|
||||||
|
counters: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Score a single article and publish one ScoredArticle per extracted ticker.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
article:
|
||||||
|
The raw article consumed from the ``news:raw`` stream.
|
||||||
|
finbert:
|
||||||
|
FinBERT analyzer instance.
|
||||||
|
ollama:
|
||||||
|
Ollama analyzer instance (used as fallback).
|
||||||
|
publisher:
|
||||||
|
Publishes results to ``news:scored``.
|
||||||
|
config:
|
||||||
|
Service configuration (confidence threshold, etc.).
|
||||||
|
counters:
|
||||||
|
Dict of OpenTelemetry counter/histogram instruments.
|
||||||
|
"""
|
||||||
|
start = time.monotonic()
|
||||||
|
|
||||||
|
# --- Step 1: Run FinBERT ---
|
||||||
|
score, confidence = await finbert.analyze(article.title, article.content)
|
||||||
|
model_used = "finbert"
|
||||||
|
counters["finbert_count"].add(1)
|
||||||
|
|
||||||
|
# --- Step 2: Fallback to Ollama if confidence is too low ---
|
||||||
|
if confidence < config.finbert_confidence_threshold:
|
||||||
|
logger.info(
|
||||||
|
"FinBERT confidence %.2f below threshold %.2f — falling back to Ollama",
|
||||||
|
confidence,
|
||||||
|
config.finbert_confidence_threshold,
|
||||||
|
)
|
||||||
|
score, confidence = await ollama.analyze(article.title, article.content)
|
||||||
|
model_used = "ollama"
|
||||||
|
counters["ollama_count"].add(1)
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
counters["inference_latency"].record(elapsed)
|
||||||
|
|
||||||
|
# --- Step 3: Extract tickers ---
|
||||||
|
combined_text = f"{article.title} {article.content}"
|
||||||
|
tickers = extract_tickers(combined_text)
|
||||||
|
|
||||||
|
if not tickers:
|
||||||
|
logger.debug("No tickers found in article: %s", article.title[:80])
|
||||||
|
# Still count the article as scored even if no tickers found.
|
||||||
|
counters["articles_scored"].add(1)
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- Step 4: Publish one ScoredArticle per ticker ---
|
||||||
|
for ticker in tickers:
|
||||||
|
scored = ScoredArticle(
|
||||||
|
source=article.source,
|
||||||
|
url=article.url,
|
||||||
|
title=article.title,
|
||||||
|
content=article.content,
|
||||||
|
published_at=article.published_at,
|
||||||
|
fetched_at=article.fetched_at,
|
||||||
|
content_hash=article.content_hash,
|
||||||
|
ticker=ticker,
|
||||||
|
sentiment_score=score,
|
||||||
|
confidence=confidence,
|
||||||
|
model_used=model_used,
|
||||||
|
entities=tickers,
|
||||||
|
)
|
||||||
|
await publisher.publish(scored.model_dump(mode="json"))
|
||||||
|
logger.debug("Published scored article for %s (score=%.2f)", ticker, score)
|
||||||
|
|
||||||
|
counters["articles_scored"].add(1)
|
||||||
|
|
||||||
|
|
||||||
|
async def run(config: SentimentAnalyzerConfig | None = None) -> None:
|
||||||
|
"""Main service loop.
|
||||||
|
|
||||||
|
Connects to Redis, initialises analysers and telemetry, then
|
||||||
|
continuously consumes from ``news:raw`` and publishes to ``news:scored``.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = SentimentAnalyzerConfig()
|
||||||
|
|
||||||
|
logging.basicConfig(level=config.log_level)
|
||||||
|
logger.info("Starting Sentiment Analyzer service")
|
||||||
|
|
||||||
|
# --- Telemetry ---
|
||||||
|
meter = setup_telemetry("sentiment-analyzer", config.otel_metrics_port)
|
||||||
|
counters = {
|
||||||
|
"articles_scored": meter.create_counter(
|
||||||
|
"articles_scored",
|
||||||
|
description="Total articles scored by the sentiment analyzer",
|
||||||
|
),
|
||||||
|
"finbert_count": meter.create_counter(
|
||||||
|
"finbert_count",
|
||||||
|
description="Number of articles scored by FinBERT",
|
||||||
|
),
|
||||||
|
"ollama_count": meter.create_counter(
|
||||||
|
"ollama_count",
|
||||||
|
description="Number of articles scored by Ollama (fallback)",
|
||||||
|
),
|
||||||
|
"inference_latency": meter.create_histogram(
|
||||||
|
"inference_latency_seconds",
|
||||||
|
description="Time spent on sentiment inference per article",
|
||||||
|
unit="s",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Redis ---
|
||||||
|
redis = Redis.from_url(config.redis_url, decode_responses=False)
|
||||||
|
consumer = StreamConsumer(redis, "news:raw", "sentiment-analyzer", "worker-1")
|
||||||
|
publisher = StreamPublisher(redis, "news:scored")
|
||||||
|
|
||||||
|
# --- Analyzers ---
|
||||||
|
finbert = FinBERTAnalyzer(
|
||||||
|
model_name=config.finbert_model,
|
||||||
|
max_content_length=config.max_content_length,
|
||||||
|
)
|
||||||
|
ollama = OllamaAnalyzer(model=config.ollama_model, host=config.ollama_host)
|
||||||
|
|
||||||
|
logger.info("Consuming from news:raw, publishing to news:scored")
|
||||||
|
|
||||||
|
# --- Consume loop ---
|
||||||
|
async for _msg_id, data in consumer.consume():
|
||||||
|
try:
|
||||||
|
article = RawArticle.model_validate(data)
|
||||||
|
await process_article(article, finbert, ollama, publisher, config, counters)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error processing article: %s", data.get("title", "<unknown>"))
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""CLI entry point."""
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
163
services/sentiment_analyzer/ticker_extractor.py
Normal file
163
services/sentiment_analyzer/ticker_extractor.py
Normal file
|
|
@ -0,0 +1,163 @@
|
||||||
|
"""Extract stock ticker symbols from free-form text.
|
||||||
|
|
||||||
|
Handles common formats:
|
||||||
|
- Dollar-prefixed: ``$AAPL``
|
||||||
|
- Exchange-prefixed: ``NASDAQ:AAPL``, ``NYSE:TSLA``
|
||||||
|
- Standalone uppercase words that look like tickers (1-5 uppercase letters)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Common false positives: short English words, financial abbreviations, and
|
||||||
|
# exchange names that match the 1-5 uppercase letter pattern.
|
||||||
|
_FALSE_POSITIVES: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
# Common English words / pronouns
|
||||||
|
"A",
|
||||||
|
"I",
|
||||||
|
"AM",
|
||||||
|
"AN",
|
||||||
|
"AS",
|
||||||
|
"AT",
|
||||||
|
"BE",
|
||||||
|
"BY",
|
||||||
|
"DO",
|
||||||
|
"GO",
|
||||||
|
"IF",
|
||||||
|
"IN",
|
||||||
|
"IS",
|
||||||
|
"IT",
|
||||||
|
"ME",
|
||||||
|
"MY",
|
||||||
|
"NO",
|
||||||
|
"OF",
|
||||||
|
"OK",
|
||||||
|
"ON",
|
||||||
|
"OR",
|
||||||
|
"SO",
|
||||||
|
"TO",
|
||||||
|
"UP",
|
||||||
|
"US",
|
||||||
|
"WE",
|
||||||
|
"PM",
|
||||||
|
"THE",
|
||||||
|
"AND",
|
||||||
|
"FOR",
|
||||||
|
"NOT",
|
||||||
|
"BUT",
|
||||||
|
"ARE",
|
||||||
|
"WAS",
|
||||||
|
"HAS",
|
||||||
|
"HAD",
|
||||||
|
"ALL",
|
||||||
|
"CAN",
|
||||||
|
"HER",
|
||||||
|
"HIS",
|
||||||
|
"HOW",
|
||||||
|
"ITS",
|
||||||
|
"MAY",
|
||||||
|
"NEW",
|
||||||
|
"NOW",
|
||||||
|
"OLD",
|
||||||
|
"OUR",
|
||||||
|
"OUT",
|
||||||
|
"OWN",
|
||||||
|
"SAY",
|
||||||
|
"SHE",
|
||||||
|
"TOO",
|
||||||
|
"USE",
|
||||||
|
# Time-related
|
||||||
|
"EST",
|
||||||
|
"PST",
|
||||||
|
"CST",
|
||||||
|
"MST",
|
||||||
|
"UTC",
|
||||||
|
"GMT",
|
||||||
|
# Financial jargon
|
||||||
|
"CEO",
|
||||||
|
"CFO",
|
||||||
|
"COO",
|
||||||
|
"CTO",
|
||||||
|
"IPO",
|
||||||
|
"ETF",
|
||||||
|
"SEC",
|
||||||
|
"NYSE",
|
||||||
|
"AMEX",
|
||||||
|
"DJIA",
|
||||||
|
"GDP",
|
||||||
|
"CPI",
|
||||||
|
"FED",
|
||||||
|
"FOMC",
|
||||||
|
"FDA",
|
||||||
|
"EPS",
|
||||||
|
"P&L",
|
||||||
|
"ROI",
|
||||||
|
"YTD",
|
||||||
|
"QOQ",
|
||||||
|
"YOY",
|
||||||
|
"ATH",
|
||||||
|
"ATL",
|
||||||
|
"RSI",
|
||||||
|
"SMA",
|
||||||
|
"EMA",
|
||||||
|
"IOT",
|
||||||
|
"API",
|
||||||
|
"AI",
|
||||||
|
"ML",
|
||||||
|
"US",
|
||||||
|
"USA",
|
||||||
|
"UK",
|
||||||
|
"EU",
|
||||||
|
"IMF",
|
||||||
|
"FTC",
|
||||||
|
"DOJ",
|
||||||
|
"IRS",
|
||||||
|
"DOT",
|
||||||
|
"SPAC",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pattern 1: $AAPL (dollar-sign prefix)
|
||||||
|
_DOLLAR_PATTERN = re.compile(r"\$([A-Z]{1,5})\b")
|
||||||
|
|
||||||
|
# Pattern 2: NASDAQ:AAPL, NYSE:TSLA (exchange prefix)
|
||||||
|
_EXCHANGE_PATTERN = re.compile(r"\b(?:NASDAQ|NYSE|AMEX|OTC|BATS|ARCA):([A-Z]{1,5})\b")
|
||||||
|
|
||||||
|
# Pattern 3: standalone uppercase words at word boundaries (1-5 chars)
|
||||||
|
_STANDALONE_PATTERN = re.compile(r"\b([A-Z]{1,5})\b")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tickers(text: str) -> list[str]:
|
||||||
|
"""Extract deduplicated stock ticker symbols from *text*.
|
||||||
|
|
||||||
|
Returns a list of unique ticker strings in the order they were first
|
||||||
|
encountered. False positives (common English words, acronyms) are
|
||||||
|
filtered out.
|
||||||
|
"""
|
||||||
|
seen: set[str] = set()
|
||||||
|
result: list[str] = []
|
||||||
|
|
||||||
|
def _add(ticker: str) -> None:
|
||||||
|
if ticker not in seen and ticker not in _FALSE_POSITIVES:
|
||||||
|
seen.add(ticker)
|
||||||
|
result.append(ticker)
|
||||||
|
|
||||||
|
# Dollar-sign tickers have the highest signal — always include.
|
||||||
|
for match in _DOLLAR_PATTERN.finditer(text):
|
||||||
|
_add(match.group(1))
|
||||||
|
|
||||||
|
# Exchange-prefixed tickers are also high confidence.
|
||||||
|
for match in _EXCHANGE_PATTERN.finditer(text):
|
||||||
|
_add(match.group(1))
|
||||||
|
|
||||||
|
# Standalone uppercase words: only include if they look like real tickers
|
||||||
|
# (not in the false positives list). We restrict to 2-5 chars to reduce
|
||||||
|
# noise, unless they were already captured by the dollar/exchange patterns.
|
||||||
|
for match in _STANDALONE_PATTERN.finditer(text):
|
||||||
|
candidate = match.group(1)
|
||||||
|
if len(candidate) >= 2:
|
||||||
|
_add(candidate)
|
||||||
|
|
||||||
|
return result
|
||||||
411
tests/services/test_sentiment_analyzer.py
Normal file
411
tests/services/test_sentiment_analyzer.py
Normal file
|
|
@ -0,0 +1,411 @@
|
||||||
|
"""Tests for the sentiment analyzer service.
|
||||||
|
|
||||||
|
Covers FinBERT analyzer, Ollama analyzer, ticker extraction, and the main
|
||||||
|
service flow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
|
||||||
|
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
|
||||||
|
from services.sentiment_analyzer.config import SentimentAnalyzerConfig
|
||||||
|
from services.sentiment_analyzer.main import process_article
|
||||||
|
from services.sentiment_analyzer.ticker_extractor import extract_tickers
|
||||||
|
from shared.schemas.news import RawArticle
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_raw_article(**overrides) -> RawArticle:
|
||||||
|
defaults = {
|
||||||
|
"source": "test",
|
||||||
|
"url": "https://example.com/article",
|
||||||
|
"title": "Test Article About $AAPL",
|
||||||
|
"content": "Apple Inc announced strong earnings.",
|
||||||
|
"published_at": datetime(2026, 1, 15, tzinfo=timezone.utc),
|
||||||
|
"fetched_at": datetime(2026, 1, 15, 0, 5, tzinfo=timezone.utc),
|
||||||
|
"content_hash": "abc123",
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return RawArticle(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pipeline_result(label: str, score: float) -> list[list[dict]]:
|
||||||
|
"""Build a return value matching transformers pipeline(return_all_scores=True)."""
|
||||||
|
labels = {"positive": score if label == "positive" else 0.0,
|
||||||
|
"negative": score if label == "negative" else 0.0,
|
||||||
|
"neutral": score if label == "neutral" else 0.0}
|
||||||
|
# Distribute remaining probability
|
||||||
|
remaining = 1.0 - score
|
||||||
|
other_labels = [l for l in labels if l != label]
|
||||||
|
for ol in other_labels:
|
||||||
|
labels[ol] = remaining / len(other_labels)
|
||||||
|
|
||||||
|
return [[{"label": l, "score": s} for l, s in labels.items()]]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# FinBERT Analyzer Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFinBERTAnalyzer:
|
||||||
|
"""Tests for FinBERTAnalyzer with a mocked transformers pipeline."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finbert_positive_sentiment(self):
|
||||||
|
"""Positive article should yield a positive score."""
|
||||||
|
mock_pipe = MagicMock()
|
||||||
|
mock_pipe.return_value = _make_pipeline_result("positive", 0.9)
|
||||||
|
|
||||||
|
analyzer = FinBERTAnalyzer(model_name="test-model")
|
||||||
|
analyzer._pipeline = mock_pipe
|
||||||
|
|
||||||
|
score, confidence = await analyzer.analyze(
|
||||||
|
"Apple beats earnings expectations",
|
||||||
|
"Apple reported revenue above analyst estimates.",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert score > 0.0, f"Expected positive score, got {score}"
|
||||||
|
assert confidence == pytest.approx(0.9, abs=0.01)
|
||||||
|
mock_pipe.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finbert_negative_sentiment(self):
|
||||||
|
"""Negative article should yield a negative score."""
|
||||||
|
mock_pipe = MagicMock()
|
||||||
|
mock_pipe.return_value = _make_pipeline_result("negative", 0.85)
|
||||||
|
|
||||||
|
analyzer = FinBERTAnalyzer(model_name="test-model")
|
||||||
|
analyzer._pipeline = mock_pipe
|
||||||
|
|
||||||
|
score, confidence = await analyzer.analyze(
|
||||||
|
"Major bank reports massive losses",
|
||||||
|
"The bank lost $2 billion in the quarter.",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert score < 0.0, f"Expected negative score, got {score}"
|
||||||
|
assert confidence == pytest.approx(0.85, abs=0.01)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finbert_neutral_sentiment(self):
|
||||||
|
"""Neutral article should yield a near-zero score."""
|
||||||
|
mock_pipe = MagicMock()
|
||||||
|
mock_pipe.return_value = _make_pipeline_result("neutral", 0.8)
|
||||||
|
|
||||||
|
analyzer = FinBERTAnalyzer(model_name="test-model")
|
||||||
|
analyzer._pipeline = mock_pipe
|
||||||
|
|
||||||
|
score, confidence = await analyzer.analyze(
|
||||||
|
"Company releases quarterly report",
|
||||||
|
"The quarterly report was filed with the SEC.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Neutral dominant => score close to zero (neutral maps to 0).
|
||||||
|
# The small residual comes from the remaining probability split
|
||||||
|
# between positive and negative.
|
||||||
|
assert abs(score) < 0.2, f"Expected near-zero score, got {score}"
|
||||||
|
assert confidence == pytest.approx(0.8, abs=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Ollama Analyzer Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestOllamaAnalyzer:
|
||||||
|
"""Tests for OllamaAnalyzer with a mocked ollama client."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ollama_successful_analysis(self):
|
||||||
|
"""Valid JSON response should be parsed correctly."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.chat.return_value = {
|
||||||
|
"message": {
|
||||||
|
"content": '{"sentiment_score": 0.75, "confidence": 0.85, "entities": ["AAPL"]}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
analyzer = OllamaAnalyzer(model="test-model")
|
||||||
|
analyzer._client = mock_client
|
||||||
|
|
||||||
|
score, confidence = await analyzer.analyze("Good news for Apple", "Apple stock surges.")
|
||||||
|
|
||||||
|
assert score == pytest.approx(0.75)
|
||||||
|
assert confidence == pytest.approx(0.85)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ollama_parse_error_returns_zero(self):
|
||||||
|
"""Invalid JSON should return (0.0, 0.0) fallback."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.chat.return_value = {
|
||||||
|
"message": {"content": "I think the sentiment is positive but I'm not sure."}
|
||||||
|
}
|
||||||
|
|
||||||
|
analyzer = OllamaAnalyzer(model="test-model")
|
||||||
|
analyzer._client = mock_client
|
||||||
|
|
||||||
|
score, confidence = await analyzer.analyze("Some headline", "Some content")
|
||||||
|
|
||||||
|
assert score == 0.0
|
||||||
|
assert confidence == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ollama_connection_error_returns_zero(self):
|
||||||
|
"""Network/connection errors should return (0.0, 0.0) fallback."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.chat.side_effect = ConnectionError("Cannot reach Ollama")
|
||||||
|
|
||||||
|
analyzer = OllamaAnalyzer(model="test-model")
|
||||||
|
analyzer._client = mock_client
|
||||||
|
|
||||||
|
score, confidence = await analyzer.analyze("Some headline", "Some content")
|
||||||
|
|
||||||
|
assert score == 0.0
|
||||||
|
assert confidence == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ollama_markdown_code_fence(self):
|
||||||
|
"""JSON wrapped in markdown code fences should still be parsed."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.chat.return_value = {
|
||||||
|
"message": {
|
||||||
|
"content": '```json\n{"sentiment_score": -0.5, "confidence": 0.7, "entities": []}\n```'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
analyzer = OllamaAnalyzer(model="test-model")
|
||||||
|
analyzer._client = mock_client
|
||||||
|
|
||||||
|
score, confidence = await analyzer.analyze("Bad news", "Markets tumble.")
|
||||||
|
|
||||||
|
assert score == pytest.approx(-0.5)
|
||||||
|
assert confidence == pytest.approx(0.7)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Ticker Extraction Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestTickerExtraction:
|
||||||
|
"""Tests for the ticker extraction utility."""
|
||||||
|
|
||||||
|
def test_ticker_extraction_dollar_sign(self):
|
||||||
|
"""$AAPL should extract AAPL."""
|
||||||
|
tickers = extract_tickers("Big news for $AAPL today.")
|
||||||
|
assert "AAPL" in tickers
|
||||||
|
|
||||||
|
def test_ticker_extraction_exchange_prefix(self):
|
||||||
|
"""NASDAQ:TSLA should extract TSLA."""
|
||||||
|
tickers = extract_tickers("Check out NASDAQ:TSLA performance.")
|
||||||
|
assert "TSLA" in tickers
|
||||||
|
|
||||||
|
def test_ticker_extraction_nyse_prefix(self):
|
||||||
|
"""NYSE:AAPL should extract AAPL."""
|
||||||
|
tickers = extract_tickers("NYSE:AAPL is trading higher.")
|
||||||
|
assert "AAPL" in tickers
|
||||||
|
|
||||||
|
def test_ticker_extraction_filters_false_positives(self):
|
||||||
|
"""Common words like CEO, IPO, ETF, SEC, NYSE should be filtered."""
|
||||||
|
tickers = extract_tickers(
|
||||||
|
"The CEO announced a new IPO. The ETF was approved by the SEC on NYSE."
|
||||||
|
)
|
||||||
|
assert "CEO" not in tickers
|
||||||
|
assert "IPO" not in tickers
|
||||||
|
assert "ETF" not in tickers
|
||||||
|
assert "SEC" not in tickers
|
||||||
|
assert "NYSE" not in tickers
|
||||||
|
|
||||||
|
def test_ticker_extraction_deduplicates(self):
|
||||||
|
"""Repeated mentions of the same ticker should appear only once."""
|
||||||
|
tickers = extract_tickers("$AAPL rose 5%. $AAPL is now above $200. NASDAQ:AAPL is great.")
|
||||||
|
assert tickers.count("AAPL") == 1
|
||||||
|
|
||||||
|
def test_ticker_extraction_multiple_tickers(self):
|
||||||
|
"""Multiple different tickers should all be extracted."""
|
||||||
|
tickers = extract_tickers("$AAPL and $MSFT both reported earnings. $GOOG is next.")
|
||||||
|
assert "AAPL" in tickers
|
||||||
|
assert "MSFT" in tickers
|
||||||
|
assert "GOOG" in tickers
|
||||||
|
|
||||||
|
def test_ticker_extraction_empty_text(self):
|
||||||
|
"""Empty text should return no tickers."""
|
||||||
|
assert extract_tickers("") == []
|
||||||
|
|
||||||
|
def test_ticker_extraction_no_tickers(self):
|
||||||
|
"""Text with no ticker-like patterns should return empty list."""
|
||||||
|
tickers = extract_tickers("The market was flat today with no major movers.")
|
||||||
|
# Should be empty — all uppercase words are filtered as false positives or too short.
|
||||||
|
assert len(tickers) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Ollama Fallback Routing Test
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFallbackRouting:
|
||||||
|
"""Test that Ollama is called when FinBERT confidence is below threshold."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ollama_fallback_on_low_confidence(self):
|
||||||
|
"""When FinBERT confidence < threshold, Ollama should be called."""
|
||||||
|
# FinBERT returns low confidence
|
||||||
|
finbert = AsyncMock(spec=FinBERTAnalyzer)
|
||||||
|
finbert.analyze = AsyncMock(return_value=(0.1, 0.4)) # confidence 0.4 < 0.6 threshold
|
||||||
|
|
||||||
|
# Ollama returns higher confidence
|
||||||
|
ollama = AsyncMock(spec=OllamaAnalyzer)
|
||||||
|
ollama.analyze = AsyncMock(return_value=(0.8, 0.9))
|
||||||
|
|
||||||
|
publisher = AsyncMock()
|
||||||
|
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||||
|
|
||||||
|
config = SentimentAnalyzerConfig(
|
||||||
|
finbert_confidence_threshold=0.6,
|
||||||
|
otel_metrics_port=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock counters
|
||||||
|
counters = {
|
||||||
|
"articles_scored": MagicMock(),
|
||||||
|
"finbert_count": MagicMock(),
|
||||||
|
"ollama_count": MagicMock(),
|
||||||
|
"inference_latency": MagicMock(),
|
||||||
|
}
|
||||||
|
|
||||||
|
article = _make_raw_article(title="Test $AAPL Article", content="Apple stock rises.")
|
||||||
|
|
||||||
|
await process_article(article, finbert, ollama, publisher, config, counters)
|
||||||
|
|
||||||
|
# Both should have been called
|
||||||
|
finbert.analyze.assert_called_once()
|
||||||
|
ollama.analyze.assert_called_once()
|
||||||
|
counters["finbert_count"].add.assert_called_once_with(1)
|
||||||
|
counters["ollama_count"].add.assert_called_once_with(1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_ollama_on_high_confidence(self):
|
||||||
|
"""When FinBERT confidence >= threshold, Ollama should NOT be called."""
|
||||||
|
finbert = AsyncMock(spec=FinBERTAnalyzer)
|
||||||
|
finbert.analyze = AsyncMock(return_value=(0.8, 0.9)) # confidence 0.9 >= 0.6
|
||||||
|
|
||||||
|
ollama = AsyncMock(spec=OllamaAnalyzer)
|
||||||
|
ollama.analyze = AsyncMock(return_value=(0.5, 0.7))
|
||||||
|
|
||||||
|
publisher = AsyncMock()
|
||||||
|
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||||
|
|
||||||
|
config = SentimentAnalyzerConfig(
|
||||||
|
finbert_confidence_threshold=0.6,
|
||||||
|
otel_metrics_port=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
counters = {
|
||||||
|
"articles_scored": MagicMock(),
|
||||||
|
"finbert_count": MagicMock(),
|
||||||
|
"ollama_count": MagicMock(),
|
||||||
|
"inference_latency": MagicMock(),
|
||||||
|
}
|
||||||
|
|
||||||
|
article = _make_raw_article(title="Test $AAPL Article", content="Apple stock rises.")
|
||||||
|
|
||||||
|
await process_article(article, finbert, ollama, publisher, config, counters)
|
||||||
|
|
||||||
|
finbert.analyze.assert_called_once()
|
||||||
|
ollama.analyze.assert_not_called()
|
||||||
|
counters["ollama_count"].add.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main Flow / Integration Test
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestMainFlow:
|
||||||
|
"""Test the full process_article flow with mocked analyzers and Redis."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_main_flow_publishes_scored_articles(self):
|
||||||
|
"""process_article should publish a ScoredArticle for each ticker found."""
|
||||||
|
finbert = AsyncMock(spec=FinBERTAnalyzer)
|
||||||
|
finbert.analyze = AsyncMock(return_value=(0.75, 0.88))
|
||||||
|
|
||||||
|
ollama = AsyncMock(spec=OllamaAnalyzer)
|
||||||
|
|
||||||
|
publisher = AsyncMock()
|
||||||
|
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||||
|
|
||||||
|
config = SentimentAnalyzerConfig(
|
||||||
|
finbert_confidence_threshold=0.6,
|
||||||
|
otel_metrics_port=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
counters = {
|
||||||
|
"articles_scored": MagicMock(),
|
||||||
|
"finbert_count": MagicMock(),
|
||||||
|
"ollama_count": MagicMock(),
|
||||||
|
"inference_latency": MagicMock(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Article mentions two tickers
|
||||||
|
article = _make_raw_article(
|
||||||
|
title="$AAPL and $MSFT report strong earnings",
|
||||||
|
content="Both Apple and Microsoft beat estimates.",
|
||||||
|
)
|
||||||
|
|
||||||
|
await process_article(article, finbert, ollama, publisher, config, counters)
|
||||||
|
|
||||||
|
# Should publish one ScoredArticle per ticker
|
||||||
|
assert publisher.publish.call_count == 2
|
||||||
|
counters["articles_scored"].add.assert_called_once_with(1)
|
||||||
|
|
||||||
|
# Verify the published data
|
||||||
|
calls = publisher.publish.call_args_list
|
||||||
|
published_tickers = {call.args[0]["ticker"] for call in calls}
|
||||||
|
assert "AAPL" in published_tickers
|
||||||
|
assert "MSFT" in published_tickers
|
||||||
|
|
||||||
|
# Each published message should have the correct sentiment score
|
||||||
|
for call in calls:
|
||||||
|
data = call.args[0]
|
||||||
|
assert data["sentiment_score"] == pytest.approx(0.75)
|
||||||
|
assert data["confidence"] == pytest.approx(0.88)
|
||||||
|
assert data["model_used"] == "finbert"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_main_flow_no_tickers_no_publish(self):
|
||||||
|
"""Articles with no tickers should not publish anything."""
|
||||||
|
finbert = AsyncMock(spec=FinBERTAnalyzer)
|
||||||
|
finbert.analyze = AsyncMock(return_value=(0.5, 0.9))
|
||||||
|
|
||||||
|
ollama = AsyncMock(spec=OllamaAnalyzer)
|
||||||
|
|
||||||
|
publisher = AsyncMock()
|
||||||
|
publisher.publish = AsyncMock()
|
||||||
|
|
||||||
|
config = SentimentAnalyzerConfig(
|
||||||
|
finbert_confidence_threshold=0.6,
|
||||||
|
otel_metrics_port=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
counters = {
|
||||||
|
"articles_scored": MagicMock(),
|
||||||
|
"finbert_count": MagicMock(),
|
||||||
|
"ollama_count": MagicMock(),
|
||||||
|
"inference_latency": MagicMock(),
|
||||||
|
}
|
||||||
|
|
||||||
|
article = _make_raw_article(
|
||||||
|
title="Market is flat today",
|
||||||
|
content="Nothing much happening in the market.",
|
||||||
|
)
|
||||||
|
|
||||||
|
await process_article(article, finbert, ollama, publisher, config, counters)
|
||||||
|
|
||||||
|
publisher.publish.assert_not_called()
|
||||||
|
# Still counted as scored
|
||||||
|
counters["articles_scored"].add.assert_called_once_with(1)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue