"""Tests for the Market Data service. Covers configuration defaults, bar parsing from Alpaca response format, historical bar fetching, poll logic, and Redis publish behaviour. All Alpaca SDK imports are mocked to avoid requiring pytz and other transitive dependencies in the test environment. """ from __future__ import annotations import sys from datetime import datetime, timezone from types import ModuleType from unittest.mock import AsyncMock, MagicMock, patch import pytest from services.market_data.config import MarketDataConfig # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- class _FakeBar: """Mimics an Alpaca Bar object with the required attributes.""" def __init__( self, timestamp: datetime, open: float, high: float, low: float, close: float, volume: float, ) -> None: self.timestamp = timestamp self.open = open self.high = high self.low = low self.close = close self.volume = volume def _make_fake_bar(close: float = 150.0) -> _FakeBar: return _FakeBar( timestamp=datetime(2026, 1, 15, 10, 30, tzinfo=timezone.utc), open=149.0, high=151.0, low=148.0, close=close, volume=10000.0, ) class _FakeBarSet(dict): """Mimics an Alpaca BarSet (dict-like) mapping ticker to list of bars.""" pass class _FakeTimeFrame: """Mimics alpaca.data.timeframe.TimeFrame.""" def __init__(self, amount, unit): self.amount = amount self.unit = unit class _FakeTimeFrameUnit: """Mimics alpaca.data.timeframe.TimeFrameUnit.""" Minute = "Minute" Hour = "Hour" Day = "Day" class _FakeStockBarsRequest: """Mimics alpaca.data.requests.StockBarsRequest.""" def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) def _install_alpaca_mocks(): """Install mock modules for the alpaca SDK so we can import market_data.main.""" timeframe_mod = ModuleType("alpaca.data.timeframe") timeframe_mod.TimeFrame = _FakeTimeFrame timeframe_mod.TimeFrameUnit = _FakeTimeFrameUnit requests_mod = ModuleType("alpaca.data.requests") requests_mod.StockBarsRequest = _FakeStockBarsRequest historical_mod = ModuleType("alpaca.data.historical") historical_mod.StockHistoricalDataClient = MagicMock enums_mod = ModuleType("alpaca.data.enums") enums_mod.DataFeed = MagicMock() # Build the package hierarchy alpaca_mod = sys.modules.get("alpaca") or ModuleType("alpaca") data_mod = sys.modules.get("alpaca.data") or ModuleType("alpaca.data") sys.modules.setdefault("alpaca", alpaca_mod) sys.modules["alpaca.data"] = data_mod sys.modules["alpaca.data.timeframe"] = timeframe_mod sys.modules["alpaca.data.requests"] = requests_mod sys.modules["alpaca.data.historical"] = historical_mod sys.modules["alpaca.data.enums"] = enums_mod # Install mocks before importing from market_data.main _install_alpaca_mocks() from services.market_data.main import ( # noqa: E402 MARKET_BARS_STREAM, _bar_to_dict, _fetch_historical_bars, _parse_timeframe, _poll_latest_bars, ) async def _to_thread_passthrough(func, *args, **kwargs): """Replacement for asyncio.to_thread that calls synchronously.""" return func(*args, **kwargs) # --------------------------------------------------------------------------- # Config defaults # --------------------------------------------------------------------------- class TestMarketDataConfig: """Tests for MarketDataConfig defaults.""" def test_default_watchlist(self): config = MarketDataConfig() assert config.watchlist == ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"] def test_default_bar_timeframe(self): config = MarketDataConfig() assert config.bar_timeframe == "5Min" def test_default_poll_interval(self): config = MarketDataConfig() assert config.poll_interval_seconds == 60 def test_default_historical_bars(self): config = MarketDataConfig() assert config.historical_bars == 100 def test_default_alpaca_keys_empty(self): config = MarketDataConfig() assert config.alpaca_api_key == "" assert config.alpaca_secret_key == "" def test_inherits_base_config_defaults(self): config = MarketDataConfig() assert config.log_level == "INFO" assert config.otel_metrics_port == 9090 # --------------------------------------------------------------------------- # Timeframe parsing # --------------------------------------------------------------------------- class TestParseTimeframe: """Tests for the _parse_timeframe helper.""" def test_parse_5min(self): tf = _parse_timeframe("5Min") assert tf is not None assert tf.amount == 5 def test_parse_1min(self): tf = _parse_timeframe("1Min") assert tf is not None assert tf.amount == 1 def test_parse_15min(self): tf = _parse_timeframe("15Min") assert tf is not None assert tf.amount == 15 def test_parse_1hour(self): tf = _parse_timeframe("1Hour") assert tf is not None assert tf.amount == 1 def test_parse_1day(self): tf = _parse_timeframe("1Day") assert tf is not None assert tf.amount == 1 def test_parse_invalid_raises(self): with pytest.raises(ValueError, match="Unsupported timeframe"): _parse_timeframe("3Min") # --------------------------------------------------------------------------- # Bar to dict conversion # --------------------------------------------------------------------------- class TestBarToDict: """Tests for _bar_to_dict conversion.""" def test_basic_conversion(self): bar = _make_fake_bar(close=155.5) result = _bar_to_dict("AAPL", bar) assert result["ticker"] == "AAPL" assert result["close"] == 155.5 assert result["open"] == 149.0 assert result["high"] == 151.0 assert result["low"] == 148.0 assert result["volume"] == 10000.0 assert "timestamp" in result def test_timestamp_is_iso_format(self): bar = _make_fake_bar() result = _bar_to_dict("TSLA", bar) # Should be parseable back parsed = datetime.fromisoformat(result["timestamp"]) assert parsed.tzinfo is not None def test_all_fields_are_floats(self): bar = _make_fake_bar() result = _bar_to_dict("AAPL", bar) for field in ("open", "high", "low", "close", "volume"): assert isinstance(result[field], float) def test_ticker_is_passed_through(self): bar = _make_fake_bar() result = _bar_to_dict("NVDA", bar) assert result["ticker"] == "NVDA" # --------------------------------------------------------------------------- # Fetch historical bars # --------------------------------------------------------------------------- class TestFetchHistoricalBars: """Tests for _fetch_historical_bars with mocked Alpaca client.""" @pytest.mark.asyncio async def test_publishes_bars_for_each_ticker(self): """Should publish all historical bars for each ticker in the watchlist.""" bars_aapl = [_make_fake_bar(150.0), _make_fake_bar(151.0)] bars_tsla = [_make_fake_bar(200.0)] bar_set = _FakeBarSet({"AAPL": bars_aapl, "TSLA": bars_tsla}) mock_client = MagicMock() mock_client.get_stock_bars = MagicMock(return_value=bar_set) publisher = AsyncMock() publisher.publish = AsyncMock() counter = MagicMock() with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): total = await _fetch_historical_bars( mock_client, ["AAPL", "TSLA"], MagicMock(), 100, publisher, counter ) assert total == 3 assert publisher.publish.call_count == 3 counter.add.assert_called_once_with(3) @pytest.mark.asyncio async def test_handles_empty_response(self): """Should handle tickers with no bars gracefully.""" bar_set = _FakeBarSet({}) mock_client = MagicMock() mock_client.get_stock_bars = MagicMock(return_value=bar_set) publisher = AsyncMock() counter = MagicMock() with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): total = await _fetch_historical_bars( mock_client, ["AAPL"], MagicMock(), 100, publisher, counter ) assert total == 0 publisher.publish.assert_not_called() @pytest.mark.asyncio async def test_handles_client_exception(self): """Should log and continue if one ticker fails.""" mock_client = MagicMock() mock_client.get_stock_bars = MagicMock(side_effect=Exception("API error")) publisher = AsyncMock() counter = MagicMock() with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): total = await _fetch_historical_bars( mock_client, ["AAPL", "TSLA"], MagicMock(), 100, publisher, counter ) assert total == 0 publisher.publish.assert_not_called() @pytest.mark.asyncio async def test_published_message_format(self): """Published messages should have the expected fields.""" bar = _make_fake_bar(175.0) bar_set = _FakeBarSet({"MSFT": [bar]}) mock_client = MagicMock() mock_client.get_stock_bars = MagicMock(return_value=bar_set) publisher = AsyncMock() publisher.publish = AsyncMock() counter = MagicMock() with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): await _fetch_historical_bars( mock_client, ["MSFT"], MagicMock(), 100, publisher, counter ) msg = publisher.publish.call_args[0][0] assert msg["ticker"] == "MSFT" assert msg["close"] == 175.0 assert "timestamp" in msg assert "open" in msg assert "high" in msg assert "low" in msg assert "volume" in msg # --------------------------------------------------------------------------- # Poll latest bars # --------------------------------------------------------------------------- class TestPollLatestBars: """Tests for _poll_latest_bars with mocked Alpaca client.""" @pytest.mark.asyncio async def test_publishes_latest_bar_per_ticker(self): """Should publish one bar per ticker.""" bar_set = _FakeBarSet({"AAPL": [_make_fake_bar(155.0)]}) mock_client = MagicMock() mock_client.get_stock_bars = MagicMock(return_value=bar_set) publisher = AsyncMock() publisher.publish = AsyncMock() counter = MagicMock() with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): count = await _poll_latest_bars( mock_client, ["AAPL"], MagicMock(), publisher, counter ) assert count == 1 assert publisher.publish.call_count == 1 published_msg = publisher.publish.call_args[0][0] assert published_msg["ticker"] == "AAPL" assert published_msg["close"] == 155.0 @pytest.mark.asyncio async def test_publishes_only_most_recent_bar(self): """When multiple bars returned, should only publish the last one.""" bars = [_make_fake_bar(150.0), _make_fake_bar(155.0)] bar_set = _FakeBarSet({"AAPL": bars}) mock_client = MagicMock() mock_client.get_stock_bars = MagicMock(return_value=bar_set) publisher = AsyncMock() publisher.publish = AsyncMock() counter = MagicMock() with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): count = await _poll_latest_bars( mock_client, ["AAPL"], MagicMock(), publisher, counter ) assert count == 1 published_msg = publisher.publish.call_args[0][0] assert published_msg["close"] == 155.0 @pytest.mark.asyncio async def test_handles_no_bars_for_ticker(self): """Should return 0 when a ticker has no bars.""" bar_set = _FakeBarSet({"AAPL": []}) mock_client = MagicMock() mock_client.get_stock_bars = MagicMock(return_value=bar_set) publisher = AsyncMock() counter = MagicMock() with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): count = await _poll_latest_bars( mock_client, ["AAPL"], MagicMock(), publisher, counter ) assert count == 0 publisher.publish.assert_not_called() @pytest.mark.asyncio async def test_handles_ticker_not_in_response(self): """Should handle gracefully when the response doesn't contain the ticker.""" bar_set = _FakeBarSet({}) mock_client = MagicMock() mock_client.get_stock_bars = MagicMock(return_value=bar_set) publisher = AsyncMock() counter = MagicMock() with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): count = await _poll_latest_bars( mock_client, ["AAPL"], MagicMock(), publisher, counter ) assert count == 0 publisher.publish.assert_not_called() @pytest.mark.asyncio async def test_multiple_tickers(self): """Should poll and publish bars for all tickers in watchlist.""" bar_set_aapl = _FakeBarSet({"AAPL": [_make_fake_bar(150.0)]}) bar_set_tsla = _FakeBarSet({"TSLA": [_make_fake_bar(250.0)]}) mock_client = MagicMock() # Return different bar sets for each call mock_client.get_stock_bars = MagicMock( side_effect=[bar_set_aapl, bar_set_tsla] ) publisher = AsyncMock() publisher.publish = AsyncMock() counter = MagicMock() with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough): count = await _poll_latest_bars( mock_client, ["AAPL", "TSLA"], MagicMock(), publisher, counter ) assert count == 2 assert publisher.publish.call_count == 2 # --------------------------------------------------------------------------- # Stream name constant # --------------------------------------------------------------------------- class TestStreamConstants: """Verify the stream name constant.""" def test_market_bars_stream_name(self): assert MARKET_BARS_STREAM == "market:bars"