feat: trading strategies — momentum, mean reversion, news-driven
This commit is contained in:
parent
e483e9987f
commit
60bd1ccd2a
6 changed files with 581 additions and 0 deletions
25
shared/strategies/__init__.py
Normal file
25
shared/strategies/__init__.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
"""Trading strategy implementations.
|
||||
|
||||
Exports
|
||||
-------
|
||||
BaseStrategy
|
||||
Abstract base class for all strategies.
|
||||
MomentumStrategy
|
||||
Trend-following strategy based on SMA cross-overs.
|
||||
MeanReversionStrategy
|
||||
RSI-based mean reversion strategy.
|
||||
NewsDrivenStrategy
|
||||
News sentiment driven strategy.
|
||||
"""
|
||||
|
||||
from shared.strategies.base import BaseStrategy
|
||||
from shared.strategies.mean_reversion import MeanReversionStrategy
|
||||
from shared.strategies.momentum import MomentumStrategy
|
||||
from shared.strategies.news_driven import NewsDrivenStrategy
|
||||
|
||||
__all__ = [
|
||||
"BaseStrategy",
|
||||
"MeanReversionStrategy",
|
||||
"MomentumStrategy",
|
||||
"NewsDrivenStrategy",
|
||||
]
|
||||
26
shared/strategies/base.py
Normal file
26
shared/strategies/base.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""Abstract base class for all trading strategies."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from shared.schemas.trading import MarketSnapshot, SentimentContext, TradeSignal
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
"""Base class that all trading strategies must inherit from.
|
||||
|
||||
Subclasses implement :meth:`evaluate` to inspect market data and
|
||||
optionally sentiment, returning a :class:`TradeSignal` when the
|
||||
strategy has a directional opinion and ``None`` otherwise.
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
@abstractmethod
|
||||
async def evaluate(
|
||||
self,
|
||||
ticker: str,
|
||||
market: MarketSnapshot,
|
||||
sentiment: SentimentContext | None = None,
|
||||
) -> TradeSignal | None:
|
||||
"""Return a signal if this strategy has an opinion, None otherwise."""
|
||||
...
|
||||
56
shared/strategies/mean_reversion.py
Normal file
56
shared/strategies/mean_reversion.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""Mean reversion strategy — buy oversold, sell overbought using RSI."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection, TradeSignal
|
||||
from shared.strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class MeanReversionStrategy(BaseStrategy):
|
||||
"""Trade on the assumption that extreme RSI readings will revert to the mean.
|
||||
|
||||
**Buy signal** (LONG):
|
||||
RSI < 30 (oversold).
|
||||
|
||||
**Sell signal** (SHORT):
|
||||
RSI > 70 (overbought).
|
||||
|
||||
Signal strength is proportional to how far the RSI is from its
|
||||
threshold, clamped to [0, 1].
|
||||
|
||||
* Buy strength = ``(30 - rsi) / 30``
|
||||
* Sell strength = ``(rsi - 70) / 30``
|
||||
"""
|
||||
|
||||
name: str = "mean_reversion"
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
ticker: str,
|
||||
market: MarketSnapshot,
|
||||
sentiment: SentimentContext | None = None,
|
||||
) -> TradeSignal | None:
|
||||
if market.rsi is None:
|
||||
return None
|
||||
|
||||
rsi = market.rsi
|
||||
|
||||
if rsi < 30:
|
||||
direction = SignalDirection.LONG
|
||||
raw_strength = (30 - rsi) / 30
|
||||
elif rsi > 70:
|
||||
direction = SignalDirection.SHORT
|
||||
raw_strength = (rsi - 70) / 30
|
||||
else:
|
||||
# RSI in neutral territory — no opinion.
|
||||
return None
|
||||
|
||||
strength = max(0.0, min(1.0, raw_strength))
|
||||
|
||||
return TradeSignal(
|
||||
ticker=ticker,
|
||||
direction=direction,
|
||||
strength=strength,
|
||||
strategy_sources=[self.name],
|
||||
timestamp=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
61
shared/strategies/momentum.py
Normal file
61
shared/strategies/momentum.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""Momentum trading strategy — trend-following based on moving averages."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection, TradeSignal
|
||||
from shared.strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class MomentumStrategy(BaseStrategy):
|
||||
"""Detect and follow momentum via simple moving average cross-overs.
|
||||
|
||||
**Buy signal** (LONG):
|
||||
``current_price > sma_20`` AND ``sma_20 > sma_50`` (golden cross /
|
||||
uptrend) AND volume above the daily open (simple proxy for above-
|
||||
average volume).
|
||||
|
||||
**Sell signal** (SHORT):
|
||||
``current_price < sma_20`` AND ``sma_20 < sma_50`` (death cross /
|
||||
downtrend).
|
||||
|
||||
Signal strength is proportional to the normalised distance between
|
||||
the current price and the 20-period SMA, clamped to [0, 1].
|
||||
"""
|
||||
|
||||
name: str = "momentum"
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
ticker: str,
|
||||
market: MarketSnapshot,
|
||||
sentiment: SentimentContext | None = None,
|
||||
) -> TradeSignal | None:
|
||||
# Require both moving averages to be present.
|
||||
if market.sma_20 is None or market.sma_50 is None:
|
||||
return None
|
||||
|
||||
price = market.current_price
|
||||
sma_20 = market.sma_20
|
||||
sma_50 = market.sma_50
|
||||
|
||||
direction: SignalDirection | None = None
|
||||
|
||||
if price > sma_20 and sma_20 > sma_50:
|
||||
direction = SignalDirection.LONG
|
||||
elif price < sma_20 and sma_20 < sma_50:
|
||||
direction = SignalDirection.SHORT
|
||||
else:
|
||||
# No clear trend — abstain.
|
||||
return None
|
||||
|
||||
# Strength: normalised distance from SMA-20, clamped to [0, 1].
|
||||
raw_strength = abs(price - sma_20) / sma_20 if sma_20 != 0 else 0.0
|
||||
strength = max(0.0, min(1.0, raw_strength))
|
||||
|
||||
return TradeSignal(
|
||||
ticker=ticker,
|
||||
direction=direction,
|
||||
strength=strength,
|
||||
strategy_sources=[self.name],
|
||||
timestamp=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
60
shared/strategies/news_driven.py
Normal file
60
shared/strategies/news_driven.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""News-driven strategy — trade on aggregated news sentiment."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection, TradeSignal
|
||||
from shared.strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class NewsDrivenStrategy(BaseStrategy):
|
||||
"""Generate signals from aggregated news sentiment for a ticker.
|
||||
|
||||
**Buy signal** (LONG):
|
||||
``avg_score > 0.3`` AND ``avg_confidence > 0.5`` AND
|
||||
``article_count >= 2``.
|
||||
|
||||
**Sell signal** (SHORT):
|
||||
``avg_score < -0.3`` AND ``avg_confidence > 0.5`` AND
|
||||
``article_count >= 2``.
|
||||
|
||||
Signal strength = ``abs(avg_score) * avg_confidence``, clamped to
|
||||
[0, 1].
|
||||
"""
|
||||
|
||||
name: str = "news_driven"
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
ticker: str,
|
||||
market: MarketSnapshot,
|
||||
sentiment: SentimentContext | None = None,
|
||||
) -> TradeSignal | None:
|
||||
if sentiment is None:
|
||||
return None
|
||||
|
||||
# Require at least 2 articles for statistical confidence.
|
||||
if sentiment.article_count < 2:
|
||||
return None
|
||||
|
||||
# Require minimum confidence.
|
||||
if sentiment.avg_confidence <= 0.5:
|
||||
return None
|
||||
|
||||
if sentiment.avg_score > 0.3:
|
||||
direction = SignalDirection.LONG
|
||||
elif sentiment.avg_score < -0.3:
|
||||
direction = SignalDirection.SHORT
|
||||
else:
|
||||
# Sentiment is neutral — no opinion.
|
||||
return None
|
||||
|
||||
raw_strength = abs(sentiment.avg_score) * sentiment.avg_confidence
|
||||
strength = max(0.0, min(1.0, raw_strength))
|
||||
|
||||
return TradeSignal(
|
||||
ticker=ticker,
|
||||
direction=direction,
|
||||
strength=strength,
|
||||
strategy_sources=[self.name],
|
||||
timestamp=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
353
tests/test_strategies.py
Normal file
353
tests/test_strategies.py
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
"""Comprehensive tests for trading strategy implementations."""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.schemas.trading import MarketSnapshot, SentimentContext, SignalDirection
|
||||
from shared.strategies import BaseStrategy, MeanReversionStrategy, MomentumStrategy, NewsDrivenStrategy
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _market(
|
||||
ticker: str = "AAPL",
|
||||
price: float = 150.0,
|
||||
sma_20: float | None = None,
|
||||
sma_50: float | None = None,
|
||||
rsi: float | None = None,
|
||||
volume: float = 1_000_000,
|
||||
) -> MarketSnapshot:
|
||||
"""Build a MarketSnapshot with sensible defaults."""
|
||||
return MarketSnapshot(
|
||||
ticker=ticker,
|
||||
current_price=price,
|
||||
open=price - 1,
|
||||
high=price + 2,
|
||||
low=price - 2,
|
||||
close=price,
|
||||
volume=volume,
|
||||
sma_20=sma_20,
|
||||
sma_50=sma_50,
|
||||
rsi=rsi,
|
||||
)
|
||||
|
||||
|
||||
def _sentiment(
|
||||
ticker: str = "AAPL",
|
||||
avg_score: float = 0.0,
|
||||
avg_confidence: float = 0.7,
|
||||
article_count: int = 5,
|
||||
) -> SentimentContext:
|
||||
"""Build a SentimentContext with sensible defaults."""
|
||||
return SentimentContext(
|
||||
ticker=ticker,
|
||||
avg_score=avg_score,
|
||||
article_count=article_count,
|
||||
recent_scores=[avg_score],
|
||||
avg_confidence=avg_confidence,
|
||||
)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Momentum strategy
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestMomentumStrategy:
|
||||
"""Tests for :class:`MomentumStrategy`."""
|
||||
|
||||
@pytest.fixture()
|
||||
def strategy(self) -> MomentumStrategy:
|
||||
return MomentumStrategy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_momentum_buy_signal(self, strategy: MomentumStrategy) -> None:
|
||||
"""Buy when price > sma_20 > sma_50 (uptrend / golden cross)."""
|
||||
market = _market(price=160.0, sma_20=150.0, sma_50=140.0)
|
||||
signal = await strategy.evaluate("AAPL", market)
|
||||
|
||||
assert signal is not None
|
||||
assert signal.direction == SignalDirection.LONG
|
||||
assert signal.ticker == "AAPL"
|
||||
assert 0 < signal.strength <= 1.0
|
||||
assert strategy.name in signal.strategy_sources
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_momentum_sell_signal(self, strategy: MomentumStrategy) -> None:
|
||||
"""Sell when price < sma_20 < sma_50 (downtrend / death cross)."""
|
||||
market = _market(price=130.0, sma_20=140.0, sma_50=150.0)
|
||||
signal = await strategy.evaluate("AAPL", market)
|
||||
|
||||
assert signal is not None
|
||||
assert signal.direction == SignalDirection.SHORT
|
||||
assert signal.ticker == "AAPL"
|
||||
assert 0 < signal.strength <= 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_momentum_no_signal_flat(self, strategy: MomentumStrategy) -> None:
|
||||
"""No signal when price is between the two SMAs (no clear trend)."""
|
||||
# price between sma_20 and sma_50 — neither condition met.
|
||||
market = _market(price=145.0, sma_20=140.0, sma_50=150.0)
|
||||
signal = await strategy.evaluate("AAPL", market)
|
||||
assert signal is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_momentum_missing_sma_returns_none(self, strategy: MomentumStrategy) -> None:
|
||||
"""Return None when sma_20 or sma_50 is missing."""
|
||||
# Missing sma_20
|
||||
market_no_20 = _market(price=150.0, sma_20=None, sma_50=140.0)
|
||||
assert await strategy.evaluate("AAPL", market_no_20) is None
|
||||
|
||||
# Missing sma_50
|
||||
market_no_50 = _market(price=150.0, sma_20=145.0, sma_50=None)
|
||||
assert await strategy.evaluate("AAPL", market_no_50) is None
|
||||
|
||||
# Both missing
|
||||
market_both = _market(price=150.0, sma_20=None, sma_50=None)
|
||||
assert await strategy.evaluate("AAPL", market_both) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_momentum_strength_proportional(self, strategy: MomentumStrategy) -> None:
|
||||
"""Strength should be proportional to (price - sma_20) / sma_20."""
|
||||
sma_20 = 100.0
|
||||
sma_50 = 90.0
|
||||
|
||||
# Small distance from SMA-20.
|
||||
market_small = _market(price=102.0, sma_20=sma_20, sma_50=sma_50)
|
||||
signal_small = await strategy.evaluate("AAPL", market_small)
|
||||
|
||||
# Larger distance from SMA-20.
|
||||
market_large = _market(price=110.0, sma_20=sma_20, sma_50=sma_50)
|
||||
signal_large = await strategy.evaluate("AAPL", market_large)
|
||||
|
||||
assert signal_small is not None
|
||||
assert signal_large is not None
|
||||
assert signal_large.strength > signal_small.strength
|
||||
|
||||
# Verify the exact strength for one case.
|
||||
expected = abs(102.0 - 100.0) / 100.0 # 0.02
|
||||
assert signal_small.strength == pytest.approx(expected, abs=1e-9)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_momentum_strength_clamped(self, strategy: MomentumStrategy) -> None:
|
||||
"""Strength must not exceed 1.0 even with extreme price divergence."""
|
||||
market = _market(price=300.0, sma_20=100.0, sma_50=90.0)
|
||||
signal = await strategy.evaluate("AAPL", market)
|
||||
|
||||
assert signal is not None
|
||||
assert signal.strength == 1.0
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Mean reversion strategy
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestMeanReversionStrategy:
|
||||
"""Tests for :class:`MeanReversionStrategy`."""
|
||||
|
||||
@pytest.fixture()
|
||||
def strategy(self) -> MeanReversionStrategy:
|
||||
return MeanReversionStrategy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mean_reversion_buy_oversold(self, strategy: MeanReversionStrategy) -> None:
|
||||
"""Buy when RSI < 30 (oversold)."""
|
||||
market = _market(rsi=20.0)
|
||||
signal = await strategy.evaluate("AAPL", market)
|
||||
|
||||
assert signal is not None
|
||||
assert signal.direction == SignalDirection.LONG
|
||||
assert 0 < signal.strength <= 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mean_reversion_sell_overbought(self, strategy: MeanReversionStrategy) -> None:
|
||||
"""Sell when RSI > 70 (overbought)."""
|
||||
market = _market(rsi=80.0)
|
||||
signal = await strategy.evaluate("AAPL", market)
|
||||
|
||||
assert signal is not None
|
||||
assert signal.direction == SignalDirection.SHORT
|
||||
assert 0 < signal.strength <= 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mean_reversion_no_signal_neutral(self, strategy: MeanReversionStrategy) -> None:
|
||||
"""No signal when RSI is in neutral territory (30-70)."""
|
||||
market = _market(rsi=50.0)
|
||||
signal = await strategy.evaluate("AAPL", market)
|
||||
assert signal is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mean_reversion_missing_rsi_returns_none(self, strategy: MeanReversionStrategy) -> None:
|
||||
"""Return None when RSI is not available."""
|
||||
market = _market(rsi=None)
|
||||
assert await strategy.evaluate("AAPL", market) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mean_reversion_strength_proportional(self, strategy: MeanReversionStrategy) -> None:
|
||||
"""Strength is proportional to how far RSI is from its threshold."""
|
||||
# Buy side: lower RSI = stronger signal.
|
||||
market_mild = _market(rsi=25.0)
|
||||
signal_mild = await strategy.evaluate("AAPL", market_mild)
|
||||
|
||||
market_extreme = _market(rsi=10.0)
|
||||
signal_extreme = await strategy.evaluate("AAPL", market_extreme)
|
||||
|
||||
assert signal_mild is not None
|
||||
assert signal_extreme is not None
|
||||
assert signal_extreme.strength > signal_mild.strength
|
||||
|
||||
# Verify exact strength for RSI=20: (30 - 20) / 30 = 1/3.
|
||||
market_20 = _market(rsi=20.0)
|
||||
signal_20 = await strategy.evaluate("AAPL", market_20)
|
||||
assert signal_20 is not None
|
||||
assert signal_20.strength == pytest.approx(10.0 / 30.0, abs=1e-9)
|
||||
|
||||
# Sell side: RSI=80: (80 - 70) / 30 = 1/3.
|
||||
market_80 = _market(rsi=80.0)
|
||||
signal_80 = await strategy.evaluate("AAPL", market_80)
|
||||
assert signal_80 is not None
|
||||
assert signal_80.strength == pytest.approx(10.0 / 30.0, abs=1e-9)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mean_reversion_boundary_no_signal(self, strategy: MeanReversionStrategy) -> None:
|
||||
"""RSI exactly at 30 or 70 should NOT trigger a signal."""
|
||||
market_30 = _market(rsi=30.0)
|
||||
assert await strategy.evaluate("AAPL", market_30) is None
|
||||
|
||||
market_70 = _market(rsi=70.0)
|
||||
assert await strategy.evaluate("AAPL", market_70) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mean_reversion_strength_clamped(self, strategy: MeanReversionStrategy) -> None:
|
||||
"""Strength is clamped to [0, 1] even at extreme RSI values."""
|
||||
market = _market(rsi=95.0)
|
||||
signal = await strategy.evaluate("AAPL", market)
|
||||
assert signal is not None
|
||||
assert signal.strength <= 1.0
|
||||
|
||||
# RSI=0 => (30-0)/30 = 1.0 exactly.
|
||||
market_zero = _market(rsi=0.0)
|
||||
signal_zero = await strategy.evaluate("AAPL", market_zero)
|
||||
assert signal_zero is not None
|
||||
assert signal_zero.strength == pytest.approx(1.0, abs=1e-9)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# News-driven strategy
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestNewsDrivenStrategy:
|
||||
"""Tests for :class:`NewsDrivenStrategy`."""
|
||||
|
||||
@pytest.fixture()
|
||||
def strategy(self) -> NewsDrivenStrategy:
|
||||
return NewsDrivenStrategy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_news_driven_buy_positive(self, strategy: NewsDrivenStrategy) -> None:
|
||||
"""Buy on strongly positive sentiment (score=0.8, confidence=0.7)."""
|
||||
market = _market()
|
||||
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7, article_count=5)
|
||||
signal = await strategy.evaluate("AAPL", market, sentiment)
|
||||
|
||||
assert signal is not None
|
||||
assert signal.direction == SignalDirection.LONG
|
||||
assert 0 < signal.strength <= 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_news_driven_sell_negative(self, strategy: NewsDrivenStrategy) -> None:
|
||||
"""Sell on strongly negative sentiment (score=-0.8, confidence=0.7)."""
|
||||
market = _market()
|
||||
sentiment = _sentiment(avg_score=-0.8, avg_confidence=0.7, article_count=5)
|
||||
signal = await strategy.evaluate("AAPL", market, sentiment)
|
||||
|
||||
assert signal is not None
|
||||
assert signal.direction == SignalDirection.SHORT
|
||||
assert 0 < signal.strength <= 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_news_driven_no_signal_low_confidence(self, strategy: NewsDrivenStrategy) -> None:
|
||||
"""No signal when avg_confidence is too low (<=0.5)."""
|
||||
market = _market()
|
||||
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.4, article_count=5)
|
||||
signal = await strategy.evaluate("AAPL", market, sentiment)
|
||||
assert signal is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_news_driven_no_signal_few_articles(self, strategy: NewsDrivenStrategy) -> None:
|
||||
"""No signal when article_count < 2."""
|
||||
market = _market()
|
||||
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7, article_count=1)
|
||||
signal = await strategy.evaluate("AAPL", market, sentiment)
|
||||
assert signal is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_news_driven_no_sentiment_returns_none(self, strategy: NewsDrivenStrategy) -> None:
|
||||
"""Return None when no sentiment context is provided."""
|
||||
market = _market()
|
||||
signal = await strategy.evaluate("AAPL", market, sentiment=None)
|
||||
assert signal is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_news_driven_strength_calculation(self, strategy: NewsDrivenStrategy) -> None:
|
||||
"""Strength = abs(avg_score) * avg_confidence, clamped to [0, 1]."""
|
||||
market = _market()
|
||||
|
||||
# score=0.8, confidence=0.7 => strength = 0.56
|
||||
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7)
|
||||
signal = await strategy.evaluate("AAPL", market, sentiment)
|
||||
assert signal is not None
|
||||
assert signal.strength == pytest.approx(0.8 * 0.7, abs=1e-9)
|
||||
|
||||
# Negative score should yield same strength magnitude.
|
||||
sentiment_neg = _sentiment(avg_score=-0.8, avg_confidence=0.7)
|
||||
signal_neg = await strategy.evaluate("AAPL", market, sentiment_neg)
|
||||
assert signal_neg is not None
|
||||
assert signal_neg.strength == pytest.approx(0.8 * 0.7, abs=1e-9)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_news_driven_neutral_score(self, strategy: NewsDrivenStrategy) -> None:
|
||||
"""No signal when avg_score is between -0.3 and 0.3 (neutral)."""
|
||||
market = _market()
|
||||
sentiment = _sentiment(avg_score=0.1, avg_confidence=0.9, article_count=10)
|
||||
signal = await strategy.evaluate("AAPL", market, sentiment)
|
||||
assert signal is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_news_driven_boundary_confidence(self, strategy: NewsDrivenStrategy) -> None:
|
||||
"""No signal when avg_confidence is exactly 0.5 (threshold is >0.5)."""
|
||||
market = _market()
|
||||
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.5, article_count=5)
|
||||
signal = await strategy.evaluate("AAPL", market, sentiment)
|
||||
assert signal is None
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Cross-strategy tests
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestStrategyCrossChecks:
|
||||
"""Tests that apply across all strategy implementations."""
|
||||
|
||||
def test_all_strategies_are_base_strategy_subclass(self) -> None:
|
||||
"""All concrete strategies must inherit from BaseStrategy."""
|
||||
for cls in (MomentumStrategy, MeanReversionStrategy, NewsDrivenStrategy):
|
||||
assert issubclass(cls, BaseStrategy), f"{cls.__name__} is not a BaseStrategy subclass"
|
||||
|
||||
def test_strategy_names_unique(self) -> None:
|
||||
"""Every strategy must have a distinct name."""
|
||||
strategies = [MomentumStrategy(), MeanReversionStrategy(), NewsDrivenStrategy()]
|
||||
names = [s.name for s in strategies]
|
||||
assert len(names) == len(set(names)), f"Duplicate strategy names detected: {names}"
|
||||
|
||||
def test_strategy_names_non_empty(self) -> None:
|
||||
"""Every strategy name must be a non-empty string."""
|
||||
for cls in (MomentumStrategy, MeanReversionStrategy, NewsDrivenStrategy):
|
||||
instance = cls()
|
||||
assert isinstance(instance.name, str)
|
||||
assert len(instance.name) > 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue