trading/tests/test_strategies.py

354 lines
14 KiB
Python
Raw Permalink Normal View History

"""Comprehensive tests for trading strategy implementations."""
import pytest
from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection
from shared.strategies import BaseStrategy, MeanReversionStrategy, MomentumStrategy, NewsDrivenStrategy
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _market(
ticker: str = "AAPL",
price: float = 150.0,
sma_20: float | None = None,
sma_50: float | None = None,
rsi: float | None = None,
volume: float = 1_000_000,
) -> MarketSnapshot:
"""Build a MarketSnapshot with sensible defaults."""
return MarketSnapshot(
ticker=ticker,
current_price=price,
open=price - 1,
high=price + 2,
low=price - 2,
close=price,
volume=volume,
sma_20=sma_20,
sma_50=sma_50,
rsi=rsi,
)
def _sentiment(
ticker: str = "AAPL",
avg_score: float = 0.0,
avg_confidence: float = 0.7,
article_count: int = 5,
) -> SentimentContext:
"""Build a SentimentContext with sensible defaults."""
return SentimentContext(
ticker=ticker,
avg_score=avg_score,
article_count=article_count,
recent_scores=[avg_score],
avg_confidence=avg_confidence,
)
# ===================================================================
# Momentum strategy
# ===================================================================
class TestMomentumStrategy:
"""Tests for :class:`MomentumStrategy`."""
@pytest.fixture()
def strategy(self) -> MomentumStrategy:
return MomentumStrategy()
@pytest.mark.asyncio
async def test_momentum_buy_signal(self, strategy: MomentumStrategy) -> None:
"""Buy when price > sma_20 > sma_50 (uptrend / golden cross)."""
market = _market(price=160.0, sma_20=150.0, sma_50=140.0)
signal = await strategy.evaluate("AAPL", market)
assert signal is not None
assert signal.direction == SignalDirection.LONG
assert signal.ticker == "AAPL"
assert 0 < signal.strength <= 1.0
assert strategy.name in signal.strategy_sources
@pytest.mark.asyncio
async def test_momentum_sell_signal(self, strategy: MomentumStrategy) -> None:
"""Sell when price < sma_20 < sma_50 (downtrend / death cross)."""
market = _market(price=130.0, sma_20=140.0, sma_50=150.0)
signal = await strategy.evaluate("AAPL", market)
assert signal is not None
assert signal.direction == SignalDirection.SHORT
assert signal.ticker == "AAPL"
assert 0 < signal.strength <= 1.0
@pytest.mark.asyncio
async def test_momentum_no_signal_flat(self, strategy: MomentumStrategy) -> None:
"""No signal when price is between the two SMAs (no clear trend)."""
# price between sma_20 and sma_50 — neither condition met.
market = _market(price=145.0, sma_20=140.0, sma_50=150.0)
signal = await strategy.evaluate("AAPL", market)
assert signal is None
@pytest.mark.asyncio
async def test_momentum_missing_sma_returns_none(self, strategy: MomentumStrategy) -> None:
"""Return None when sma_20 or sma_50 is missing."""
# Missing sma_20
market_no_20 = _market(price=150.0, sma_20=None, sma_50=140.0)
assert await strategy.evaluate("AAPL", market_no_20) is None
# Missing sma_50
market_no_50 = _market(price=150.0, sma_20=145.0, sma_50=None)
assert await strategy.evaluate("AAPL", market_no_50) is None
# Both missing
market_both = _market(price=150.0, sma_20=None, sma_50=None)
assert await strategy.evaluate("AAPL", market_both) is None
@pytest.mark.asyncio
async def test_momentum_strength_proportional(self, strategy: MomentumStrategy) -> None:
"""Strength should be proportional to (price - sma_20) / sma_20."""
sma_20 = 100.0
sma_50 = 90.0
# Small distance from SMA-20.
market_small = _market(price=102.0, sma_20=sma_20, sma_50=sma_50)
signal_small = await strategy.evaluate("AAPL", market_small)
# Larger distance from SMA-20.
market_large = _market(price=110.0, sma_20=sma_20, sma_50=sma_50)
signal_large = await strategy.evaluate("AAPL", market_large)
assert signal_small is not None
assert signal_large is not None
assert signal_large.strength > signal_small.strength
# Verify the exact strength for one case.
expected = abs(102.0 - 100.0) / 100.0 # 0.02
assert signal_small.strength == pytest.approx(expected, abs=1e-9)
@pytest.mark.asyncio
async def test_momentum_strength_clamped(self, strategy: MomentumStrategy) -> None:
"""Strength must not exceed 1.0 even with extreme price divergence."""
market = _market(price=300.0, sma_20=100.0, sma_50=90.0)
signal = await strategy.evaluate("AAPL", market)
assert signal is not None
assert signal.strength == 1.0
# ===================================================================
# Mean reversion strategy
# ===================================================================
class TestMeanReversionStrategy:
"""Tests for :class:`MeanReversionStrategy`."""
@pytest.fixture()
def strategy(self) -> MeanReversionStrategy:
return MeanReversionStrategy()
@pytest.mark.asyncio
async def test_mean_reversion_buy_oversold(self, strategy: MeanReversionStrategy) -> None:
"""Buy when RSI < 30 (oversold)."""
market = _market(rsi=20.0)
signal = await strategy.evaluate("AAPL", market)
assert signal is not None
assert signal.direction == SignalDirection.LONG
assert 0 < signal.strength <= 1.0
@pytest.mark.asyncio
async def test_mean_reversion_sell_overbought(self, strategy: MeanReversionStrategy) -> None:
"""Sell when RSI > 70 (overbought)."""
market = _market(rsi=80.0)
signal = await strategy.evaluate("AAPL", market)
assert signal is not None
assert signal.direction == SignalDirection.SHORT
assert 0 < signal.strength <= 1.0
@pytest.mark.asyncio
async def test_mean_reversion_no_signal_neutral(self, strategy: MeanReversionStrategy) -> None:
"""No signal when RSI is in neutral territory (30-70)."""
market = _market(rsi=50.0)
signal = await strategy.evaluate("AAPL", market)
assert signal is None
@pytest.mark.asyncio
async def test_mean_reversion_missing_rsi_returns_none(self, strategy: MeanReversionStrategy) -> None:
"""Return None when RSI is not available."""
market = _market(rsi=None)
assert await strategy.evaluate("AAPL", market) is None
@pytest.mark.asyncio
async def test_mean_reversion_strength_proportional(self, strategy: MeanReversionStrategy) -> None:
"""Strength is proportional to how far RSI is from its threshold."""
# Buy side: lower RSI = stronger signal.
market_mild = _market(rsi=25.0)
signal_mild = await strategy.evaluate("AAPL", market_mild)
market_extreme = _market(rsi=10.0)
signal_extreme = await strategy.evaluate("AAPL", market_extreme)
assert signal_mild is not None
assert signal_extreme is not None
assert signal_extreme.strength > signal_mild.strength
# Verify exact strength for RSI=20: (30 - 20) / 30 = 1/3.
market_20 = _market(rsi=20.0)
signal_20 = await strategy.evaluate("AAPL", market_20)
assert signal_20 is not None
assert signal_20.strength == pytest.approx(10.0 / 30.0, abs=1e-9)
# Sell side: RSI=80: (80 - 70) / 30 = 1/3.
market_80 = _market(rsi=80.0)
signal_80 = await strategy.evaluate("AAPL", market_80)
assert signal_80 is not None
assert signal_80.strength == pytest.approx(10.0 / 30.0, abs=1e-9)
@pytest.mark.asyncio
async def test_mean_reversion_boundary_no_signal(self, strategy: MeanReversionStrategy) -> None:
"""RSI exactly at 30 or 70 should NOT trigger a signal."""
market_30 = _market(rsi=30.0)
assert await strategy.evaluate("AAPL", market_30) is None
market_70 = _market(rsi=70.0)
assert await strategy.evaluate("AAPL", market_70) is None
@pytest.mark.asyncio
async def test_mean_reversion_strength_clamped(self, strategy: MeanReversionStrategy) -> None:
"""Strength is clamped to [0, 1] even at extreme RSI values."""
market = _market(rsi=95.0)
signal = await strategy.evaluate("AAPL", market)
assert signal is not None
assert signal.strength <= 1.0
# RSI=0 => (30-0)/30 = 1.0 exactly.
market_zero = _market(rsi=0.0)
signal_zero = await strategy.evaluate("AAPL", market_zero)
assert signal_zero is not None
assert signal_zero.strength == pytest.approx(1.0, abs=1e-9)
# ===================================================================
# News-driven strategy
# ===================================================================
class TestNewsDrivenStrategy:
"""Tests for :class:`NewsDrivenStrategy`."""
@pytest.fixture()
def strategy(self) -> NewsDrivenStrategy:
return NewsDrivenStrategy()
@pytest.mark.asyncio
async def test_news_driven_buy_positive(self, strategy: NewsDrivenStrategy) -> None:
"""Buy on strongly positive sentiment (score=0.8, confidence=0.7)."""
market = _market()
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7, article_count=5)
signal = await strategy.evaluate("AAPL", market, sentiment)
assert signal is not None
assert signal.direction == SignalDirection.LONG
assert 0 < signal.strength <= 1.0
@pytest.mark.asyncio
async def test_news_driven_sell_negative(self, strategy: NewsDrivenStrategy) -> None:
"""Sell on strongly negative sentiment (score=-0.8, confidence=0.7)."""
market = _market()
sentiment = _sentiment(avg_score=-0.8, avg_confidence=0.7, article_count=5)
signal = await strategy.evaluate("AAPL", market, sentiment)
assert signal is not None
assert signal.direction == SignalDirection.SHORT
assert 0 < signal.strength <= 1.0
@pytest.mark.asyncio
async def test_news_driven_no_signal_low_confidence(self, strategy: NewsDrivenStrategy) -> None:
"""No signal when avg_confidence is too low (<=0.3)."""
market = _market()
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.2, article_count=5)
signal = await strategy.evaluate("AAPL", market, sentiment)
assert signal is None
@pytest.mark.asyncio
async def test_news_driven_no_signal_few_articles(self, strategy: NewsDrivenStrategy) -> None:
"""No signal when article_count < 1."""
market = _market()
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7, article_count=0)
signal = await strategy.evaluate("AAPL", market, sentiment)
assert signal is None
@pytest.mark.asyncio
async def test_news_driven_no_sentiment_returns_none(self, strategy: NewsDrivenStrategy) -> None:
"""Return None when no sentiment context is provided."""
market = _market()
signal = await strategy.evaluate("AAPL", market, sentiment=None)
assert signal is None
@pytest.mark.asyncio
async def test_news_driven_strength_calculation(self, strategy: NewsDrivenStrategy) -> None:
"""Strength = abs(avg_score) * avg_confidence, clamped to [0, 1]."""
market = _market()
# score=0.8, confidence=0.7 => strength = 0.56
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7)
signal = await strategy.evaluate("AAPL", market, sentiment)
assert signal is not None
assert signal.strength == pytest.approx(0.8 * 0.7, abs=1e-9)
# Negative score should yield same strength magnitude.
sentiment_neg = _sentiment(avg_score=-0.8, avg_confidence=0.7)
signal_neg = await strategy.evaluate("AAPL", market, sentiment_neg)
assert signal_neg is not None
assert signal_neg.strength == pytest.approx(0.8 * 0.7, abs=1e-9)
@pytest.mark.asyncio
async def test_news_driven_neutral_score(self, strategy: NewsDrivenStrategy) -> None:
"""No signal when avg_score is between -0.15 and 0.15 (neutral)."""
market = _market()
sentiment = _sentiment(avg_score=0.1, avg_confidence=0.9, article_count=10)
signal = await strategy.evaluate("AAPL", market, sentiment)
assert signal is None
@pytest.mark.asyncio
async def test_news_driven_boundary_confidence(self, strategy: NewsDrivenStrategy) -> None:
"""No signal when avg_confidence is exactly 0.3 (threshold is >0.3)."""
market = _market()
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.3, article_count=5)
signal = await strategy.evaluate("AAPL", market, sentiment)
assert signal is None
# ===================================================================
# Cross-strategy tests
# ===================================================================
class TestStrategyCrossChecks:
"""Tests that apply across all strategy implementations."""
def test_all_strategies_are_base_strategy_subclass(self) -> None:
"""All concrete strategies must inherit from BaseStrategy."""
for cls in (MomentumStrategy, MeanReversionStrategy, NewsDrivenStrategy):
assert issubclass(cls, BaseStrategy), f"{cls.__name__} is not a BaseStrategy subclass"
def test_strategy_names_unique(self) -> None:
"""Every strategy must have a distinct name."""
strategies = [MomentumStrategy(), MeanReversionStrategy(), NewsDrivenStrategy()]
names = [s.name for s in strategies]
assert len(names) == len(set(names)), f"Duplicate strategy names detected: {names}"
def test_strategy_names_non_empty(self) -> None:
"""Every strategy name must be a non-empty string."""
for cls in (MomentumStrategy, MeanReversionStrategy, NewsDrivenStrategy):
instance = cls()
assert isinstance(instance.name, str)
assert len(instance.name) > 0