trading/tests/services/test_news_fetcher.py

391 lines
12 KiB
Python

"""Tests for the news fetcher service — RSS, Reddit, deduplication, publishing."""
import hashlib
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from shared.schemas.news import RawArticle
from services.news_fetcher.sources.rss import RSSSource
from services.news_fetcher.sources.reddit import RedditSource
from services.news_fetcher.main import _deduplicate_and_publish, SEEN_HASHES_KEY
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
FIXTURE_RSS_XML = """\
<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0">
<channel>
<title>Test Finance Feed</title>
<item>
<title>AAPL hits record high</title>
<link>https://example.com/aapl</link>
<description>Apple stock reached an all-time high today.</description>
<pubDate>Sat, 22 Feb 2026 12:00:00 GMT</pubDate>
</item>
<item>
<title>TSLA earnings beat</title>
<link>https://example.com/tsla</link>
<description>Tesla reported stronger-than-expected earnings.</description>
<pubDate>Sat, 22 Feb 2026 13:00:00 GMT</pubDate>
</item>
</channel>
</rss>
"""
def _make_fake_entry(title: str, link: str, summary: str, published: str | None = None) -> dict:
"""Return a dict matching feedparser entry structure."""
entry = {"title": title, "link": link, "summary": summary}
if published:
entry["published"] = published
return entry
def _make_reddit_post(title, selftext, url, permalink, score, created_utc):
"""Return a SimpleNamespace mimicking an asyncpraw Submission."""
return SimpleNamespace(
title=title,
selftext=selftext,
url=url,
permalink=permalink,
score=score,
created_utc=created_utc,
)
# ---------------------------------------------------------------------------
# RSS tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_rss_source_parses_feed():
"""feedparser.parse is called for each feed and entries become RawArticles."""
fake_parsed = MagicMock()
fake_parsed.bozo = False
fake_parsed.entries = [
_make_fake_entry(
"AAPL hits record high",
"https://example.com/aapl",
"Apple stock reached an all-time high today.",
"Sat, 22 Feb 2026 12:00:00 GMT",
),
_make_fake_entry(
"TSLA earnings beat",
"https://example.com/tsla",
"Tesla reported stronger-than-expected earnings.",
"Sat, 22 Feb 2026 13:00:00 GMT",
),
]
with patch("services.news_fetcher.sources.rss.feedparser.parse", return_value=fake_parsed) as mock_parse:
source = RSSSource(feeds=["https://example.com/feed"])
articles = await source.fetch()
mock_parse.assert_called_once_with("https://example.com/feed")
assert len(articles) == 2
assert articles[0].source == "rss"
assert articles[0].title == "AAPL hits record high"
assert articles[0].url == "https://example.com/aapl"
assert articles[0].content == "Apple stock reached an all-time high today."
expected_hash = hashlib.sha256("https://example.com/aaplAAPL hits record high".encode()).hexdigest()
assert articles[0].content_hash == expected_hash
assert articles[0].published_at is not None
assert articles[1].title == "TSLA earnings beat"
@pytest.mark.asyncio
async def test_rss_source_handles_bad_feed():
"""A feed that raises an exception is skipped; an empty list is returned."""
with patch(
"services.news_fetcher.sources.rss.feedparser.parse",
side_effect=Exception("network timeout"),
):
source = RSSSource(feeds=["https://bad-feed.example.com/rss"])
articles = await source.fetch()
assert articles == []
@pytest.mark.asyncio
async def test_rss_source_handles_bozo_feed():
"""A bozo feed with no entries is skipped gracefully."""
fake_parsed = MagicMock()
fake_parsed.bozo = True
fake_parsed.entries = []
fake_parsed.bozo_exception = "malformed XML"
with patch("services.news_fetcher.sources.rss.feedparser.parse", return_value=fake_parsed):
source = RSSSource(feeds=["https://broken.example.com/rss"])
articles = await source.fetch()
assert articles == []
@pytest.mark.asyncio
async def test_rss_source_multiple_feeds():
"""Articles from multiple feeds are combined."""
feed1 = MagicMock()
feed1.bozo = False
feed1.entries = [_make_fake_entry("A", "https://a.com", "content a")]
feed2 = MagicMock()
feed2.bozo = False
feed2.entries = [_make_fake_entry("B", "https://b.com", "content b")]
with patch(
"services.news_fetcher.sources.rss.feedparser.parse",
side_effect=[feed1, feed2],
):
source = RSSSource(feeds=["https://feed1.com", "https://feed2.com"])
articles = await source.fetch()
assert len(articles) == 2
assert {a.title for a in articles} == {"A", "B"}
# ---------------------------------------------------------------------------
# Reddit tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reddit_source_converts_posts():
"""Hot posts are converted to RawArticle with correct fields."""
post = _make_reddit_post(
title="GME to the moon",
selftext="Diamond hands forever",
url="https://reddit.com/r/wsb/1",
permalink="/r/wallstreetbets/comments/abc123/gme_to_the_moon/",
score=500,
created_utc=1740200000.0,
)
# Create an async iterator from the posts
async def _hot(limit=25):
for p in [post]:
yield p
fake_subreddit = AsyncMock()
fake_subreddit.hot = _hot
fake_reddit = AsyncMock()
fake_reddit.subreddit = AsyncMock(return_value=fake_subreddit)
fake_reddit.close = AsyncMock()
with patch("asyncpraw.Reddit", return_value=fake_reddit):
source = RedditSource(
subreddits=["wallstreetbets"],
client_id="test_id",
client_secret="test_secret",
user_agent="test-agent",
min_score=10,
)
articles = await source.fetch()
assert len(articles) == 1
assert articles[0].source == "reddit"
assert articles[0].title == "GME to the moon"
assert articles[0].content == "Diamond hands forever"
expected_hash = hashlib.sha256(
"/r/wallstreetbets/comments/abc123/gme_to_the_moon/".encode()
).hexdigest()
assert articles[0].content_hash == expected_hash
assert "reddit.com" in articles[0].url
fake_reddit.close.assert_awaited_once()
@pytest.mark.asyncio
async def test_reddit_source_filters_by_score():
"""Posts below min_score are excluded."""
high_score = _make_reddit_post("High", "text", "url", "/r/stocks/high", 100, 1740200000.0)
low_score = _make_reddit_post("Low", "text", "url", "/r/stocks/low", 5, 1740200000.0)
async def _hot(limit=25):
for p in [high_score, low_score]:
yield p
fake_subreddit = AsyncMock()
fake_subreddit.hot = _hot
fake_reddit = AsyncMock()
fake_reddit.subreddit = AsyncMock(return_value=fake_subreddit)
fake_reddit.close = AsyncMock()
with patch("asyncpraw.Reddit", return_value=fake_reddit):
source = RedditSource(
subreddits=["stocks"],
client_id="id",
client_secret="secret",
user_agent="agent",
min_score=10,
)
articles = await source.fetch()
assert len(articles) == 1
assert articles[0].title == "High"
@pytest.mark.asyncio
async def test_reddit_source_uses_url_when_no_selftext():
"""When selftext is empty, the post URL is used as content."""
post = _make_reddit_post(
title="Link post",
selftext="",
url="https://example.com/article",
permalink="/r/investing/link",
score=50,
created_utc=1740200000.0,
)
async def _hot(limit=25):
yield post
fake_subreddit = AsyncMock()
fake_subreddit.hot = _hot
fake_reddit = AsyncMock()
fake_reddit.subreddit = AsyncMock(return_value=fake_subreddit)
fake_reddit.close = AsyncMock()
with patch("asyncpraw.Reddit", return_value=fake_reddit):
source = RedditSource(
subreddits=["investing"],
client_id="id",
client_secret="secret",
user_agent="agent",
min_score=10,
)
articles = await source.fetch()
assert len(articles) == 1
assert articles[0].content == "https://example.com/article"
# ---------------------------------------------------------------------------
# Deduplication tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_deduplication_skips_seen_hashes():
"""Articles with previously-seen content_hash are not published."""
now = datetime.now(timezone.utc)
articles = [
RawArticle(
source="rss",
url="https://example.com/1",
title="First",
content="Content 1",
fetched_at=now,
content_hash="hash_new",
),
RawArticle(
source="rss",
url="https://example.com/2",
title="Already seen",
content="Content 2",
fetched_at=now,
content_hash="hash_old",
),
]
redis = AsyncMock()
# First article is new (SADD returns 1), second is duplicate (returns 0)
redis.sadd = AsyncMock(side_effect=[1, 0])
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
counter.add = MagicMock()
error_counter = MagicMock()
count = await _deduplicate_and_publish(articles, redis, publisher, counter, error_counter)
assert count == 1
publisher.publish.assert_called_once()
# Verify the hash was checked for both articles
assert redis.sadd.call_count == 2
@pytest.mark.asyncio
async def test_deduplication_publishes_all_new():
"""When all articles are new, all are published."""
now = datetime.now(timezone.utc)
articles = [
RawArticle(
source="rss",
url=f"https://example.com/{i}",
title=f"Article {i}",
content=f"Content {i}",
fetched_at=now,
content_hash=f"hash_{i}",
)
for i in range(3)
]
redis = AsyncMock()
redis.sadd = AsyncMock(return_value=1)
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
counter.add = MagicMock()
error_counter = MagicMock()
count = await _deduplicate_and_publish(articles, redis, publisher, counter, error_counter)
assert count == 3
assert publisher.publish.call_count == 3
counter.add.assert_called_once_with(3)
# ---------------------------------------------------------------------------
# Main service integration test (mocked sources + redis)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_main_publishes_to_stream():
"""End-to-end: mocked sources produce articles which get published."""
now = datetime.now(timezone.utc)
fake_article = RawArticle(
source="rss",
url="https://example.com/test",
title="Test",
content="Test content",
fetched_at=now,
content_hash="unique_hash",
)
redis = AsyncMock()
redis.sadd = AsyncMock(return_value=1)
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
counter.add = MagicMock()
error_counter = MagicMock()
count = await _deduplicate_and_publish(
[fake_article], redis, publisher, counter, error_counter
)
assert count == 1
publisher.publish.assert_called_once()
# Verify the published data matches the article
call_args = publisher.publish.call_args[0][0]
assert call_args["title"] == "Test"
assert call_args["source"] == "rss"
assert call_args["content_hash"] == "unique_hash"