- Point Ollama to local instance via host.docker.internal, use gemma3 model - Remove Docker Ollama service (using host's Ollama instead) - Add company-name-to-ticker mapping (Apple→AAPL, Tesla→TSLA, etc.) for RSS articles - Lower signal thresholds for faster feedback with paper trading: - FinBERT confidence: 0.6→0.4, signal strength: 0.3→0.15 - News strategy: article_count 2→1, confidence 0.5→0.3, score ±0.3→±0.15 - Fix market data BarSet access bug (BarSet.__contains__ returns False incorrectly) - Fix market data SIP feed error by switching to IEX feed for free Alpaca accounts - Fix nginx proxy routing for /api/auth/* to api-gateway /auth/* - Add seed_sample_data script - Update tests for new thresholds and alpaca mock modules
466 lines
15 KiB
Python
466 lines
15 KiB
Python
"""Tests for the Market Data service.
|
|
|
|
Covers configuration defaults, bar parsing from Alpaca response format,
|
|
historical bar fetching, poll logic, and Redis publish behaviour.
|
|
|
|
All Alpaca SDK imports are mocked to avoid requiring pytz and other
|
|
transitive dependencies in the test environment.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
from datetime import datetime, timezone
|
|
from types import ModuleType
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from services.market_data.config import MarketDataConfig
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _FakeBar:
|
|
"""Mimics an Alpaca Bar object with the required attributes."""
|
|
|
|
def __init__(
|
|
self,
|
|
timestamp: datetime,
|
|
open: float,
|
|
high: float,
|
|
low: float,
|
|
close: float,
|
|
volume: float,
|
|
) -> None:
|
|
self.timestamp = timestamp
|
|
self.open = open
|
|
self.high = high
|
|
self.low = low
|
|
self.close = close
|
|
self.volume = volume
|
|
|
|
|
|
def _make_fake_bar(close: float = 150.0) -> _FakeBar:
|
|
return _FakeBar(
|
|
timestamp=datetime(2026, 1, 15, 10, 30, tzinfo=timezone.utc),
|
|
open=149.0,
|
|
high=151.0,
|
|
low=148.0,
|
|
close=close,
|
|
volume=10000.0,
|
|
)
|
|
|
|
|
|
class _FakeBarSet(dict):
|
|
"""Mimics an Alpaca BarSet (dict-like) mapping ticker to list of bars."""
|
|
|
|
pass
|
|
|
|
|
|
class _FakeTimeFrame:
|
|
"""Mimics alpaca.data.timeframe.TimeFrame."""
|
|
|
|
def __init__(self, amount, unit):
|
|
self.amount = amount
|
|
self.unit = unit
|
|
|
|
|
|
class _FakeTimeFrameUnit:
|
|
"""Mimics alpaca.data.timeframe.TimeFrameUnit."""
|
|
|
|
Minute = "Minute"
|
|
Hour = "Hour"
|
|
Day = "Day"
|
|
|
|
|
|
class _FakeStockBarsRequest:
|
|
"""Mimics alpaca.data.requests.StockBarsRequest."""
|
|
|
|
def __init__(self, **kwargs):
|
|
for k, v in kwargs.items():
|
|
setattr(self, k, v)
|
|
|
|
|
|
def _install_alpaca_mocks():
|
|
"""Install mock modules for the alpaca SDK so we can import market_data.main."""
|
|
timeframe_mod = ModuleType("alpaca.data.timeframe")
|
|
timeframe_mod.TimeFrame = _FakeTimeFrame
|
|
timeframe_mod.TimeFrameUnit = _FakeTimeFrameUnit
|
|
|
|
requests_mod = ModuleType("alpaca.data.requests")
|
|
requests_mod.StockBarsRequest = _FakeStockBarsRequest
|
|
|
|
historical_mod = ModuleType("alpaca.data.historical")
|
|
historical_mod.StockHistoricalDataClient = MagicMock
|
|
|
|
enums_mod = ModuleType("alpaca.data.enums")
|
|
enums_mod.DataFeed = MagicMock()
|
|
|
|
# Build the package hierarchy
|
|
alpaca_mod = sys.modules.get("alpaca") or ModuleType("alpaca")
|
|
data_mod = sys.modules.get("alpaca.data") or ModuleType("alpaca.data")
|
|
|
|
sys.modules.setdefault("alpaca", alpaca_mod)
|
|
sys.modules["alpaca.data"] = data_mod
|
|
sys.modules["alpaca.data.timeframe"] = timeframe_mod
|
|
sys.modules["alpaca.data.requests"] = requests_mod
|
|
sys.modules["alpaca.data.historical"] = historical_mod
|
|
sys.modules["alpaca.data.enums"] = enums_mod
|
|
|
|
|
|
# Install mocks before importing from market_data.main
|
|
_install_alpaca_mocks()
|
|
|
|
from services.market_data.main import ( # noqa: E402
|
|
MARKET_BARS_STREAM,
|
|
_bar_to_dict,
|
|
_fetch_historical_bars,
|
|
_parse_timeframe,
|
|
_poll_latest_bars,
|
|
)
|
|
|
|
|
|
async def _to_thread_passthrough(func, *args, **kwargs):
|
|
"""Replacement for asyncio.to_thread that calls synchronously."""
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config defaults
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMarketDataConfig:
|
|
"""Tests for MarketDataConfig defaults."""
|
|
|
|
def test_default_watchlist(self):
|
|
config = MarketDataConfig()
|
|
assert config.watchlist == ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"]
|
|
|
|
def test_default_bar_timeframe(self):
|
|
config = MarketDataConfig()
|
|
assert config.bar_timeframe == "5Min"
|
|
|
|
def test_default_poll_interval(self):
|
|
config = MarketDataConfig()
|
|
assert config.poll_interval_seconds == 60
|
|
|
|
def test_default_historical_bars(self):
|
|
config = MarketDataConfig()
|
|
assert config.historical_bars == 100
|
|
|
|
def test_default_alpaca_keys_empty(self):
|
|
config = MarketDataConfig()
|
|
assert config.alpaca_api_key == ""
|
|
assert config.alpaca_secret_key == ""
|
|
|
|
def test_inherits_base_config_defaults(self):
|
|
config = MarketDataConfig()
|
|
assert config.log_level == "INFO"
|
|
assert config.otel_metrics_port == 9090
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Timeframe parsing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestParseTimeframe:
|
|
"""Tests for the _parse_timeframe helper."""
|
|
|
|
def test_parse_5min(self):
|
|
tf = _parse_timeframe("5Min")
|
|
assert tf is not None
|
|
assert tf.amount == 5
|
|
|
|
def test_parse_1min(self):
|
|
tf = _parse_timeframe("1Min")
|
|
assert tf is not None
|
|
assert tf.amount == 1
|
|
|
|
def test_parse_15min(self):
|
|
tf = _parse_timeframe("15Min")
|
|
assert tf is not None
|
|
assert tf.amount == 15
|
|
|
|
def test_parse_1hour(self):
|
|
tf = _parse_timeframe("1Hour")
|
|
assert tf is not None
|
|
assert tf.amount == 1
|
|
|
|
def test_parse_1day(self):
|
|
tf = _parse_timeframe("1Day")
|
|
assert tf is not None
|
|
assert tf.amount == 1
|
|
|
|
def test_parse_invalid_raises(self):
|
|
with pytest.raises(ValueError, match="Unsupported timeframe"):
|
|
_parse_timeframe("3Min")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Bar to dict conversion
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBarToDict:
|
|
"""Tests for _bar_to_dict conversion."""
|
|
|
|
def test_basic_conversion(self):
|
|
bar = _make_fake_bar(close=155.5)
|
|
result = _bar_to_dict("AAPL", bar)
|
|
|
|
assert result["ticker"] == "AAPL"
|
|
assert result["close"] == 155.5
|
|
assert result["open"] == 149.0
|
|
assert result["high"] == 151.0
|
|
assert result["low"] == 148.0
|
|
assert result["volume"] == 10000.0
|
|
assert "timestamp" in result
|
|
|
|
def test_timestamp_is_iso_format(self):
|
|
bar = _make_fake_bar()
|
|
result = _bar_to_dict("TSLA", bar)
|
|
# Should be parseable back
|
|
parsed = datetime.fromisoformat(result["timestamp"])
|
|
assert parsed.tzinfo is not None
|
|
|
|
def test_all_fields_are_floats(self):
|
|
bar = _make_fake_bar()
|
|
result = _bar_to_dict("AAPL", bar)
|
|
for field in ("open", "high", "low", "close", "volume"):
|
|
assert isinstance(result[field], float)
|
|
|
|
def test_ticker_is_passed_through(self):
|
|
bar = _make_fake_bar()
|
|
result = _bar_to_dict("NVDA", bar)
|
|
assert result["ticker"] == "NVDA"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fetch historical bars
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestFetchHistoricalBars:
|
|
"""Tests for _fetch_historical_bars with mocked Alpaca client."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publishes_bars_for_each_ticker(self):
|
|
"""Should publish all historical bars for each ticker in the watchlist."""
|
|
bars_aapl = [_make_fake_bar(150.0), _make_fake_bar(151.0)]
|
|
bars_tsla = [_make_fake_bar(200.0)]
|
|
|
|
bar_set = _FakeBarSet({"AAPL": bars_aapl, "TSLA": bars_tsla})
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
|
|
|
publisher = AsyncMock()
|
|
publisher.publish = AsyncMock()
|
|
counter = MagicMock()
|
|
|
|
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
|
total = await _fetch_historical_bars(
|
|
mock_client, ["AAPL", "TSLA"], MagicMock(), 100, publisher, counter
|
|
)
|
|
|
|
assert total == 3
|
|
assert publisher.publish.call_count == 3
|
|
counter.add.assert_called_once_with(3)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handles_empty_response(self):
|
|
"""Should handle tickers with no bars gracefully."""
|
|
bar_set = _FakeBarSet({})
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
|
|
|
publisher = AsyncMock()
|
|
counter = MagicMock()
|
|
|
|
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
|
total = await _fetch_historical_bars(
|
|
mock_client, ["AAPL"], MagicMock(), 100, publisher, counter
|
|
)
|
|
|
|
assert total == 0
|
|
publisher.publish.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handles_client_exception(self):
|
|
"""Should log and continue if one ticker fails."""
|
|
mock_client = MagicMock()
|
|
mock_client.get_stock_bars = MagicMock(side_effect=Exception("API error"))
|
|
|
|
publisher = AsyncMock()
|
|
counter = MagicMock()
|
|
|
|
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
|
total = await _fetch_historical_bars(
|
|
mock_client, ["AAPL", "TSLA"], MagicMock(), 100, publisher, counter
|
|
)
|
|
|
|
assert total == 0
|
|
publisher.publish.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_published_message_format(self):
|
|
"""Published messages should have the expected fields."""
|
|
bar = _make_fake_bar(175.0)
|
|
bar_set = _FakeBarSet({"MSFT": [bar]})
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
|
|
|
publisher = AsyncMock()
|
|
publisher.publish = AsyncMock()
|
|
counter = MagicMock()
|
|
|
|
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
|
await _fetch_historical_bars(
|
|
mock_client, ["MSFT"], MagicMock(), 100, publisher, counter
|
|
)
|
|
|
|
msg = publisher.publish.call_args[0][0]
|
|
assert msg["ticker"] == "MSFT"
|
|
assert msg["close"] == 175.0
|
|
assert "timestamp" in msg
|
|
assert "open" in msg
|
|
assert "high" in msg
|
|
assert "low" in msg
|
|
assert "volume" in msg
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Poll latest bars
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPollLatestBars:
|
|
"""Tests for _poll_latest_bars with mocked Alpaca client."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publishes_latest_bar_per_ticker(self):
|
|
"""Should publish one bar per ticker."""
|
|
bar_set = _FakeBarSet({"AAPL": [_make_fake_bar(155.0)]})
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
|
|
|
publisher = AsyncMock()
|
|
publisher.publish = AsyncMock()
|
|
counter = MagicMock()
|
|
|
|
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
|
count = await _poll_latest_bars(
|
|
mock_client, ["AAPL"], MagicMock(), publisher, counter
|
|
)
|
|
|
|
assert count == 1
|
|
assert publisher.publish.call_count == 1
|
|
|
|
published_msg = publisher.publish.call_args[0][0]
|
|
assert published_msg["ticker"] == "AAPL"
|
|
assert published_msg["close"] == 155.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publishes_only_most_recent_bar(self):
|
|
"""When multiple bars returned, should only publish the last one."""
|
|
bars = [_make_fake_bar(150.0), _make_fake_bar(155.0)]
|
|
bar_set = _FakeBarSet({"AAPL": bars})
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
|
|
|
publisher = AsyncMock()
|
|
publisher.publish = AsyncMock()
|
|
counter = MagicMock()
|
|
|
|
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
|
count = await _poll_latest_bars(
|
|
mock_client, ["AAPL"], MagicMock(), publisher, counter
|
|
)
|
|
|
|
assert count == 1
|
|
published_msg = publisher.publish.call_args[0][0]
|
|
assert published_msg["close"] == 155.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handles_no_bars_for_ticker(self):
|
|
"""Should return 0 when a ticker has no bars."""
|
|
bar_set = _FakeBarSet({"AAPL": []})
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
|
|
|
publisher = AsyncMock()
|
|
counter = MagicMock()
|
|
|
|
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
|
count = await _poll_latest_bars(
|
|
mock_client, ["AAPL"], MagicMock(), publisher, counter
|
|
)
|
|
|
|
assert count == 0
|
|
publisher.publish.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handles_ticker_not_in_response(self):
|
|
"""Should handle gracefully when the response doesn't contain the ticker."""
|
|
bar_set = _FakeBarSet({})
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
|
|
|
|
publisher = AsyncMock()
|
|
counter = MagicMock()
|
|
|
|
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
|
count = await _poll_latest_bars(
|
|
mock_client, ["AAPL"], MagicMock(), publisher, counter
|
|
)
|
|
|
|
assert count == 0
|
|
publisher.publish.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_tickers(self):
|
|
"""Should poll and publish bars for all tickers in watchlist."""
|
|
bar_set_aapl = _FakeBarSet({"AAPL": [_make_fake_bar(150.0)]})
|
|
bar_set_tsla = _FakeBarSet({"TSLA": [_make_fake_bar(250.0)]})
|
|
|
|
mock_client = MagicMock()
|
|
# Return different bar sets for each call
|
|
mock_client.get_stock_bars = MagicMock(
|
|
side_effect=[bar_set_aapl, bar_set_tsla]
|
|
)
|
|
|
|
publisher = AsyncMock()
|
|
publisher.publish = AsyncMock()
|
|
counter = MagicMock()
|
|
|
|
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
|
|
count = await _poll_latest_bars(
|
|
mock_client, ["AAPL", "TSLA"], MagicMock(), publisher, counter
|
|
)
|
|
|
|
assert count == 2
|
|
assert publisher.publish.call_count == 2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Stream name constant
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestStreamConstants:
|
|
"""Verify the stream name constant."""
|
|
|
|
def test_market_bars_stream_name(self):
|
|
assert MARKET_BARS_STREAM == "market:bars"
|