feat: news fetcher service — RSS and Reddit sources
This commit is contained in:
parent
9f46071502
commit
90b52a5144
10 changed files with 722 additions and 2 deletions
|
|
@ -16,7 +16,7 @@ dependencies = [
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
api = ["fastapi>=0.110", "uvicorn[standard]>=0.27", "websockets>=12.0", "py-webauthn>=2.0", "pyjwt[crypto]>=2.8"]
|
api = ["fastapi>=0.110", "uvicorn[standard]>=0.27", "websockets>=12.0", "py-webauthn>=2.0", "pyjwt[crypto]>=2.8"]
|
||||||
news = ["feedparser>=6.0", "praw>=7.7", "httpx>=0.27"]
|
news = ["feedparser>=6.0", "praw>=7.7", "asyncpraw>=7.7", "httpx>=0.27"]
|
||||||
sentiment = ["transformers>=4.38", "torch>=2.2", "ollama>=0.1"]
|
sentiment = ["transformers>=4.38", "torch>=2.2", "ollama>=0.1"]
|
||||||
trading = ["alpaca-py>=0.21"]
|
trading = ["alpaca-py>=0.21"]
|
||||||
backtester = ["numpy>=1.26", "pandas>=2.2"]
|
backtester = ["numpy>=1.26", "pandas>=2.2"]
|
||||||
|
|
@ -27,7 +27,7 @@ requires = ["setuptools>=70.0"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
include = ["shared*", "tests*"]
|
include = ["shared*", "services*", "tests*"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
|
|
||||||
0
services/__init__.py
Normal file
0
services/__init__.py
Normal file
1
services/news_fetcher/__init__.py
Normal file
1
services/news_fetcher/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""News fetcher service — polls RSS feeds and Reddit for financial news."""
|
||||||
28
services/news_fetcher/config.py
Normal file
28
services/news_fetcher/config.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""Configuration for the news fetcher service."""
|
||||||
|
|
||||||
|
from shared.config import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
class NewsFetcherConfig(BaseConfig):
|
||||||
|
"""News fetcher settings.
|
||||||
|
|
||||||
|
Extends :class:`BaseConfig` with RSS feed URLs, poll intervals,
|
||||||
|
and Reddit API credentials. All settings can be overridden via
|
||||||
|
environment variables prefixed with ``TRADING_``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# RSS settings
|
||||||
|
rss_feeds: list[str] = [
|
||||||
|
"https://finance.yahoo.com/news/rssindex",
|
||||||
|
"https://feeds.reuters.com/reuters/businessNews",
|
||||||
|
"https://feeds.content.dowjones.io/public/rss/mw_topstories",
|
||||||
|
]
|
||||||
|
rss_poll_interval_seconds: int = 300
|
||||||
|
|
||||||
|
# Reddit settings
|
||||||
|
reddit_subreddits: list[str] = ["wallstreetbets", "stocks", "investing"]
|
||||||
|
reddit_poll_interval_seconds: int = 600
|
||||||
|
reddit_min_score: int = 10
|
||||||
|
reddit_client_id: str = ""
|
||||||
|
reddit_client_secret: str = ""
|
||||||
|
reddit_user_agent: str = "trading-bot/0.1"
|
||||||
152
services/news_fetcher/main.py
Normal file
152
services/news_fetcher/main.py
Normal file
|
|
@ -0,0 +1,152 @@
|
||||||
|
"""News fetcher service entry point.
|
||||||
|
|
||||||
|
Polls RSS feeds and Reddit on independent schedules, deduplicates
|
||||||
|
articles by content hash (via a Redis SET), and publishes new articles
|
||||||
|
to the ``news:raw`` Redis Stream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from shared.redis_streams import StreamPublisher
|
||||||
|
from shared.telemetry import setup_telemetry
|
||||||
|
from services.news_fetcher.config import NewsFetcherConfig
|
||||||
|
from services.news_fetcher.sources.rss import RSSSource
|
||||||
|
from services.news_fetcher.sources.reddit import RedditSource
|
||||||
|
from shared.schemas.news import RawArticle
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SEEN_HASHES_KEY = "news:seen_hashes"
|
||||||
|
NEWS_RAW_STREAM = "news:raw"
|
||||||
|
|
||||||
|
|
||||||
|
async def _deduplicate_and_publish(
|
||||||
|
articles: list[RawArticle],
|
||||||
|
redis: Redis,
|
||||||
|
publisher: StreamPublisher,
|
||||||
|
articles_fetched_counter,
|
||||||
|
fetch_errors_counter,
|
||||||
|
) -> int:
|
||||||
|
"""Add unseen articles to the ``news:raw`` stream.
|
||||||
|
|
||||||
|
Returns the number of newly published articles.
|
||||||
|
"""
|
||||||
|
published = 0
|
||||||
|
for article in articles:
|
||||||
|
# SADD returns 1 if the member was added (i.e. not already present)
|
||||||
|
added = await redis.sadd(SEEN_HASHES_KEY, article.content_hash)
|
||||||
|
if added:
|
||||||
|
await publisher.publish(article.model_dump(mode="json"))
|
||||||
|
published += 1
|
||||||
|
if published:
|
||||||
|
articles_fetched_counter.add(published)
|
||||||
|
return published
|
||||||
|
|
||||||
|
|
||||||
|
async def _poll_rss(
|
||||||
|
source: RSSSource,
|
||||||
|
interval: int,
|
||||||
|
redis: Redis,
|
||||||
|
publisher: StreamPublisher,
|
||||||
|
articles_fetched_counter,
|
||||||
|
fetch_errors_counter,
|
||||||
|
) -> None:
|
||||||
|
"""Continuously poll RSS feeds at *interval* seconds."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
logger.info("Polling RSS feeds …")
|
||||||
|
articles = await source.fetch()
|
||||||
|
count = await _deduplicate_and_publish(
|
||||||
|
articles, redis, publisher, articles_fetched_counter, fetch_errors_counter
|
||||||
|
)
|
||||||
|
logger.info("RSS poll complete: %d new articles published", count)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("RSS poll cycle failed")
|
||||||
|
fetch_errors_counter.add(1)
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
|
async def _poll_reddit(
|
||||||
|
source: RedditSource,
|
||||||
|
interval: int,
|
||||||
|
redis: Redis,
|
||||||
|
publisher: StreamPublisher,
|
||||||
|
articles_fetched_counter,
|
||||||
|
fetch_errors_counter,
|
||||||
|
) -> None:
|
||||||
|
"""Continuously poll Reddit at *interval* seconds."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
logger.info("Polling Reddit …")
|
||||||
|
articles = await source.fetch()
|
||||||
|
count = await _deduplicate_and_publish(
|
||||||
|
articles, redis, publisher, articles_fetched_counter, fetch_errors_counter
|
||||||
|
)
|
||||||
|
logger.info("Reddit poll complete: %d new articles published", count)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Reddit poll cycle failed")
|
||||||
|
fetch_errors_counter.add(1)
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
"""Boot the news fetcher and start polling."""
|
||||||
|
config = NewsFetcherConfig()
|
||||||
|
|
||||||
|
logging.basicConfig(level=config.log_level)
|
||||||
|
logger.info("Starting news fetcher service")
|
||||||
|
|
||||||
|
# Telemetry
|
||||||
|
meter = setup_telemetry("news-fetcher", config.otel_metrics_port)
|
||||||
|
articles_fetched_counter = meter.create_counter(
|
||||||
|
"news.articles_fetched",
|
||||||
|
description="Total articles fetched and published",
|
||||||
|
)
|
||||||
|
fetch_errors_counter = meter.create_counter(
|
||||||
|
"news.fetch_errors",
|
||||||
|
description="Total fetch-cycle errors",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
redis = Redis.from_url(config.redis_url, decode_responses=True)
|
||||||
|
publisher = StreamPublisher(redis, NEWS_RAW_STREAM)
|
||||||
|
|
||||||
|
# Sources
|
||||||
|
rss_source = RSSSource(feeds=config.rss_feeds)
|
||||||
|
reddit_source = RedditSource(
|
||||||
|
subreddits=config.reddit_subreddits,
|
||||||
|
client_id=config.reddit_client_id,
|
||||||
|
client_secret=config.reddit_client_secret,
|
||||||
|
user_agent=config.reddit_user_agent,
|
||||||
|
min_score=config.reddit_min_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run pollers concurrently
|
||||||
|
async with asyncio.TaskGroup() as tg:
|
||||||
|
tg.create_task(
|
||||||
|
_poll_rss(
|
||||||
|
rss_source,
|
||||||
|
config.rss_poll_interval_seconds,
|
||||||
|
redis,
|
||||||
|
publisher,
|
||||||
|
articles_fetched_counter,
|
||||||
|
fetch_errors_counter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tg.create_task(
|
||||||
|
_poll_reddit(
|
||||||
|
reddit_source,
|
||||||
|
config.reddit_poll_interval_seconds,
|
||||||
|
redis,
|
||||||
|
publisher,
|
||||||
|
articles_fetched_counter,
|
||||||
|
fetch_errors_counter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(run())
|
||||||
1
services/news_fetcher/sources/__init__.py
Normal file
1
services/news_fetcher/sources/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""News source adapters (RSS, Reddit)."""
|
||||||
76
services/news_fetcher/sources/reddit.py
Normal file
76
services/news_fetcher/sources/reddit.py
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
"""Reddit source — fetches hot posts from financial subreddits via asyncpraw."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from shared.schemas.news import RawArticle
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RedditSource:
|
||||||
|
"""Fetches hot posts from Reddit and converts them to :class:`RawArticle`."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
subreddits: list[str],
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
user_agent: str,
|
||||||
|
min_score: int = 10,
|
||||||
|
) -> None:
|
||||||
|
self.subreddits = subreddits
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
self.user_agent = user_agent
|
||||||
|
self.min_score = min_score
|
||||||
|
|
||||||
|
async def fetch(self) -> list[RawArticle]:
|
||||||
|
"""Return hot posts above *min_score* from each configured subreddit.
|
||||||
|
|
||||||
|
Uses ``asyncpraw`` so the caller must run within an ``async`` context.
|
||||||
|
Each Reddit instance is created and closed within this call to avoid
|
||||||
|
leaking sessions across poll cycles.
|
||||||
|
"""
|
||||||
|
import asyncpraw # lazy import so the dep is optional at import time
|
||||||
|
|
||||||
|
articles: list[RawArticle] = []
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
reddit = asyncpraw.Reddit(
|
||||||
|
client_id=self.client_id,
|
||||||
|
client_secret=self.client_secret,
|
||||||
|
user_agent=self.user_agent,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
for sub_name in self.subreddits:
|
||||||
|
try:
|
||||||
|
subreddit = await reddit.subreddit(sub_name)
|
||||||
|
async for post in subreddit.hot(limit=25):
|
||||||
|
if post.score < self.min_score:
|
||||||
|
continue
|
||||||
|
|
||||||
|
content = post.selftext if post.selftext else post.url
|
||||||
|
permalink = post.permalink
|
||||||
|
content_hash = hashlib.sha256(permalink.encode()).hexdigest()
|
||||||
|
published_at = datetime.fromtimestamp(post.created_utc, tz=timezone.utc)
|
||||||
|
|
||||||
|
articles.append(
|
||||||
|
RawArticle(
|
||||||
|
source="reddit",
|
||||||
|
url=f"https://reddit.com{permalink}",
|
||||||
|
title=post.title,
|
||||||
|
content=content,
|
||||||
|
published_at=published_at,
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash=content_hash,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to fetch subreddit r/%s", sub_name)
|
||||||
|
continue
|
||||||
|
finally:
|
||||||
|
await reddit.close()
|
||||||
|
|
||||||
|
return articles
|
||||||
71
services/news_fetcher/sources/rss.py
Normal file
71
services/news_fetcher/sources/rss.py
Normal file
|
|
@ -0,0 +1,71 @@
|
||||||
|
"""RSS feed source — fetches articles from configurable RSS feed URLs."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from email.utils import parsedate_to_datetime
|
||||||
|
|
||||||
|
import feedparser
|
||||||
|
|
||||||
|
from shared.schemas.news import RawArticle
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RSSSource:
|
||||||
|
"""Fetches and converts RSS feed entries to :class:`RawArticle` instances."""
|
||||||
|
|
||||||
|
def __init__(self, feeds: list[str]) -> None:
|
||||||
|
self.feeds = feeds
|
||||||
|
|
||||||
|
async def fetch(self) -> list[RawArticle]:
|
||||||
|
"""Parse every configured feed and return a list of raw articles.
|
||||||
|
|
||||||
|
Feeds that fail to parse are logged and skipped so that a single
|
||||||
|
broken feed does not prevent the others from being collected.
|
||||||
|
"""
|
||||||
|
articles: list[RawArticle] = []
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
for feed_url in self.feeds:
|
||||||
|
try:
|
||||||
|
parsed = feedparser.parse(feed_url)
|
||||||
|
if parsed.bozo and not parsed.entries:
|
||||||
|
logger.warning("Feed %s returned bozo error: %s", feed_url, parsed.bozo_exception)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for entry in parsed.entries:
|
||||||
|
title = entry.get("title", "")
|
||||||
|
link = entry.get("link", "")
|
||||||
|
content = entry.get("summary", "") or entry.get("description", "")
|
||||||
|
|
||||||
|
published_at = self._parse_published(entry)
|
||||||
|
content_hash = hashlib.sha256(f"{link}{title}".encode()).hexdigest()
|
||||||
|
|
||||||
|
articles.append(
|
||||||
|
RawArticle(
|
||||||
|
source="rss",
|
||||||
|
url=link,
|
||||||
|
title=title,
|
||||||
|
content=content,
|
||||||
|
published_at=published_at,
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash=content_hash,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to fetch RSS feed %s", feed_url)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return articles
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_published(entry: dict) -> datetime | None:
|
||||||
|
"""Best-effort parsing of the entry's publication date."""
|
||||||
|
raw = entry.get("published") or entry.get("updated")
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return parsedate_to_datetime(raw)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
0
tests/services/__init__.py
Normal file
0
tests/services/__init__.py
Normal file
391
tests/services/test_news_fetcher.py
Normal file
391
tests/services/test_news_fetcher.py
Normal file
|
|
@ -0,0 +1,391 @@
|
||||||
|
"""Tests for the news fetcher service — RSS, Reddit, deduplication, publishing."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from shared.schemas.news import RawArticle
|
||||||
|
from services.news_fetcher.sources.rss import RSSSource
|
||||||
|
from services.news_fetcher.sources.reddit import RedditSource
|
||||||
|
from services.news_fetcher.main import _deduplicate_and_publish, SEEN_HASHES_KEY
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
FIXTURE_RSS_XML = """\
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<rss version="2.0">
|
||||||
|
<channel>
|
||||||
|
<title>Test Finance Feed</title>
|
||||||
|
<item>
|
||||||
|
<title>AAPL hits record high</title>
|
||||||
|
<link>https://example.com/aapl</link>
|
||||||
|
<description>Apple stock reached an all-time high today.</description>
|
||||||
|
<pubDate>Sat, 22 Feb 2026 12:00:00 GMT</pubDate>
|
||||||
|
</item>
|
||||||
|
<item>
|
||||||
|
<title>TSLA earnings beat</title>
|
||||||
|
<link>https://example.com/tsla</link>
|
||||||
|
<description>Tesla reported stronger-than-expected earnings.</description>
|
||||||
|
<pubDate>Sat, 22 Feb 2026 13:00:00 GMT</pubDate>
|
||||||
|
</item>
|
||||||
|
</channel>
|
||||||
|
</rss>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fake_entry(title: str, link: str, summary: str, published: str | None = None) -> dict:
|
||||||
|
"""Return a dict matching feedparser entry structure."""
|
||||||
|
entry = {"title": title, "link": link, "summary": summary}
|
||||||
|
if published:
|
||||||
|
entry["published"] = published
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
def _make_reddit_post(title, selftext, url, permalink, score, created_utc):
|
||||||
|
"""Return a SimpleNamespace mimicking an asyncpraw Submission."""
|
||||||
|
return SimpleNamespace(
|
||||||
|
title=title,
|
||||||
|
selftext=selftext,
|
||||||
|
url=url,
|
||||||
|
permalink=permalink,
|
||||||
|
score=score,
|
||||||
|
created_utc=created_utc,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RSS tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rss_source_parses_feed():
|
||||||
|
"""feedparser.parse is called for each feed and entries become RawArticles."""
|
||||||
|
fake_parsed = MagicMock()
|
||||||
|
fake_parsed.bozo = False
|
||||||
|
fake_parsed.entries = [
|
||||||
|
_make_fake_entry(
|
||||||
|
"AAPL hits record high",
|
||||||
|
"https://example.com/aapl",
|
||||||
|
"Apple stock reached an all-time high today.",
|
||||||
|
"Sat, 22 Feb 2026 12:00:00 GMT",
|
||||||
|
),
|
||||||
|
_make_fake_entry(
|
||||||
|
"TSLA earnings beat",
|
||||||
|
"https://example.com/tsla",
|
||||||
|
"Tesla reported stronger-than-expected earnings.",
|
||||||
|
"Sat, 22 Feb 2026 13:00:00 GMT",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("services.news_fetcher.sources.rss.feedparser.parse", return_value=fake_parsed) as mock_parse:
|
||||||
|
source = RSSSource(feeds=["https://example.com/feed"])
|
||||||
|
articles = await source.fetch()
|
||||||
|
|
||||||
|
mock_parse.assert_called_once_with("https://example.com/feed")
|
||||||
|
assert len(articles) == 2
|
||||||
|
|
||||||
|
assert articles[0].source == "rss"
|
||||||
|
assert articles[0].title == "AAPL hits record high"
|
||||||
|
assert articles[0].url == "https://example.com/aapl"
|
||||||
|
assert articles[0].content == "Apple stock reached an all-time high today."
|
||||||
|
expected_hash = hashlib.sha256("https://example.com/aaplAAPL hits record high".encode()).hexdigest()
|
||||||
|
assert articles[0].content_hash == expected_hash
|
||||||
|
assert articles[0].published_at is not None
|
||||||
|
|
||||||
|
assert articles[1].title == "TSLA earnings beat"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rss_source_handles_bad_feed():
|
||||||
|
"""A feed that raises an exception is skipped; an empty list is returned."""
|
||||||
|
with patch(
|
||||||
|
"services.news_fetcher.sources.rss.feedparser.parse",
|
||||||
|
side_effect=Exception("network timeout"),
|
||||||
|
):
|
||||||
|
source = RSSSource(feeds=["https://bad-feed.example.com/rss"])
|
||||||
|
articles = await source.fetch()
|
||||||
|
|
||||||
|
assert articles == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rss_source_handles_bozo_feed():
|
||||||
|
"""A bozo feed with no entries is skipped gracefully."""
|
||||||
|
fake_parsed = MagicMock()
|
||||||
|
fake_parsed.bozo = True
|
||||||
|
fake_parsed.entries = []
|
||||||
|
fake_parsed.bozo_exception = "malformed XML"
|
||||||
|
|
||||||
|
with patch("services.news_fetcher.sources.rss.feedparser.parse", return_value=fake_parsed):
|
||||||
|
source = RSSSource(feeds=["https://broken.example.com/rss"])
|
||||||
|
articles = await source.fetch()
|
||||||
|
|
||||||
|
assert articles == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rss_source_multiple_feeds():
|
||||||
|
"""Articles from multiple feeds are combined."""
|
||||||
|
feed1 = MagicMock()
|
||||||
|
feed1.bozo = False
|
||||||
|
feed1.entries = [_make_fake_entry("A", "https://a.com", "content a")]
|
||||||
|
|
||||||
|
feed2 = MagicMock()
|
||||||
|
feed2.bozo = False
|
||||||
|
feed2.entries = [_make_fake_entry("B", "https://b.com", "content b")]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"services.news_fetcher.sources.rss.feedparser.parse",
|
||||||
|
side_effect=[feed1, feed2],
|
||||||
|
):
|
||||||
|
source = RSSSource(feeds=["https://feed1.com", "https://feed2.com"])
|
||||||
|
articles = await source.fetch()
|
||||||
|
|
||||||
|
assert len(articles) == 2
|
||||||
|
assert {a.title for a in articles} == {"A", "B"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reddit tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reddit_source_converts_posts():
|
||||||
|
"""Hot posts are converted to RawArticle with correct fields."""
|
||||||
|
post = _make_reddit_post(
|
||||||
|
title="GME to the moon",
|
||||||
|
selftext="Diamond hands forever",
|
||||||
|
url="https://reddit.com/r/wsb/1",
|
||||||
|
permalink="/r/wallstreetbets/comments/abc123/gme_to_the_moon/",
|
||||||
|
score=500,
|
||||||
|
created_utc=1740200000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an async iterator from the posts
|
||||||
|
async def _hot(limit=25):
|
||||||
|
for p in [post]:
|
||||||
|
yield p
|
||||||
|
|
||||||
|
fake_subreddit = AsyncMock()
|
||||||
|
fake_subreddit.hot = _hot
|
||||||
|
|
||||||
|
fake_reddit = AsyncMock()
|
||||||
|
fake_reddit.subreddit = AsyncMock(return_value=fake_subreddit)
|
||||||
|
fake_reddit.close = AsyncMock()
|
||||||
|
|
||||||
|
with patch("asyncpraw.Reddit", return_value=fake_reddit):
|
||||||
|
source = RedditSource(
|
||||||
|
subreddits=["wallstreetbets"],
|
||||||
|
client_id="test_id",
|
||||||
|
client_secret="test_secret",
|
||||||
|
user_agent="test-agent",
|
||||||
|
min_score=10,
|
||||||
|
)
|
||||||
|
articles = await source.fetch()
|
||||||
|
|
||||||
|
assert len(articles) == 1
|
||||||
|
assert articles[0].source == "reddit"
|
||||||
|
assert articles[0].title == "GME to the moon"
|
||||||
|
assert articles[0].content == "Diamond hands forever"
|
||||||
|
expected_hash = hashlib.sha256(
|
||||||
|
"/r/wallstreetbets/comments/abc123/gme_to_the_moon/".encode()
|
||||||
|
).hexdigest()
|
||||||
|
assert articles[0].content_hash == expected_hash
|
||||||
|
assert "reddit.com" in articles[0].url
|
||||||
|
fake_reddit.close.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reddit_source_filters_by_score():
|
||||||
|
"""Posts below min_score are excluded."""
|
||||||
|
high_score = _make_reddit_post("High", "text", "url", "/r/stocks/high", 100, 1740200000.0)
|
||||||
|
low_score = _make_reddit_post("Low", "text", "url", "/r/stocks/low", 5, 1740200000.0)
|
||||||
|
|
||||||
|
async def _hot(limit=25):
|
||||||
|
for p in [high_score, low_score]:
|
||||||
|
yield p
|
||||||
|
|
||||||
|
fake_subreddit = AsyncMock()
|
||||||
|
fake_subreddit.hot = _hot
|
||||||
|
|
||||||
|
fake_reddit = AsyncMock()
|
||||||
|
fake_reddit.subreddit = AsyncMock(return_value=fake_subreddit)
|
||||||
|
fake_reddit.close = AsyncMock()
|
||||||
|
|
||||||
|
with patch("asyncpraw.Reddit", return_value=fake_reddit):
|
||||||
|
source = RedditSource(
|
||||||
|
subreddits=["stocks"],
|
||||||
|
client_id="id",
|
||||||
|
client_secret="secret",
|
||||||
|
user_agent="agent",
|
||||||
|
min_score=10,
|
||||||
|
)
|
||||||
|
articles = await source.fetch()
|
||||||
|
|
||||||
|
assert len(articles) == 1
|
||||||
|
assert articles[0].title == "High"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reddit_source_uses_url_when_no_selftext():
|
||||||
|
"""When selftext is empty, the post URL is used as content."""
|
||||||
|
post = _make_reddit_post(
|
||||||
|
title="Link post",
|
||||||
|
selftext="",
|
||||||
|
url="https://example.com/article",
|
||||||
|
permalink="/r/investing/link",
|
||||||
|
score=50,
|
||||||
|
created_utc=1740200000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _hot(limit=25):
|
||||||
|
yield post
|
||||||
|
|
||||||
|
fake_subreddit = AsyncMock()
|
||||||
|
fake_subreddit.hot = _hot
|
||||||
|
|
||||||
|
fake_reddit = AsyncMock()
|
||||||
|
fake_reddit.subreddit = AsyncMock(return_value=fake_subreddit)
|
||||||
|
fake_reddit.close = AsyncMock()
|
||||||
|
|
||||||
|
with patch("asyncpraw.Reddit", return_value=fake_reddit):
|
||||||
|
source = RedditSource(
|
||||||
|
subreddits=["investing"],
|
||||||
|
client_id="id",
|
||||||
|
client_secret="secret",
|
||||||
|
user_agent="agent",
|
||||||
|
min_score=10,
|
||||||
|
)
|
||||||
|
articles = await source.fetch()
|
||||||
|
|
||||||
|
assert len(articles) == 1
|
||||||
|
assert articles[0].content == "https://example.com/article"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Deduplication tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduplication_skips_seen_hashes():
|
||||||
|
"""Articles with previously-seen content_hash are not published."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
articles = [
|
||||||
|
RawArticle(
|
||||||
|
source="rss",
|
||||||
|
url="https://example.com/1",
|
||||||
|
title="First",
|
||||||
|
content="Content 1",
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash="hash_new",
|
||||||
|
),
|
||||||
|
RawArticle(
|
||||||
|
source="rss",
|
||||||
|
url="https://example.com/2",
|
||||||
|
title="Already seen",
|
||||||
|
content="Content 2",
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash="hash_old",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
redis = AsyncMock()
|
||||||
|
# First article is new (SADD returns 1), second is duplicate (returns 0)
|
||||||
|
redis.sadd = AsyncMock(side_effect=[1, 0])
|
||||||
|
|
||||||
|
publisher = AsyncMock()
|
||||||
|
publisher.publish = AsyncMock()
|
||||||
|
|
||||||
|
counter = MagicMock()
|
||||||
|
counter.add = MagicMock()
|
||||||
|
error_counter = MagicMock()
|
||||||
|
|
||||||
|
count = await _deduplicate_and_publish(articles, redis, publisher, counter, error_counter)
|
||||||
|
|
||||||
|
assert count == 1
|
||||||
|
publisher.publish.assert_called_once()
|
||||||
|
# Verify the hash was checked for both articles
|
||||||
|
assert redis.sadd.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduplication_publishes_all_new():
|
||||||
|
"""When all articles are new, all are published."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
articles = [
|
||||||
|
RawArticle(
|
||||||
|
source="rss",
|
||||||
|
url=f"https://example.com/{i}",
|
||||||
|
title=f"Article {i}",
|
||||||
|
content=f"Content {i}",
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash=f"hash_{i}",
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
redis = AsyncMock()
|
||||||
|
redis.sadd = AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
publisher = AsyncMock()
|
||||||
|
publisher.publish = AsyncMock()
|
||||||
|
|
||||||
|
counter = MagicMock()
|
||||||
|
counter.add = MagicMock()
|
||||||
|
error_counter = MagicMock()
|
||||||
|
|
||||||
|
count = await _deduplicate_and_publish(articles, redis, publisher, counter, error_counter)
|
||||||
|
|
||||||
|
assert count == 3
|
||||||
|
assert publisher.publish.call_count == 3
|
||||||
|
counter.add.assert_called_once_with(3)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main service integration test (mocked sources + redis)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_main_publishes_to_stream():
|
||||||
|
"""End-to-end: mocked sources produce articles which get published."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
fake_article = RawArticle(
|
||||||
|
source="rss",
|
||||||
|
url="https://example.com/test",
|
||||||
|
title="Test",
|
||||||
|
content="Test content",
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash="unique_hash",
|
||||||
|
)
|
||||||
|
|
||||||
|
redis = AsyncMock()
|
||||||
|
redis.sadd = AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
publisher = AsyncMock()
|
||||||
|
publisher.publish = AsyncMock()
|
||||||
|
|
||||||
|
counter = MagicMock()
|
||||||
|
counter.add = MagicMock()
|
||||||
|
error_counter = MagicMock()
|
||||||
|
|
||||||
|
count = await _deduplicate_and_publish(
|
||||||
|
[fake_article], redis, publisher, counter, error_counter
|
||||||
|
)
|
||||||
|
|
||||||
|
assert count == 1
|
||||||
|
publisher.publish.assert_called_once()
|
||||||
|
# Verify the published data matches the article
|
||||||
|
call_args = publisher.publish.call_args[0][0]
|
||||||
|
assert call_args["title"] == "Test"
|
||||||
|
assert call_args["source"] == "rss"
|
||||||
|
assert call_args["content_hash"] == "unique_hash"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue