From 60bd1ccd2a3b8fd0df06aabad4571414cf45d9a7 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Sun, 22 Feb 2026 15:32:18 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20trading=20strategies=20=E2=80=94=20mome?= =?UTF-8?q?ntum,=20mean=20reversion,=20news-driven?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- shared/strategies/__init__.py | 25 ++ shared/strategies/base.py | 26 ++ shared/strategies/mean_reversion.py | 56 +++++ shared/strategies/momentum.py | 61 +++++ shared/strategies/news_driven.py | 60 +++++ tests/test_strategies.py | 353 ++++++++++++++++++++++++++++ 6 files changed, 581 insertions(+) create mode 100644 shared/strategies/__init__.py create mode 100644 shared/strategies/base.py create mode 100644 shared/strategies/mean_reversion.py create mode 100644 shared/strategies/momentum.py create mode 100644 shared/strategies/news_driven.py create mode 100644 tests/test_strategies.py diff --git a/shared/strategies/__init__.py b/shared/strategies/__init__.py new file mode 100644 index 0000000..0567b2e --- /dev/null +++ b/shared/strategies/__init__.py @@ -0,0 +1,25 @@ +"""Trading strategy implementations. + +Exports +------- +BaseStrategy + Abstract base class for all strategies. +MomentumStrategy + Trend-following strategy based on SMA cross-overs. +MeanReversionStrategy + RSI-based mean reversion strategy. +NewsDrivenStrategy + News sentiment driven strategy. +""" + +from shared.strategies.base import BaseStrategy +from shared.strategies.mean_reversion import MeanReversionStrategy +from shared.strategies.momentum import MomentumStrategy +from shared.strategies.news_driven import NewsDrivenStrategy + +__all__ = [ + "BaseStrategy", + "MeanReversionStrategy", + "MomentumStrategy", + "NewsDrivenStrategy", +] diff --git a/shared/strategies/base.py b/shared/strategies/base.py new file mode 100644 index 0000000..8b78588 --- /dev/null +++ b/shared/strategies/base.py @@ -0,0 +1,26 @@ +"""Abstract base class for all trading strategies.""" + +from abc import ABC, abstractmethod + +from shared.schemas.trading import MarketSnapshot, SentimentContext, TradeSignal + + +class BaseStrategy(ABC): + """Base class that all trading strategies must inherit from. + + Subclasses implement :meth:`evaluate` to inspect market data and + optionally sentiment, returning a :class:`TradeSignal` when the + strategy has a directional opinion and ``None`` otherwise. + """ + + name: str + + @abstractmethod + async def evaluate( + self, + ticker: str, + market: MarketSnapshot, + sentiment: SentimentContext | None = None, + ) -> TradeSignal | None: + """Return a signal if this strategy has an opinion, None otherwise.""" + ... diff --git a/shared/strategies/mean_reversion.py b/shared/strategies/mean_reversion.py new file mode 100644 index 0000000..8c6b171 --- /dev/null +++ b/shared/strategies/mean_reversion.py @@ -0,0 +1,56 @@ +"""Mean reversion strategy — buy oversold, sell overbought using RSI.""" + +from datetime import datetime, timezone + +from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection, TradeSignal +from shared.strategies.base import BaseStrategy + + +class MeanReversionStrategy(BaseStrategy): + """Trade on the assumption that extreme RSI readings will revert to the mean. + + **Buy signal** (LONG): + RSI < 30 (oversold). + + **Sell signal** (SHORT): + RSI > 70 (overbought). + + Signal strength is proportional to how far the RSI is from its + threshold, clamped to [0, 1]. + + * Buy strength = ``(30 - rsi) / 30`` + * Sell strength = ``(rsi - 70) / 30`` + """ + + name: str = "mean_reversion" + + async def evaluate( + self, + ticker: str, + market: MarketSnapshot, + sentiment: SentimentContext | None = None, + ) -> TradeSignal | None: + if market.rsi is None: + return None + + rsi = market.rsi + + if rsi < 30: + direction = SignalDirection.LONG + raw_strength = (30 - rsi) / 30 + elif rsi > 70: + direction = SignalDirection.SHORT + raw_strength = (rsi - 70) / 30 + else: + # RSI in neutral territory — no opinion. + return None + + strength = max(0.0, min(1.0, raw_strength)) + + return TradeSignal( + ticker=ticker, + direction=direction, + strength=strength, + strategy_sources=[self.name], + timestamp=datetime.now(tz=timezone.utc), + ) diff --git a/shared/strategies/momentum.py b/shared/strategies/momentum.py new file mode 100644 index 0000000..3098594 --- /dev/null +++ b/shared/strategies/momentum.py @@ -0,0 +1,61 @@ +"""Momentum trading strategy — trend-following based on moving averages.""" + +from datetime import datetime, timezone + +from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection, TradeSignal +from shared.strategies.base import BaseStrategy + + +class MomentumStrategy(BaseStrategy): + """Detect and follow momentum via simple moving average cross-overs. + + **Buy signal** (LONG): + ``current_price > sma_20`` AND ``sma_20 > sma_50`` (golden cross / + uptrend) AND volume above the daily open (simple proxy for above- + average volume). + + **Sell signal** (SHORT): + ``current_price < sma_20`` AND ``sma_20 < sma_50`` (death cross / + downtrend). + + Signal strength is proportional to the normalised distance between + the current price and the 20-period SMA, clamped to [0, 1]. + """ + + name: str = "momentum" + + async def evaluate( + self, + ticker: str, + market: MarketSnapshot, + sentiment: SentimentContext | None = None, + ) -> TradeSignal | None: + # Require both moving averages to be present. + if market.sma_20 is None or market.sma_50 is None: + return None + + price = market.current_price + sma_20 = market.sma_20 + sma_50 = market.sma_50 + + direction: SignalDirection | None = None + + if price > sma_20 and sma_20 > sma_50: + direction = SignalDirection.LONG + elif price < sma_20 and sma_20 < sma_50: + direction = SignalDirection.SHORT + else: + # No clear trend — abstain. + return None + + # Strength: normalised distance from SMA-20, clamped to [0, 1]. + raw_strength = abs(price - sma_20) / sma_20 if sma_20 != 0 else 0.0 + strength = max(0.0, min(1.0, raw_strength)) + + return TradeSignal( + ticker=ticker, + direction=direction, + strength=strength, + strategy_sources=[self.name], + timestamp=datetime.now(tz=timezone.utc), + ) diff --git a/shared/strategies/news_driven.py b/shared/strategies/news_driven.py new file mode 100644 index 0000000..1ce4a1b --- /dev/null +++ b/shared/strategies/news_driven.py @@ -0,0 +1,60 @@ +"""News-driven strategy — trade on aggregated news sentiment.""" + +from datetime import datetime, timezone + +from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection, TradeSignal +from shared.strategies.base import BaseStrategy + + +class NewsDrivenStrategy(BaseStrategy): + """Generate signals from aggregated news sentiment for a ticker. + + **Buy signal** (LONG): + ``avg_score > 0.3`` AND ``avg_confidence > 0.5`` AND + ``article_count >= 2``. + + **Sell signal** (SHORT): + ``avg_score < -0.3`` AND ``avg_confidence > 0.5`` AND + ``article_count >= 2``. + + Signal strength = ``abs(avg_score) * avg_confidence``, clamped to + [0, 1]. + """ + + name: str = "news_driven" + + async def evaluate( + self, + ticker: str, + market: MarketSnapshot, + sentiment: SentimentContext | None = None, + ) -> TradeSignal | None: + if sentiment is None: + return None + + # Require at least 2 articles for statistical confidence. + if sentiment.article_count < 2: + return None + + # Require minimum confidence. + if sentiment.avg_confidence <= 0.5: + return None + + if sentiment.avg_score > 0.3: + direction = SignalDirection.LONG + elif sentiment.avg_score < -0.3: + direction = SignalDirection.SHORT + else: + # Sentiment is neutral — no opinion. + return None + + raw_strength = abs(sentiment.avg_score) * sentiment.avg_confidence + strength = max(0.0, min(1.0, raw_strength)) + + return TradeSignal( + ticker=ticker, + direction=direction, + strength=strength, + strategy_sources=[self.name], + timestamp=datetime.now(tz=timezone.utc), + ) diff --git a/tests/test_strategies.py b/tests/test_strategies.py new file mode 100644 index 0000000..16b01a5 --- /dev/null +++ b/tests/test_strategies.py @@ -0,0 +1,353 @@ +"""Comprehensive tests for trading strategy implementations.""" + +import pytest + +from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection +from shared.strategies import BaseStrategy, MeanReversionStrategy, MomentumStrategy, NewsDrivenStrategy + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _market( + ticker: str = "AAPL", + price: float = 150.0, + sma_20: float | None = None, + sma_50: float | None = None, + rsi: float | None = None, + volume: float = 1_000_000, +) -> MarketSnapshot: + """Build a MarketSnapshot with sensible defaults.""" + return MarketSnapshot( + ticker=ticker, + current_price=price, + open=price - 1, + high=price + 2, + low=price - 2, + close=price, + volume=volume, + sma_20=sma_20, + sma_50=sma_50, + rsi=rsi, + ) + + +def _sentiment( + ticker: str = "AAPL", + avg_score: float = 0.0, + avg_confidence: float = 0.7, + article_count: int = 5, +) -> SentimentContext: + """Build a SentimentContext with sensible defaults.""" + return SentimentContext( + ticker=ticker, + avg_score=avg_score, + article_count=article_count, + recent_scores=[avg_score], + avg_confidence=avg_confidence, + ) + + +# =================================================================== +# Momentum strategy +# =================================================================== + + +class TestMomentumStrategy: + """Tests for :class:`MomentumStrategy`.""" + + @pytest.fixture() + def strategy(self) -> MomentumStrategy: + return MomentumStrategy() + + @pytest.mark.asyncio + async def test_momentum_buy_signal(self, strategy: MomentumStrategy) -> None: + """Buy when price > sma_20 > sma_50 (uptrend / golden cross).""" + market = _market(price=160.0, sma_20=150.0, sma_50=140.0) + signal = await strategy.evaluate("AAPL", market) + + assert signal is not None + assert signal.direction == SignalDirection.LONG + assert signal.ticker == "AAPL" + assert 0 < signal.strength <= 1.0 + assert strategy.name in signal.strategy_sources + + @pytest.mark.asyncio + async def test_momentum_sell_signal(self, strategy: MomentumStrategy) -> None: + """Sell when price < sma_20 < sma_50 (downtrend / death cross).""" + market = _market(price=130.0, sma_20=140.0, sma_50=150.0) + signal = await strategy.evaluate("AAPL", market) + + assert signal is not None + assert signal.direction == SignalDirection.SHORT + assert signal.ticker == "AAPL" + assert 0 < signal.strength <= 1.0 + + @pytest.mark.asyncio + async def test_momentum_no_signal_flat(self, strategy: MomentumStrategy) -> None: + """No signal when price is between the two SMAs (no clear trend).""" + # price between sma_20 and sma_50 — neither condition met. + market = _market(price=145.0, sma_20=140.0, sma_50=150.0) + signal = await strategy.evaluate("AAPL", market) + assert signal is None + + @pytest.mark.asyncio + async def test_momentum_missing_sma_returns_none(self, strategy: MomentumStrategy) -> None: + """Return None when sma_20 or sma_50 is missing.""" + # Missing sma_20 + market_no_20 = _market(price=150.0, sma_20=None, sma_50=140.0) + assert await strategy.evaluate("AAPL", market_no_20) is None + + # Missing sma_50 + market_no_50 = _market(price=150.0, sma_20=145.0, sma_50=None) + assert await strategy.evaluate("AAPL", market_no_50) is None + + # Both missing + market_both = _market(price=150.0, sma_20=None, sma_50=None) + assert await strategy.evaluate("AAPL", market_both) is None + + @pytest.mark.asyncio + async def test_momentum_strength_proportional(self, strategy: MomentumStrategy) -> None: + """Strength should be proportional to (price - sma_20) / sma_20.""" + sma_20 = 100.0 + sma_50 = 90.0 + + # Small distance from SMA-20. + market_small = _market(price=102.0, sma_20=sma_20, sma_50=sma_50) + signal_small = await strategy.evaluate("AAPL", market_small) + + # Larger distance from SMA-20. + market_large = _market(price=110.0, sma_20=sma_20, sma_50=sma_50) + signal_large = await strategy.evaluate("AAPL", market_large) + + assert signal_small is not None + assert signal_large is not None + assert signal_large.strength > signal_small.strength + + # Verify the exact strength for one case. + expected = abs(102.0 - 100.0) / 100.0 # 0.02 + assert signal_small.strength == pytest.approx(expected, abs=1e-9) + + @pytest.mark.asyncio + async def test_momentum_strength_clamped(self, strategy: MomentumStrategy) -> None: + """Strength must not exceed 1.0 even with extreme price divergence.""" + market = _market(price=300.0, sma_20=100.0, sma_50=90.0) + signal = await strategy.evaluate("AAPL", market) + + assert signal is not None + assert signal.strength == 1.0 + + +# =================================================================== +# Mean reversion strategy +# =================================================================== + + +class TestMeanReversionStrategy: + """Tests for :class:`MeanReversionStrategy`.""" + + @pytest.fixture() + def strategy(self) -> MeanReversionStrategy: + return MeanReversionStrategy() + + @pytest.mark.asyncio + async def test_mean_reversion_buy_oversold(self, strategy: MeanReversionStrategy) -> None: + """Buy when RSI < 30 (oversold).""" + market = _market(rsi=20.0) + signal = await strategy.evaluate("AAPL", market) + + assert signal is not None + assert signal.direction == SignalDirection.LONG + assert 0 < signal.strength <= 1.0 + + @pytest.mark.asyncio + async def test_mean_reversion_sell_overbought(self, strategy: MeanReversionStrategy) -> None: + """Sell when RSI > 70 (overbought).""" + market = _market(rsi=80.0) + signal = await strategy.evaluate("AAPL", market) + + assert signal is not None + assert signal.direction == SignalDirection.SHORT + assert 0 < signal.strength <= 1.0 + + @pytest.mark.asyncio + async def test_mean_reversion_no_signal_neutral(self, strategy: MeanReversionStrategy) -> None: + """No signal when RSI is in neutral territory (30-70).""" + market = _market(rsi=50.0) + signal = await strategy.evaluate("AAPL", market) + assert signal is None + + @pytest.mark.asyncio + async def test_mean_reversion_missing_rsi_returns_none(self, strategy: MeanReversionStrategy) -> None: + """Return None when RSI is not available.""" + market = _market(rsi=None) + assert await strategy.evaluate("AAPL", market) is None + + @pytest.mark.asyncio + async def test_mean_reversion_strength_proportional(self, strategy: MeanReversionStrategy) -> None: + """Strength is proportional to how far RSI is from its threshold.""" + # Buy side: lower RSI = stronger signal. + market_mild = _market(rsi=25.0) + signal_mild = await strategy.evaluate("AAPL", market_mild) + + market_extreme = _market(rsi=10.0) + signal_extreme = await strategy.evaluate("AAPL", market_extreme) + + assert signal_mild is not None + assert signal_extreme is not None + assert signal_extreme.strength > signal_mild.strength + + # Verify exact strength for RSI=20: (30 - 20) / 30 = 1/3. + market_20 = _market(rsi=20.0) + signal_20 = await strategy.evaluate("AAPL", market_20) + assert signal_20 is not None + assert signal_20.strength == pytest.approx(10.0 / 30.0, abs=1e-9) + + # Sell side: RSI=80: (80 - 70) / 30 = 1/3. + market_80 = _market(rsi=80.0) + signal_80 = await strategy.evaluate("AAPL", market_80) + assert signal_80 is not None + assert signal_80.strength == pytest.approx(10.0 / 30.0, abs=1e-9) + + @pytest.mark.asyncio + async def test_mean_reversion_boundary_no_signal(self, strategy: MeanReversionStrategy) -> None: + """RSI exactly at 30 or 70 should NOT trigger a signal.""" + market_30 = _market(rsi=30.0) + assert await strategy.evaluate("AAPL", market_30) is None + + market_70 = _market(rsi=70.0) + assert await strategy.evaluate("AAPL", market_70) is None + + @pytest.mark.asyncio + async def test_mean_reversion_strength_clamped(self, strategy: MeanReversionStrategy) -> None: + """Strength is clamped to [0, 1] even at extreme RSI values.""" + market = _market(rsi=95.0) + signal = await strategy.evaluate("AAPL", market) + assert signal is not None + assert signal.strength <= 1.0 + + # RSI=0 => (30-0)/30 = 1.0 exactly. + market_zero = _market(rsi=0.0) + signal_zero = await strategy.evaluate("AAPL", market_zero) + assert signal_zero is not None + assert signal_zero.strength == pytest.approx(1.0, abs=1e-9) + + +# =================================================================== +# News-driven strategy +# =================================================================== + + +class TestNewsDrivenStrategy: + """Tests for :class:`NewsDrivenStrategy`.""" + + @pytest.fixture() + def strategy(self) -> NewsDrivenStrategy: + return NewsDrivenStrategy() + + @pytest.mark.asyncio + async def test_news_driven_buy_positive(self, strategy: NewsDrivenStrategy) -> None: + """Buy on strongly positive sentiment (score=0.8, confidence=0.7).""" + market = _market() + sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7, article_count=5) + signal = await strategy.evaluate("AAPL", market, sentiment) + + assert signal is not None + assert signal.direction == SignalDirection.LONG + assert 0 < signal.strength <= 1.0 + + @pytest.mark.asyncio + async def test_news_driven_sell_negative(self, strategy: NewsDrivenStrategy) -> None: + """Sell on strongly negative sentiment (score=-0.8, confidence=0.7).""" + market = _market() + sentiment = _sentiment(avg_score=-0.8, avg_confidence=0.7, article_count=5) + signal = await strategy.evaluate("AAPL", market, sentiment) + + assert signal is not None + assert signal.direction == SignalDirection.SHORT + assert 0 < signal.strength <= 1.0 + + @pytest.mark.asyncio + async def test_news_driven_no_signal_low_confidence(self, strategy: NewsDrivenStrategy) -> None: + """No signal when avg_confidence is too low (<=0.5).""" + market = _market() + sentiment = _sentiment(avg_score=0.8, avg_confidence=0.4, article_count=5) + signal = await strategy.evaluate("AAPL", market, sentiment) + assert signal is None + + @pytest.mark.asyncio + async def test_news_driven_no_signal_few_articles(self, strategy: NewsDrivenStrategy) -> None: + """No signal when article_count < 2.""" + market = _market() + sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7, article_count=1) + signal = await strategy.evaluate("AAPL", market, sentiment) + assert signal is None + + @pytest.mark.asyncio + async def test_news_driven_no_sentiment_returns_none(self, strategy: NewsDrivenStrategy) -> None: + """Return None when no sentiment context is provided.""" + market = _market() + signal = await strategy.evaluate("AAPL", market, sentiment=None) + assert signal is None + + @pytest.mark.asyncio + async def test_news_driven_strength_calculation(self, strategy: NewsDrivenStrategy) -> None: + """Strength = abs(avg_score) * avg_confidence, clamped to [0, 1].""" + market = _market() + + # score=0.8, confidence=0.7 => strength = 0.56 + sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7) + signal = await strategy.evaluate("AAPL", market, sentiment) + assert signal is not None + assert signal.strength == pytest.approx(0.8 * 0.7, abs=1e-9) + + # Negative score should yield same strength magnitude. + sentiment_neg = _sentiment(avg_score=-0.8, avg_confidence=0.7) + signal_neg = await strategy.evaluate("AAPL", market, sentiment_neg) + assert signal_neg is not None + assert signal_neg.strength == pytest.approx(0.8 * 0.7, abs=1e-9) + + @pytest.mark.asyncio + async def test_news_driven_neutral_score(self, strategy: NewsDrivenStrategy) -> None: + """No signal when avg_score is between -0.3 and 0.3 (neutral).""" + market = _market() + sentiment = _sentiment(avg_score=0.1, avg_confidence=0.9, article_count=10) + signal = await strategy.evaluate("AAPL", market, sentiment) + assert signal is None + + @pytest.mark.asyncio + async def test_news_driven_boundary_confidence(self, strategy: NewsDrivenStrategy) -> None: + """No signal when avg_confidence is exactly 0.5 (threshold is >0.5).""" + market = _market() + sentiment = _sentiment(avg_score=0.8, avg_confidence=0.5, article_count=5) + signal = await strategy.evaluate("AAPL", market, sentiment) + assert signal is None + + +# =================================================================== +# Cross-strategy tests +# =================================================================== + + +class TestStrategyCrossChecks: + """Tests that apply across all strategy implementations.""" + + def test_all_strategies_are_base_strategy_subclass(self) -> None: + """All concrete strategies must inherit from BaseStrategy.""" + for cls in (MomentumStrategy, MeanReversionStrategy, NewsDrivenStrategy): + assert issubclass(cls, BaseStrategy), f"{cls.__name__} is not a BaseStrategy subclass" + + def test_strategy_names_unique(self) -> None: + """Every strategy must have a distinct name.""" + strategies = [MomentumStrategy(), MeanReversionStrategy(), NewsDrivenStrategy()] + names = [s.name for s in strategies] + assert len(names) == len(set(names)), f"Duplicate strategy names detected: {names}" + + def test_strategy_names_non_empty(self) -> None: + """Every strategy name must be a non-empty string.""" + for cls in (MomentumStrategy, MeanReversionStrategy, NewsDrivenStrategy): + instance = cls() + assert isinstance(instance.name, str) + assert len(instance.name) > 0