From 90b52a51444ed0d9a59880ca1b8a9a6a7fa17c0a Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Sun, 22 Feb 2026 15:25:27 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20news=20fetcher=20service=20=E2=80=94=20?= =?UTF-8?q?RSS=20and=20Reddit=20sources?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 +- services/__init__.py | 0 services/news_fetcher/__init__.py | 1 + services/news_fetcher/config.py | 28 ++ services/news_fetcher/main.py | 152 +++++++++ services/news_fetcher/sources/__init__.py | 1 + services/news_fetcher/sources/reddit.py | 76 +++++ services/news_fetcher/sources/rss.py | 71 ++++ tests/services/__init__.py | 0 tests/services/test_news_fetcher.py | 391 ++++++++++++++++++++++ 10 files changed, 722 insertions(+), 2 deletions(-) create mode 100644 services/__init__.py create mode 100644 services/news_fetcher/__init__.py create mode 100644 services/news_fetcher/config.py create mode 100644 services/news_fetcher/main.py create mode 100644 services/news_fetcher/sources/__init__.py create mode 100644 services/news_fetcher/sources/reddit.py create mode 100644 services/news_fetcher/sources/rss.py create mode 100644 tests/services/__init__.py create mode 100644 tests/services/test_news_fetcher.py diff --git a/pyproject.toml b/pyproject.toml index e7d0c55..c0c3043 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ [project.optional-dependencies] api = ["fastapi>=0.110", "uvicorn[standard]>=0.27", "websockets>=12.0", "py-webauthn>=2.0", "pyjwt[crypto]>=2.8"] -news = ["feedparser>=6.0", "praw>=7.7", "httpx>=0.27"] +news = ["feedparser>=6.0", "praw>=7.7", "asyncpraw>=7.7", "httpx>=0.27"] sentiment = ["transformers>=4.38", "torch>=2.2", "ollama>=0.1"] trading = ["alpaca-py>=0.21"] backtester = ["numpy>=1.26", "pandas>=2.2"] @@ -27,7 +27,7 @@ requires = ["setuptools>=70.0"] build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] -include = ["shared*", "tests*"] +include = ["shared*", "services*", "tests*"] [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/news_fetcher/__init__.py b/services/news_fetcher/__init__.py new file mode 100644 index 0000000..86fff9d --- /dev/null +++ b/services/news_fetcher/__init__.py @@ -0,0 +1 @@ +"""News fetcher service — polls RSS feeds and Reddit for financial news.""" diff --git a/services/news_fetcher/config.py b/services/news_fetcher/config.py new file mode 100644 index 0000000..76bd606 --- /dev/null +++ b/services/news_fetcher/config.py @@ -0,0 +1,28 @@ +"""Configuration for the news fetcher service.""" + +from shared.config import BaseConfig + + +class NewsFetcherConfig(BaseConfig): + """News fetcher settings. + + Extends :class:`BaseConfig` with RSS feed URLs, poll intervals, + and Reddit API credentials. All settings can be overridden via + environment variables prefixed with ``TRADING_``. + """ + + # RSS settings + rss_feeds: list[str] = [ + "https://finance.yahoo.com/news/rssindex", + "https://feeds.reuters.com/reuters/businessNews", + "https://feeds.content.dowjones.io/public/rss/mw_topstories", + ] + rss_poll_interval_seconds: int = 300 + + # Reddit settings + reddit_subreddits: list[str] = ["wallstreetbets", "stocks", "investing"] + reddit_poll_interval_seconds: int = 600 + reddit_min_score: int = 10 + reddit_client_id: str = "" + reddit_client_secret: str = "" + reddit_user_agent: str = "trading-bot/0.1" diff --git a/services/news_fetcher/main.py b/services/news_fetcher/main.py new file mode 100644 index 0000000..e90603d --- /dev/null +++ b/services/news_fetcher/main.py @@ -0,0 +1,152 @@ +"""News fetcher service entry point. + +Polls RSS feeds and Reddit on independent schedules, deduplicates +articles by content hash (via a Redis SET), and publishes new articles +to the ``news:raw`` Redis Stream. +""" + +import asyncio +import logging + +from redis.asyncio import Redis + +from shared.redis_streams import StreamPublisher +from shared.telemetry import setup_telemetry +from services.news_fetcher.config import NewsFetcherConfig +from services.news_fetcher.sources.rss import RSSSource +from services.news_fetcher.sources.reddit import RedditSource +from shared.schemas.news import RawArticle + +logger = logging.getLogger(__name__) + +SEEN_HASHES_KEY = "news:seen_hashes" +NEWS_RAW_STREAM = "news:raw" + + +async def _deduplicate_and_publish( + articles: list[RawArticle], + redis: Redis, + publisher: StreamPublisher, + articles_fetched_counter, + fetch_errors_counter, +) -> int: + """Add unseen articles to the ``news:raw`` stream. + + Returns the number of newly published articles. + """ + published = 0 + for article in articles: + # SADD returns 1 if the member was added (i.e. not already present) + added = await redis.sadd(SEEN_HASHES_KEY, article.content_hash) + if added: + await publisher.publish(article.model_dump(mode="json")) + published += 1 + if published: + articles_fetched_counter.add(published) + return published + + +async def _poll_rss( + source: RSSSource, + interval: int, + redis: Redis, + publisher: StreamPublisher, + articles_fetched_counter, + fetch_errors_counter, +) -> None: + """Continuously poll RSS feeds at *interval* seconds.""" + while True: + try: + logger.info("Polling RSS feeds …") + articles = await source.fetch() + count = await _deduplicate_and_publish( + articles, redis, publisher, articles_fetched_counter, fetch_errors_counter + ) + logger.info("RSS poll complete: %d new articles published", count) + except Exception: + logger.exception("RSS poll cycle failed") + fetch_errors_counter.add(1) + await asyncio.sleep(interval) + + +async def _poll_reddit( + source: RedditSource, + interval: int, + redis: Redis, + publisher: StreamPublisher, + articles_fetched_counter, + fetch_errors_counter, +) -> None: + """Continuously poll Reddit at *interval* seconds.""" + while True: + try: + logger.info("Polling Reddit …") + articles = await source.fetch() + count = await _deduplicate_and_publish( + articles, redis, publisher, articles_fetched_counter, fetch_errors_counter + ) + logger.info("Reddit poll complete: %d new articles published", count) + except Exception: + logger.exception("Reddit poll cycle failed") + fetch_errors_counter.add(1) + await asyncio.sleep(interval) + + +async def run() -> None: + """Boot the news fetcher and start polling.""" + config = NewsFetcherConfig() + + logging.basicConfig(level=config.log_level) + logger.info("Starting news fetcher service") + + # Telemetry + meter = setup_telemetry("news-fetcher", config.otel_metrics_port) + articles_fetched_counter = meter.create_counter( + "news.articles_fetched", + description="Total articles fetched and published", + ) + fetch_errors_counter = meter.create_counter( + "news.fetch_errors", + description="Total fetch-cycle errors", + ) + + # Redis + redis = Redis.from_url(config.redis_url, decode_responses=True) + publisher = StreamPublisher(redis, NEWS_RAW_STREAM) + + # Sources + rss_source = RSSSource(feeds=config.rss_feeds) + reddit_source = RedditSource( + subreddits=config.reddit_subreddits, + client_id=config.reddit_client_id, + client_secret=config.reddit_client_secret, + user_agent=config.reddit_user_agent, + min_score=config.reddit_min_score, + ) + + # Run pollers concurrently + async with asyncio.TaskGroup() as tg: + tg.create_task( + _poll_rss( + rss_source, + config.rss_poll_interval_seconds, + redis, + publisher, + articles_fetched_counter, + fetch_errors_counter, + ) + ) + tg.create_task( + _poll_reddit( + reddit_source, + config.reddit_poll_interval_seconds, + redis, + publisher, + articles_fetched_counter, + fetch_errors_counter, + ) + ) + + +if __name__ == "__main__": + asyncio.run(run()) diff --git a/services/news_fetcher/sources/__init__.py b/services/news_fetcher/sources/__init__.py new file mode 100644 index 0000000..fd1bcee --- /dev/null +++ b/services/news_fetcher/sources/__init__.py @@ -0,0 +1 @@ +"""News source adapters (RSS, Reddit).""" diff --git a/services/news_fetcher/sources/reddit.py b/services/news_fetcher/sources/reddit.py new file mode 100644 index 0000000..bd88639 --- /dev/null +++ b/services/news_fetcher/sources/reddit.py @@ -0,0 +1,76 @@ +"""Reddit source — fetches hot posts from financial subreddits via asyncpraw.""" + +import hashlib +import logging +from datetime import datetime, timezone + +from shared.schemas.news import RawArticle + +logger = logging.getLogger(__name__) + + +class RedditSource: + """Fetches hot posts from Reddit and converts them to :class:`RawArticle`.""" + + def __init__( + self, + subreddits: list[str], + client_id: str, + client_secret: str, + user_agent: str, + min_score: int = 10, + ) -> None: + self.subreddits = subreddits + self.client_id = client_id + self.client_secret = client_secret + self.user_agent = user_agent + self.min_score = min_score + + async def fetch(self) -> list[RawArticle]: + """Return hot posts above *min_score* from each configured subreddit. + + Uses ``asyncpraw`` so the caller must run within an ``async`` context. + Each Reddit instance is created and closed within this call to avoid + leaking sessions across poll cycles. + """ + import asyncpraw # lazy import so the dep is optional at import time + + articles: list[RawArticle] = [] + now = datetime.now(timezone.utc) + + reddit = asyncpraw.Reddit( + client_id=self.client_id, + client_secret=self.client_secret, + user_agent=self.user_agent, + ) + try: + for sub_name in self.subreddits: + try: + subreddit = await reddit.subreddit(sub_name) + async for post in subreddit.hot(limit=25): + if post.score < self.min_score: + continue + + content = post.selftext if post.selftext else post.url + permalink = post.permalink + content_hash = hashlib.sha256(permalink.encode()).hexdigest() + published_at = datetime.fromtimestamp(post.created_utc, tz=timezone.utc) + + articles.append( + RawArticle( + source="reddit", + url=f"https://reddit.com{permalink}", + title=post.title, + content=content, + published_at=published_at, + fetched_at=now, + content_hash=content_hash, + ) + ) + except Exception: + logger.exception("Failed to fetch subreddit r/%s", sub_name) + continue + finally: + await reddit.close() + + return articles diff --git a/services/news_fetcher/sources/rss.py b/services/news_fetcher/sources/rss.py new file mode 100644 index 0000000..95766f1 --- /dev/null +++ b/services/news_fetcher/sources/rss.py @@ -0,0 +1,71 @@ +"""RSS feed source — fetches articles from configurable RSS feed URLs.""" + +import hashlib +import logging +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime + +import feedparser + +from shared.schemas.news import RawArticle + +logger = logging.getLogger(__name__) + + +class RSSSource: + """Fetches and converts RSS feed entries to :class:`RawArticle` instances.""" + + def __init__(self, feeds: list[str]) -> None: + self.feeds = feeds + + async def fetch(self) -> list[RawArticle]: + """Parse every configured feed and return a list of raw articles. + + Feeds that fail to parse are logged and skipped so that a single + broken feed does not prevent the others from being collected. + """ + articles: list[RawArticle] = [] + now = datetime.now(timezone.utc) + + for feed_url in self.feeds: + try: + parsed = feedparser.parse(feed_url) + if parsed.bozo and not parsed.entries: + logger.warning("Feed %s returned bozo error: %s", feed_url, parsed.bozo_exception) + continue + + for entry in parsed.entries: + title = entry.get("title", "") + link = entry.get("link", "") + content = entry.get("summary", "") or entry.get("description", "") + + published_at = self._parse_published(entry) + content_hash = hashlib.sha256(f"{link}{title}".encode()).hexdigest() + + articles.append( + RawArticle( + source="rss", + url=link, + title=title, + content=content, + published_at=published_at, + fetched_at=now, + content_hash=content_hash, + ) + ) + except Exception: + logger.exception("Failed to fetch RSS feed %s", feed_url) + continue + + return articles + + @staticmethod + def _parse_published(entry: dict) -> datetime | None: + """Best-effort parsing of the entry's publication date.""" + raw = entry.get("published") or entry.get("updated") + if not raw: + return None + try: + return parsedate_to_datetime(raw) + except Exception: + return None diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/services/test_news_fetcher.py b/tests/services/test_news_fetcher.py new file mode 100644 index 0000000..61402c1 --- /dev/null +++ b/tests/services/test_news_fetcher.py @@ -0,0 +1,391 @@ +"""Tests for the news fetcher service — RSS, Reddit, deduplication, publishing.""" + +import hashlib +from datetime import datetime, timezone +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from shared.schemas.news import RawArticle +from services.news_fetcher.sources.rss import RSSSource +from services.news_fetcher.sources.reddit import RedditSource +from services.news_fetcher.main import _deduplicate_and_publish, SEEN_HASHES_KEY + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +FIXTURE_RSS_XML = """\ + + + + Test Finance Feed + + AAPL hits record high + https://example.com/aapl + Apple stock reached an all-time high today. + Sat, 22 Feb 2026 12:00:00 GMT + + + TSLA earnings beat + https://example.com/tsla + Tesla reported stronger-than-expected earnings. + Sat, 22 Feb 2026 13:00:00 GMT + + + +""" + + +def _make_fake_entry(title: str, link: str, summary: str, published: str | None = None) -> dict: + """Return a dict matching feedparser entry structure.""" + entry = {"title": title, "link": link, "summary": summary} + if published: + entry["published"] = published + return entry + + +def _make_reddit_post(title, selftext, url, permalink, score, created_utc): + """Return a SimpleNamespace mimicking an asyncpraw Submission.""" + return SimpleNamespace( + title=title, + selftext=selftext, + url=url, + permalink=permalink, + score=score, + created_utc=created_utc, + ) + + +# --------------------------------------------------------------------------- +# RSS tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_rss_source_parses_feed(): + """feedparser.parse is called for each feed and entries become RawArticles.""" + fake_parsed = MagicMock() + fake_parsed.bozo = False + fake_parsed.entries = [ + _make_fake_entry( + "AAPL hits record high", + "https://example.com/aapl", + "Apple stock reached an all-time high today.", + "Sat, 22 Feb 2026 12:00:00 GMT", + ), + _make_fake_entry( + "TSLA earnings beat", + "https://example.com/tsla", + "Tesla reported stronger-than-expected earnings.", + "Sat, 22 Feb 2026 13:00:00 GMT", + ), + ] + + with patch("services.news_fetcher.sources.rss.feedparser.parse", return_value=fake_parsed) as mock_parse: + source = RSSSource(feeds=["https://example.com/feed"]) + articles = await source.fetch() + + mock_parse.assert_called_once_with("https://example.com/feed") + assert len(articles) == 2 + + assert articles[0].source == "rss" + assert articles[0].title == "AAPL hits record high" + assert articles[0].url == "https://example.com/aapl" + assert articles[0].content == "Apple stock reached an all-time high today." + expected_hash = hashlib.sha256("https://example.com/aaplAAPL hits record high".encode()).hexdigest() + assert articles[0].content_hash == expected_hash + assert articles[0].published_at is not None + + assert articles[1].title == "TSLA earnings beat" + + +@pytest.mark.asyncio +async def test_rss_source_handles_bad_feed(): + """A feed that raises an exception is skipped; an empty list is returned.""" + with patch( + "services.news_fetcher.sources.rss.feedparser.parse", + side_effect=Exception("network timeout"), + ): + source = RSSSource(feeds=["https://bad-feed.example.com/rss"]) + articles = await source.fetch() + + assert articles == [] + + +@pytest.mark.asyncio +async def test_rss_source_handles_bozo_feed(): + """A bozo feed with no entries is skipped gracefully.""" + fake_parsed = MagicMock() + fake_parsed.bozo = True + fake_parsed.entries = [] + fake_parsed.bozo_exception = "malformed XML" + + with patch("services.news_fetcher.sources.rss.feedparser.parse", return_value=fake_parsed): + source = RSSSource(feeds=["https://broken.example.com/rss"]) + articles = await source.fetch() + + assert articles == [] + + +@pytest.mark.asyncio +async def test_rss_source_multiple_feeds(): + """Articles from multiple feeds are combined.""" + feed1 = MagicMock() + feed1.bozo = False + feed1.entries = [_make_fake_entry("A", "https://a.com", "content a")] + + feed2 = MagicMock() + feed2.bozo = False + feed2.entries = [_make_fake_entry("B", "https://b.com", "content b")] + + with patch( + "services.news_fetcher.sources.rss.feedparser.parse", + side_effect=[feed1, feed2], + ): + source = RSSSource(feeds=["https://feed1.com", "https://feed2.com"]) + articles = await source.fetch() + + assert len(articles) == 2 + assert {a.title for a in articles} == {"A", "B"} + + +# --------------------------------------------------------------------------- +# Reddit tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_reddit_source_converts_posts(): + """Hot posts are converted to RawArticle with correct fields.""" + post = _make_reddit_post( + title="GME to the moon", + selftext="Diamond hands forever", + url="https://reddit.com/r/wsb/1", + permalink="/r/wallstreetbets/comments/abc123/gme_to_the_moon/", + score=500, + created_utc=1740200000.0, + ) + + # Create an async iterator from the posts + async def _hot(limit=25): + for p in [post]: + yield p + + fake_subreddit = AsyncMock() + fake_subreddit.hot = _hot + + fake_reddit = AsyncMock() + fake_reddit.subreddit = AsyncMock(return_value=fake_subreddit) + fake_reddit.close = AsyncMock() + + with patch("asyncpraw.Reddit", return_value=fake_reddit): + source = RedditSource( + subreddits=["wallstreetbets"], + client_id="test_id", + client_secret="test_secret", + user_agent="test-agent", + min_score=10, + ) + articles = await source.fetch() + + assert len(articles) == 1 + assert articles[0].source == "reddit" + assert articles[0].title == "GME to the moon" + assert articles[0].content == "Diamond hands forever" + expected_hash = hashlib.sha256( + "/r/wallstreetbets/comments/abc123/gme_to_the_moon/".encode() + ).hexdigest() + assert articles[0].content_hash == expected_hash + assert "reddit.com" in articles[0].url + fake_reddit.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_reddit_source_filters_by_score(): + """Posts below min_score are excluded.""" + high_score = _make_reddit_post("High", "text", "url", "/r/stocks/high", 100, 1740200000.0) + low_score = _make_reddit_post("Low", "text", "url", "/r/stocks/low", 5, 1740200000.0) + + async def _hot(limit=25): + for p in [high_score, low_score]: + yield p + + fake_subreddit = AsyncMock() + fake_subreddit.hot = _hot + + fake_reddit = AsyncMock() + fake_reddit.subreddit = AsyncMock(return_value=fake_subreddit) + fake_reddit.close = AsyncMock() + + with patch("asyncpraw.Reddit", return_value=fake_reddit): + source = RedditSource( + subreddits=["stocks"], + client_id="id", + client_secret="secret", + user_agent="agent", + min_score=10, + ) + articles = await source.fetch() + + assert len(articles) == 1 + assert articles[0].title == "High" + + +@pytest.mark.asyncio +async def test_reddit_source_uses_url_when_no_selftext(): + """When selftext is empty, the post URL is used as content.""" + post = _make_reddit_post( + title="Link post", + selftext="", + url="https://example.com/article", + permalink="/r/investing/link", + score=50, + created_utc=1740200000.0, + ) + + async def _hot(limit=25): + yield post + + fake_subreddit = AsyncMock() + fake_subreddit.hot = _hot + + fake_reddit = AsyncMock() + fake_reddit.subreddit = AsyncMock(return_value=fake_subreddit) + fake_reddit.close = AsyncMock() + + with patch("asyncpraw.Reddit", return_value=fake_reddit): + source = RedditSource( + subreddits=["investing"], + client_id="id", + client_secret="secret", + user_agent="agent", + min_score=10, + ) + articles = await source.fetch() + + assert len(articles) == 1 + assert articles[0].content == "https://example.com/article" + + +# --------------------------------------------------------------------------- +# Deduplication tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_deduplication_skips_seen_hashes(): + """Articles with previously-seen content_hash are not published.""" + now = datetime.now(timezone.utc) + articles = [ + RawArticle( + source="rss", + url="https://example.com/1", + title="First", + content="Content 1", + fetched_at=now, + content_hash="hash_new", + ), + RawArticle( + source="rss", + url="https://example.com/2", + title="Already seen", + content="Content 2", + fetched_at=now, + content_hash="hash_old", + ), + ] + + redis = AsyncMock() + # First article is new (SADD returns 1), second is duplicate (returns 0) + redis.sadd = AsyncMock(side_effect=[1, 0]) + + publisher = AsyncMock() + publisher.publish = AsyncMock() + + counter = MagicMock() + counter.add = MagicMock() + error_counter = MagicMock() + + count = await _deduplicate_and_publish(articles, redis, publisher, counter, error_counter) + + assert count == 1 + publisher.publish.assert_called_once() + # Verify the hash was checked for both articles + assert redis.sadd.call_count == 2 + + +@pytest.mark.asyncio +async def test_deduplication_publishes_all_new(): + """When all articles are new, all are published.""" + now = datetime.now(timezone.utc) + articles = [ + RawArticle( + source="rss", + url=f"https://example.com/{i}", + title=f"Article {i}", + content=f"Content {i}", + fetched_at=now, + content_hash=f"hash_{i}", + ) + for i in range(3) + ] + + redis = AsyncMock() + redis.sadd = AsyncMock(return_value=1) + + publisher = AsyncMock() + publisher.publish = AsyncMock() + + counter = MagicMock() + counter.add = MagicMock() + error_counter = MagicMock() + + count = await _deduplicate_and_publish(articles, redis, publisher, counter, error_counter) + + assert count == 3 + assert publisher.publish.call_count == 3 + counter.add.assert_called_once_with(3) + + +# --------------------------------------------------------------------------- +# Main service integration test (mocked sources + redis) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_main_publishes_to_stream(): + """End-to-end: mocked sources produce articles which get published.""" + now = datetime.now(timezone.utc) + fake_article = RawArticle( + source="rss", + url="https://example.com/test", + title="Test", + content="Test content", + fetched_at=now, + content_hash="unique_hash", + ) + + redis = AsyncMock() + redis.sadd = AsyncMock(return_value=1) + + publisher = AsyncMock() + publisher.publish = AsyncMock() + + counter = MagicMock() + counter.add = MagicMock() + error_counter = MagicMock() + + count = await _deduplicate_and_publish( + [fake_article], redis, publisher, counter, error_counter + ) + + assert count == 1 + publisher.publish.assert_called_once() + # Verify the published data matches the article + call_args = publisher.publish.call_args[0][0] + assert call_args["title"] == "Test" + assert call_args["source"] == "rss" + assert call_args["content_hash"] == "unique_hash"