trading/tests/services/test_signal_generator.py

360 lines
13 KiB
Python
Raw Normal View History

"""Tests for the Signal Generator service.
Covers MarketDataManager (SMA, RSI, snapshot) and WeightedEnsemble
(signal combination, threshold filtering, strategy source tagging).
"""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from services.signal_generator.ensemble import WeightedEnsemble
from services.signal_generator.market_data import MarketDataManager
from shared.schemas.trading import (
MarketSnapshot,
OHLCVBar,
SentimentContext,
SignalDirection,
TradeSignal,
)
from shared.strategies.base import BaseStrategy
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_bar(close: float, *, ts_offset: int = 0) -> OHLCVBar:
"""Create an ``OHLCVBar`` with the given close price."""
return OHLCVBar(
timestamp=datetime(2026, 1, 1, 10, ts_offset, tzinfo=timezone.utc),
open=close - 0.5,
high=close + 1.0,
low=close - 1.0,
close=close,
volume=1000.0,
)
class _StubStrategy(BaseStrategy):
"""Test helper that returns a preconfigured signal."""
def __init__(self, name: str, signal: TradeSignal | None) -> None:
self.name = name
self._signal = signal
async def evaluate(self, ticker, market, sentiment=None):
return self._signal
def _make_signal(
direction: SignalDirection = SignalDirection.LONG,
strength: float = 0.8,
sources: list[str] | None = None,
) -> TradeSignal:
return TradeSignal(
ticker="AAPL",
direction=direction,
strength=strength,
strategy_sources=sources or ["test"],
timestamp=datetime.now(timezone.utc),
)
# ---------------------------------------------------------------------------
# MarketDataManager — SMA
# ---------------------------------------------------------------------------
class TestMarketDataManagerSMA:
"""Tests for SMA computation inside MarketDataManager."""
def test_sma_basic(self):
"""SMA-20 should equal the mean of the last 20 close prices."""
mgr = MarketDataManager()
closes = list(range(1, 21)) # 1, 2, ..., 20
for i, c in enumerate(closes):
mgr.add_bar("AAPL", _make_bar(float(c), ts_offset=i))
snap = mgr.get_snapshot("AAPL")
assert snap is not None
expected_sma_20 = sum(closes) / 20
assert snap.sma_20 == pytest.approx(expected_sma_20)
def test_sma_returns_none_insufficient_data(self):
"""SMA-20 should be None when fewer than 20 bars exist."""
mgr = MarketDataManager()
for i in range(10):
mgr.add_bar("AAPL", _make_bar(100.0, ts_offset=i))
snap = mgr.get_snapshot("AAPL")
assert snap is not None
assert snap.sma_20 is None
def test_sma_50_requires_50_bars(self):
"""SMA-50 should be None with only 30 bars, present with 50."""
mgr = MarketDataManager()
for i in range(30):
mgr.add_bar("AAPL", _make_bar(float(i + 1), ts_offset=i))
snap = mgr.get_snapshot("AAPL")
assert snap is not None
assert snap.sma_50 is None
# Add 20 more
for i in range(30, 50):
mgr.add_bar("AAPL", _make_bar(float(i + 1), ts_offset=i))
snap = mgr.get_snapshot("AAPL")
assert snap is not None
assert snap.sma_50 is not None
expected = sum(range(1, 51)) / 50
assert snap.sma_50 == pytest.approx(expected)
# ---------------------------------------------------------------------------
# MarketDataManager — RSI
# ---------------------------------------------------------------------------
class TestMarketDataManagerRSI:
"""Tests for RSI computation inside MarketDataManager."""
def test_rsi_all_gains(self):
"""RSI should be 100 when all price changes are positive."""
mgr = MarketDataManager()
for i in range(20):
mgr.add_bar("AAPL", _make_bar(100.0 + i, ts_offset=i))
snap = mgr.get_snapshot("AAPL")
assert snap is not None
assert snap.rsi == pytest.approx(100.0)
def test_rsi_all_losses(self):
"""RSI should be 0 when all price changes are negative."""
mgr = MarketDataManager()
for i in range(20):
mgr.add_bar("AAPL", _make_bar(200.0 - i, ts_offset=i))
snap = mgr.get_snapshot("AAPL")
assert snap is not None
assert snap.rsi == pytest.approx(0.0)
def test_rsi_mixed(self):
"""RSI should be between 0 and 100 with mixed gains and losses."""
mgr = MarketDataManager()
prices = [44, 44.34, 44.09, 43.61, 44.33, 44.83, 45.10, 45.42,
45.84, 46.08, 45.89, 46.03, 45.61, 46.28, 46.28, 46.00]
for i, p in enumerate(prices):
mgr.add_bar("AAPL", _make_bar(p, ts_offset=i))
snap = mgr.get_snapshot("AAPL")
assert snap is not None
assert snap.rsi is not None
assert 0 < snap.rsi < 100
def test_rsi_returns_none_insufficient_data(self):
"""RSI should be None when fewer than 15 bars exist (need 14+1)."""
mgr = MarketDataManager()
for i in range(10):
mgr.add_bar("AAPL", _make_bar(100.0, ts_offset=i))
snap = mgr.get_snapshot("AAPL")
assert snap is not None
assert snap.rsi is None
# ---------------------------------------------------------------------------
# MarketDataManager — snapshot
# ---------------------------------------------------------------------------
class TestMarketDataManagerSnapshot:
"""Tests for get_snapshot behaviour."""
def test_snapshot_returns_none_for_unknown_ticker(self):
mgr = MarketDataManager()
assert mgr.get_snapshot("UNKNOWN") is None
def test_snapshot_uses_latest_bar_for_price(self):
mgr = MarketDataManager()
mgr.add_bar("AAPL", _make_bar(100.0, ts_offset=0))
mgr.add_bar("AAPL", _make_bar(105.0, ts_offset=1))
snap = mgr.get_snapshot("AAPL")
assert snap is not None
assert snap.current_price == 105.0
def test_snapshot_contains_bars(self):
mgr = MarketDataManager()
for i in range(5):
mgr.add_bar("AAPL", _make_bar(100.0 + i, ts_offset=i))
snap = mgr.get_snapshot("AAPL")
assert snap is not None
assert len(snap.bars) == 5
# ---------------------------------------------------------------------------
# WeightedEnsemble — combines signals
# ---------------------------------------------------------------------------
class TestEnsembleCombinesSignals:
"""Test that the ensemble correctly combines strategy signals."""
@pytest.mark.asyncio
async def test_combines_two_long_signals(self):
"""Two LONG signals should produce a combined LONG signal."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.8))
s2 = _StubStrategy("beta", _make_signal(SignalDirection.LONG, 0.6))
ensemble = WeightedEnsemble([s1, s2], threshold=0.0)
market = MarketSnapshot(
ticker="AAPL", current_price=150.0,
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
)
weights = {"alpha": 0.5, "beta": 0.5}
signal = await ensemble.evaluate("AAPL", market, None, weights)
assert signal is not None
assert signal.direction == SignalDirection.LONG
# Weighted average = (0.8*0.5 + 0.6*0.5) / (0.5+0.5) = 0.7
assert signal.strength == pytest.approx(0.7, abs=0.01)
@pytest.mark.asyncio
async def test_opposing_signals_net_direction(self):
"""When strategies disagree, direction follows the stronger weighted side."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
s2 = _StubStrategy("beta", _make_signal(SignalDirection.SHORT, 0.3))
ensemble = WeightedEnsemble([s1, s2], threshold=0.0)
market = MarketSnapshot(
ticker="AAPL", current_price=150.0,
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
)
weights = {"alpha": 0.5, "beta": 0.5}
signal = await ensemble.evaluate("AAPL", market, None, weights)
assert signal is not None
# Net direction should be LONG since alpha is stronger
assert signal.direction == SignalDirection.LONG
# ---------------------------------------------------------------------------
# WeightedEnsemble — threshold filtering
# ---------------------------------------------------------------------------
class TestEnsembleThresholdFiltering:
"""Test that weak combined signals are filtered out by the threshold."""
@pytest.mark.asyncio
async def test_below_threshold_returns_none(self):
"""Combined strength below threshold should yield None."""
# Two opposing signals of similar strength will nearly cancel out
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.5))
s2 = _StubStrategy("beta", _make_signal(SignalDirection.SHORT, 0.45))
ensemble = WeightedEnsemble([s1, s2], threshold=0.5)
market = MarketSnapshot(
ticker="AAPL", current_price=150.0,
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
)
weights = {"alpha": 0.5, "beta": 0.5}
signal = await ensemble.evaluate("AAPL", market, None, weights)
assert signal is None
@pytest.mark.asyncio
async def test_above_threshold_returns_signal(self):
"""Strong combined signal above threshold should yield a signal."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
ensemble = WeightedEnsemble([s1], threshold=0.3)
market = MarketSnapshot(
ticker="AAPL", current_price=150.0,
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
)
weights = {"alpha": 1.0}
signal = await ensemble.evaluate("AAPL", market, None, weights)
assert signal is not None
assert signal.strength >= 0.3
# ---------------------------------------------------------------------------
# WeightedEnsemble — no signals returns None
# ---------------------------------------------------------------------------
class TestEnsembleNoSignals:
"""Test that the ensemble returns None when no strategy fires."""
@pytest.mark.asyncio
async def test_all_strategies_return_none(self):
s1 = _StubStrategy("alpha", None)
s2 = _StubStrategy("beta", None)
ensemble = WeightedEnsemble([s1, s2], threshold=0.3)
market = MarketSnapshot(
ticker="AAPL", current_price=150.0,
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
)
weights = {"alpha": 0.5, "beta": 0.5}
signal = await ensemble.evaluate("AAPL", market, None, weights)
assert signal is None
# ---------------------------------------------------------------------------
# WeightedEnsemble — tags strategy sources
# ---------------------------------------------------------------------------
class TestEnsembleTagsStrategySources:
"""Verify that the output signal records which strategies contributed."""
@pytest.mark.asyncio
async def test_strategy_sources_contains_all_contributors(self):
s1 = _StubStrategy("momentum", _make_signal(SignalDirection.LONG, 0.7, ["momentum"]))
s2 = _StubStrategy("news_driven", _make_signal(SignalDirection.LONG, 0.6, ["news_driven"]))
s3 = _StubStrategy("mean_reversion", None) # does not contribute
ensemble = WeightedEnsemble([s1, s2, s3], threshold=0.0)
market = MarketSnapshot(
ticker="AAPL", current_price=150.0,
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
)
weights = {"momentum": 0.5, "news_driven": 0.3, "mean_reversion": 0.2}
signal = await ensemble.evaluate("AAPL", market, None, weights)
assert signal is not None
# Should have exactly 2 sources
assert len(signal.strategy_sources) == 2
source_names = [s.split(":")[0] for s in signal.strategy_sources]
assert "momentum" in source_names
assert "news_driven" in source_names
# mean_reversion should NOT be present
assert "mean_reversion" not in source_names
@pytest.mark.asyncio
async def test_strategy_sources_contain_direction_and_strength(self):
"""Each source tag should be formatted as name:DIRECTION:strength."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.75))
ensemble = WeightedEnsemble([s1], threshold=0.0)
market = MarketSnapshot(
ticker="AAPL", current_price=150.0,
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
)
weights = {"alpha": 1.0}
signal = await ensemble.evaluate("AAPL", market, None, weights)
assert signal is not None
assert len(signal.strategy_sources) == 1
parts = signal.strategy_sources[0].split(":")
assert parts[0] == "alpha"
assert parts[1] == "LONG"
assert float(parts[2]) == pytest.approx(0.75, abs=0.01)