"""Comprehensive tests for trading strategy implementations.""" import pytest from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection from shared.strategies import BaseStrategy, MeanReversionStrategy, MomentumStrategy, NewsDrivenStrategy # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _market( ticker: str = "AAPL", price: float = 150.0, sma_20: float | None = None, sma_50: float | None = None, rsi: float | None = None, volume: float = 1_000_000, ) -> MarketSnapshot: """Build a MarketSnapshot with sensible defaults.""" return MarketSnapshot( ticker=ticker, current_price=price, open=price - 1, high=price + 2, low=price - 2, close=price, volume=volume, sma_20=sma_20, sma_50=sma_50, rsi=rsi, ) def _sentiment( ticker: str = "AAPL", avg_score: float = 0.0, avg_confidence: float = 0.7, article_count: int = 5, ) -> SentimentContext: """Build a SentimentContext with sensible defaults.""" return SentimentContext( ticker=ticker, avg_score=avg_score, article_count=article_count, recent_scores=[avg_score], avg_confidence=avg_confidence, ) # =================================================================== # Momentum strategy # =================================================================== class TestMomentumStrategy: """Tests for :class:`MomentumStrategy`.""" @pytest.fixture() def strategy(self) -> MomentumStrategy: return MomentumStrategy() @pytest.mark.asyncio async def test_momentum_buy_signal(self, strategy: MomentumStrategy) -> None: """Buy when price > sma_20 > sma_50 (uptrend / golden cross).""" market = _market(price=160.0, sma_20=150.0, sma_50=140.0) signal = await strategy.evaluate("AAPL", market) assert signal is not None assert signal.direction == SignalDirection.LONG assert signal.ticker == "AAPL" assert 0 < signal.strength <= 1.0 assert strategy.name in signal.strategy_sources @pytest.mark.asyncio async def test_momentum_sell_signal(self, strategy: MomentumStrategy) -> None: """Sell when price < sma_20 < sma_50 (downtrend / death cross).""" market = _market(price=130.0, sma_20=140.0, sma_50=150.0) signal = await strategy.evaluate("AAPL", market) assert signal is not None assert signal.direction == SignalDirection.SHORT assert signal.ticker == "AAPL" assert 0 < signal.strength <= 1.0 @pytest.mark.asyncio async def test_momentum_no_signal_flat(self, strategy: MomentumStrategy) -> None: """No signal when price is between the two SMAs (no clear trend).""" # price between sma_20 and sma_50 — neither condition met. market = _market(price=145.0, sma_20=140.0, sma_50=150.0) signal = await strategy.evaluate("AAPL", market) assert signal is None @pytest.mark.asyncio async def test_momentum_missing_sma_returns_none(self, strategy: MomentumStrategy) -> None: """Return None when sma_20 or sma_50 is missing.""" # Missing sma_20 market_no_20 = _market(price=150.0, sma_20=None, sma_50=140.0) assert await strategy.evaluate("AAPL", market_no_20) is None # Missing sma_50 market_no_50 = _market(price=150.0, sma_20=145.0, sma_50=None) assert await strategy.evaluate("AAPL", market_no_50) is None # Both missing market_both = _market(price=150.0, sma_20=None, sma_50=None) assert await strategy.evaluate("AAPL", market_both) is None @pytest.mark.asyncio async def test_momentum_strength_proportional(self, strategy: MomentumStrategy) -> None: """Strength should be proportional to (price - sma_20) / sma_20.""" sma_20 = 100.0 sma_50 = 90.0 # Small distance from SMA-20. market_small = _market(price=102.0, sma_20=sma_20, sma_50=sma_50) signal_small = await strategy.evaluate("AAPL", market_small) # Larger distance from SMA-20. market_large = _market(price=110.0, sma_20=sma_20, sma_50=sma_50) signal_large = await strategy.evaluate("AAPL", market_large) assert signal_small is not None assert signal_large is not None assert signal_large.strength > signal_small.strength # Verify the exact strength for one case. expected = abs(102.0 - 100.0) / 100.0 # 0.02 assert signal_small.strength == pytest.approx(expected, abs=1e-9) @pytest.mark.asyncio async def test_momentum_strength_clamped(self, strategy: MomentumStrategy) -> None: """Strength must not exceed 1.0 even with extreme price divergence.""" market = _market(price=300.0, sma_20=100.0, sma_50=90.0) signal = await strategy.evaluate("AAPL", market) assert signal is not None assert signal.strength == 1.0 # =================================================================== # Mean reversion strategy # =================================================================== class TestMeanReversionStrategy: """Tests for :class:`MeanReversionStrategy`.""" @pytest.fixture() def strategy(self) -> MeanReversionStrategy: return MeanReversionStrategy() @pytest.mark.asyncio async def test_mean_reversion_buy_oversold(self, strategy: MeanReversionStrategy) -> None: """Buy when RSI < 30 (oversold).""" market = _market(rsi=20.0) signal = await strategy.evaluate("AAPL", market) assert signal is not None assert signal.direction == SignalDirection.LONG assert 0 < signal.strength <= 1.0 @pytest.mark.asyncio async def test_mean_reversion_sell_overbought(self, strategy: MeanReversionStrategy) -> None: """Sell when RSI > 70 (overbought).""" market = _market(rsi=80.0) signal = await strategy.evaluate("AAPL", market) assert signal is not None assert signal.direction == SignalDirection.SHORT assert 0 < signal.strength <= 1.0 @pytest.mark.asyncio async def test_mean_reversion_no_signal_neutral(self, strategy: MeanReversionStrategy) -> None: """No signal when RSI is in neutral territory (30-70).""" market = _market(rsi=50.0) signal = await strategy.evaluate("AAPL", market) assert signal is None @pytest.mark.asyncio async def test_mean_reversion_missing_rsi_returns_none(self, strategy: MeanReversionStrategy) -> None: """Return None when RSI is not available.""" market = _market(rsi=None) assert await strategy.evaluate("AAPL", market) is None @pytest.mark.asyncio async def test_mean_reversion_strength_proportional(self, strategy: MeanReversionStrategy) -> None: """Strength is proportional to how far RSI is from its threshold.""" # Buy side: lower RSI = stronger signal. market_mild = _market(rsi=25.0) signal_mild = await strategy.evaluate("AAPL", market_mild) market_extreme = _market(rsi=10.0) signal_extreme = await strategy.evaluate("AAPL", market_extreme) assert signal_mild is not None assert signal_extreme is not None assert signal_extreme.strength > signal_mild.strength # Verify exact strength for RSI=20: (30 - 20) / 30 = 1/3. market_20 = _market(rsi=20.0) signal_20 = await strategy.evaluate("AAPL", market_20) assert signal_20 is not None assert signal_20.strength == pytest.approx(10.0 / 30.0, abs=1e-9) # Sell side: RSI=80: (80 - 70) / 30 = 1/3. market_80 = _market(rsi=80.0) signal_80 = await strategy.evaluate("AAPL", market_80) assert signal_80 is not None assert signal_80.strength == pytest.approx(10.0 / 30.0, abs=1e-9) @pytest.mark.asyncio async def test_mean_reversion_boundary_no_signal(self, strategy: MeanReversionStrategy) -> None: """RSI exactly at 30 or 70 should NOT trigger a signal.""" market_30 = _market(rsi=30.0) assert await strategy.evaluate("AAPL", market_30) is None market_70 = _market(rsi=70.0) assert await strategy.evaluate("AAPL", market_70) is None @pytest.mark.asyncio async def test_mean_reversion_strength_clamped(self, strategy: MeanReversionStrategy) -> None: """Strength is clamped to [0, 1] even at extreme RSI values.""" market = _market(rsi=95.0) signal = await strategy.evaluate("AAPL", market) assert signal is not None assert signal.strength <= 1.0 # RSI=0 => (30-0)/30 = 1.0 exactly. market_zero = _market(rsi=0.0) signal_zero = await strategy.evaluate("AAPL", market_zero) assert signal_zero is not None assert signal_zero.strength == pytest.approx(1.0, abs=1e-9) # =================================================================== # News-driven strategy # =================================================================== class TestNewsDrivenStrategy: """Tests for :class:`NewsDrivenStrategy`.""" @pytest.fixture() def strategy(self) -> NewsDrivenStrategy: return NewsDrivenStrategy() @pytest.mark.asyncio async def test_news_driven_buy_positive(self, strategy: NewsDrivenStrategy) -> None: """Buy on strongly positive sentiment (score=0.8, confidence=0.7).""" market = _market() sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7, article_count=5) signal = await strategy.evaluate("AAPL", market, sentiment) assert signal is not None assert signal.direction == SignalDirection.LONG assert 0 < signal.strength <= 1.0 @pytest.mark.asyncio async def test_news_driven_sell_negative(self, strategy: NewsDrivenStrategy) -> None: """Sell on strongly negative sentiment (score=-0.8, confidence=0.7).""" market = _market() sentiment = _sentiment(avg_score=-0.8, avg_confidence=0.7, article_count=5) signal = await strategy.evaluate("AAPL", market, sentiment) assert signal is not None assert signal.direction == SignalDirection.SHORT assert 0 < signal.strength <= 1.0 @pytest.mark.asyncio async def test_news_driven_no_signal_low_confidence(self, strategy: NewsDrivenStrategy) -> None: """No signal when avg_confidence is too low (<=0.3).""" market = _market() sentiment = _sentiment(avg_score=0.8, avg_confidence=0.2, article_count=5) signal = await strategy.evaluate("AAPL", market, sentiment) assert signal is None @pytest.mark.asyncio async def test_news_driven_no_signal_few_articles(self, strategy: NewsDrivenStrategy) -> None: """No signal when article_count < 1.""" market = _market() sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7, article_count=0) signal = await strategy.evaluate("AAPL", market, sentiment) assert signal is None @pytest.mark.asyncio async def test_news_driven_no_sentiment_returns_none(self, strategy: NewsDrivenStrategy) -> None: """Return None when no sentiment context is provided.""" market = _market() signal = await strategy.evaluate("AAPL", market, sentiment=None) assert signal is None @pytest.mark.asyncio async def test_news_driven_strength_calculation(self, strategy: NewsDrivenStrategy) -> None: """Strength = abs(avg_score) * avg_confidence, clamped to [0, 1].""" market = _market() # score=0.8, confidence=0.7 => strength = 0.56 sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7) signal = await strategy.evaluate("AAPL", market, sentiment) assert signal is not None assert signal.strength == pytest.approx(0.8 * 0.7, abs=1e-9) # Negative score should yield same strength magnitude. sentiment_neg = _sentiment(avg_score=-0.8, avg_confidence=0.7) signal_neg = await strategy.evaluate("AAPL", market, sentiment_neg) assert signal_neg is not None assert signal_neg.strength == pytest.approx(0.8 * 0.7, abs=1e-9) @pytest.mark.asyncio async def test_news_driven_neutral_score(self, strategy: NewsDrivenStrategy) -> None: """No signal when avg_score is between -0.15 and 0.15 (neutral).""" market = _market() sentiment = _sentiment(avg_score=0.1, avg_confidence=0.9, article_count=10) signal = await strategy.evaluate("AAPL", market, sentiment) assert signal is None @pytest.mark.asyncio async def test_news_driven_boundary_confidence(self, strategy: NewsDrivenStrategy) -> None: """No signal when avg_confidence is exactly 0.3 (threshold is >0.3).""" market = _market() sentiment = _sentiment(avg_score=0.8, avg_confidence=0.3, article_count=5) signal = await strategy.evaluate("AAPL", market, sentiment) assert signal is None # =================================================================== # Cross-strategy tests # =================================================================== class TestStrategyCrossChecks: """Tests that apply across all strategy implementations.""" def test_all_strategies_are_base_strategy_subclass(self) -> None: """All concrete strategies must inherit from BaseStrategy.""" for cls in (MomentumStrategy, MeanReversionStrategy, NewsDrivenStrategy): assert issubclass(cls, BaseStrategy), f"{cls.__name__} is not a BaseStrategy subclass" def test_strategy_names_unique(self) -> None: """Every strategy must have a distinct name.""" strategies = [MomentumStrategy(), MeanReversionStrategy(), NewsDrivenStrategy()] names = [s.name for s in strategies] assert len(names) == len(set(names)), f"Duplicate strategy names detected: {names}" def test_strategy_names_non_empty(self) -> None: """Every strategy name must be a non-empty string.""" for cls in (MomentumStrategy, MeanReversionStrategy, NewsDrivenStrategy): instance = cls() assert isinstance(instance.name, str) assert len(instance.name) > 0