"""Tests for portfolio sync background task. Verifies that the sync loop correctly: - Creates PortfolioSnapshot rows from broker account data - Upserts Position rows from broker positions - Removes Position rows for closed positions - Handles broker errors gracefully - Respects US market hours """ from __future__ import annotations import asyncio from datetime import datetime, time, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from services.api_gateway.config import ApiGatewayConfig from services.api_gateway.tasks.portfolio_sync import ( _sync_once, is_market_open, portfolio_sync_loop, ) from shared.schemas.trading import AccountInfo, PositionInfo # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture() def config() -> ApiGatewayConfig: return ApiGatewayConfig( jwt_secret_key="test-secret-for-sync", database_url="sqlite+aiosqlite:///:memory:", redis_url="redis://localhost:6379/0", alpaca_api_key="test-key", alpaca_secret_key="test-secret", paper_trading=True, snapshot_interval_seconds=1, ) @pytest.fixture() def config_no_creds() -> ApiGatewayConfig: return ApiGatewayConfig( jwt_secret_key="test-secret-for-sync", database_url="sqlite+aiosqlite:///:memory:", redis_url="redis://localhost:6379/0", alpaca_api_key="", alpaca_secret_key="", ) @pytest.fixture() def mock_account() -> AccountInfo: return AccountInfo( equity=105000.0, cash=50000.0, buying_power=100000.0, portfolio_value=105000.0, ) @pytest.fixture() def mock_positions() -> list[PositionInfo]: return [ PositionInfo( ticker="AAPL", qty=10.0, avg_entry=150.0, current_price=155.0, unrealized_pnl=50.0, market_value=1550.0, ), PositionInfo( ticker="MSFT", qty=5.0, avg_entry=400.0, current_price=410.0, unrealized_pnl=50.0, market_value=2050.0, ), ] @pytest.fixture() def mock_broker(mock_account, mock_positions): broker = AsyncMock() broker.get_account = AsyncMock(return_value=mock_account) broker.get_positions = AsyncMock(return_value=mock_positions) return broker @pytest.fixture() def mock_session(): """Create a mock async session with context manager support.""" session = AsyncMock() session.__aenter__ = AsyncMock(return_value=session) session.__aexit__ = AsyncMock(return_value=False) # Mock the begin() context manager begin_ctx = AsyncMock() begin_ctx.__aenter__ = AsyncMock(return_value=None) begin_ctx.__aexit__ = AsyncMock(return_value=False) session.begin = MagicMock(return_value=begin_ctx) # session.add is synchronous in SQLAlchemy — use MagicMock to avoid warnings session.add = MagicMock() return session @pytest.fixture() def mock_session_factory(mock_session): factory = MagicMock() factory.return_value = mock_session return factory # --------------------------------------------------------------------------- # Market hours tests # --------------------------------------------------------------------------- class TestMarketHours: """Tests for the is_market_open() function.""" def test_weekday_during_market_hours(self) -> None: # Wednesday 2024-01-10 at 10:00 AM ET = 15:00 UTC dt = datetime(2024, 1, 10, 15, 0, 0, tzinfo=timezone.utc) assert is_market_open(dt) is True def test_weekday_before_market_open(self) -> None: # Wednesday 2024-01-10 at 9:00 AM ET = 14:00 UTC dt = datetime(2024, 1, 10, 14, 0, 0, tzinfo=timezone.utc) assert is_market_open(dt) is False def test_weekday_after_market_close(self) -> None: # Wednesday 2024-01-10 at 4:30 PM ET = 21:30 UTC dt = datetime(2024, 1, 10, 21, 30, 0, tzinfo=timezone.utc) assert is_market_open(dt) is False def test_weekend_saturday(self) -> None: # Saturday 2024-01-13 at 12:00 PM ET = 17:00 UTC dt = datetime(2024, 1, 13, 17, 0, 0, tzinfo=timezone.utc) assert is_market_open(dt) is False def test_weekend_sunday(self) -> None: # Sunday 2024-01-14 at 12:00 PM ET = 17:00 UTC dt = datetime(2024, 1, 14, 17, 0, 0, tzinfo=timezone.utc) assert is_market_open(dt) is False def test_market_open_boundary(self) -> None: # Wednesday 2024-01-10 at exactly 9:30 AM ET = 14:30 UTC dt = datetime(2024, 1, 10, 14, 30, 0, tzinfo=timezone.utc) assert is_market_open(dt) is True def test_market_close_boundary(self) -> None: # Wednesday 2024-01-10 at exactly 4:00 PM ET = 21:00 UTC dt = datetime(2024, 1, 10, 21, 0, 0, tzinfo=timezone.utc) assert is_market_open(dt) is False # --------------------------------------------------------------------------- # Snapshot creation tests # --------------------------------------------------------------------------- class TestSyncOnce: """Tests for the _sync_once() function.""" async def test_creates_portfolio_snapshot( self, mock_broker, mock_session_factory, mock_session ) -> None: # Mock the select query to return None (no existing positions) execute_result = MagicMock() execute_result.scalar_one_or_none.return_value = None mock_session.execute = AsyncMock(return_value=execute_result) await _sync_once(mock_broker, mock_session_factory) # Verify the broker was called mock_broker.get_account.assert_awaited_once() mock_broker.get_positions.assert_awaited_once() # Verify session.add was called (snapshot + 2 new positions) assert mock_session.add.call_count == 3 # 1 snapshot + 2 positions # Check the snapshot snapshot_call = mock_session.add.call_args_list[0] snapshot = snapshot_call[0][0] assert snapshot.total_value == 105000.0 assert snapshot.cash == 50000.0 assert snapshot.positions_value == 55000.0 # 105000 - 50000 assert snapshot.daily_pnl == 0.0 async def test_creates_position_rows_for_new_positions( self, mock_broker, mock_session_factory, mock_session ) -> None: # No existing positions in DB execute_result = MagicMock() execute_result.scalar_one_or_none.return_value = None mock_session.execute = AsyncMock(return_value=execute_result) await _sync_once(mock_broker, mock_session_factory) # Positions are added via session.add (after the snapshot) position_calls = mock_session.add.call_args_list[1:] assert len(position_calls) == 2 pos1 = position_calls[0][0][0] assert pos1.ticker == "AAPL" assert pos1.qty == 10.0 assert pos1.avg_entry == 150.0 assert pos1.unrealized_pnl == 50.0 pos2 = position_calls[1][0][0] assert pos2.ticker == "MSFT" assert pos2.qty == 5.0 assert pos2.avg_entry == 400.0 async def test_updates_existing_position( self, mock_broker, mock_session_factory, mock_session ) -> None: # Mock an existing position for AAPL, None for MSFT existing_aapl = MagicMock() existing_aapl.ticker = "AAPL" existing_aapl.qty = 5.0 # old qty existing_aapl.avg_entry = 140.0 # old entry # Day-start snapshot query (returns None = first snapshot today) result_day_start = MagicMock() result_day_start.scalar_one_or_none.return_value = None result_aapl = MagicMock() result_aapl.scalar_one_or_none.return_value = existing_aapl result_msft = MagicMock() result_msft.scalar_one_or_none.return_value = None # Execute calls: day-start snapshot, AAPL lookup, MSFT lookup, DELETE mock_session.execute = AsyncMock( side_effect=[result_day_start, result_aapl, result_msft, MagicMock()] ) await _sync_once(mock_broker, mock_session_factory) # AAPL should be updated in place assert existing_aapl.qty == 10.0 assert existing_aapl.avg_entry == 150.0 assert existing_aapl.unrealized_pnl == 50.0 # MSFT should be added as new (snapshot + MSFT = 2 adds) assert mock_session.add.call_count == 2 # snapshot + new MSFT async def test_removes_closed_positions( self, mock_session_factory, mock_session ) -> None: # Broker returns only AAPL (MSFT was sold) broker = AsyncMock() broker.get_account = AsyncMock( return_value=AccountInfo( equity=100000, cash=90000, buying_power=90000, portfolio_value=100000 ) ) broker.get_positions = AsyncMock( return_value=[ PositionInfo( ticker="AAPL", qty=10.0, avg_entry=150.0, current_price=155.0, unrealized_pnl=50.0, market_value=1550.0, ) ] ) execute_result = MagicMock() execute_result.scalar_one_or_none.return_value = None mock_session.execute = AsyncMock(return_value=execute_result) await _sync_once(broker, mock_session_factory) # The delete statement should have been executed # Find the delete call among execute calls delete_called = False for call in mock_session.execute.call_args_list: stmt = call[0][0] # Check if it's a delete statement (SQLAlchemy Delete object) stmt_str = str(stmt) if "DELETE" in stmt_str.upper(): delete_called = True break assert delete_called, "Expected a DELETE statement for closed positions" async def test_removes_all_positions_when_broker_has_none( self, mock_session_factory, mock_session ) -> None: broker = AsyncMock() broker.get_account = AsyncMock( return_value=AccountInfo( equity=100000, cash=100000, buying_power=100000, portfolio_value=100000 ) ) broker.get_positions = AsyncMock(return_value=[]) mock_session.execute = AsyncMock(return_value=MagicMock()) await _sync_once(broker, mock_session_factory) # Should delete all positions since broker has none delete_called = False for call in mock_session.execute.call_args_list: stmt = call[0][0] stmt_str = str(stmt) if "DELETE" in stmt_str.upper(): delete_called = True break assert delete_called, "Expected a DELETE statement to clear all positions" # --------------------------------------------------------------------------- # Error handling tests # --------------------------------------------------------------------------- class TestSyncErrorHandling: """Tests that the sync loop handles errors gracefully.""" async def test_broker_error_does_not_crash_loop( self, config, mock_session_factory ) -> None: """Broker raises an exception — loop should catch it and continue.""" call_count = 0 async def mock_sync_once(broker, sf): nonlocal call_count call_count += 1 if call_count == 1: raise ConnectionError("Broker API down") # Second call succeeds with ( patch( "services.api_gateway.tasks.portfolio_sync.AlpacaBroker" ) as MockBroker, patch( "services.api_gateway.tasks.portfolio_sync._sync_once", side_effect=mock_sync_once, ), patch( "services.api_gateway.tasks.portfolio_sync.is_market_open", return_value=True, ), ): MockBroker.return_value = AsyncMock() task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory)) # Give it time for 2 iterations (interval=1s) await asyncio.sleep(2.5) task.cancel() try: await task except asyncio.CancelledError: pass assert call_count >= 2, "Loop should have retried after the error" async def test_no_credentials_returns_immediately( self, config_no_creds, mock_session_factory ) -> None: """When Alpaca credentials are empty, the loop should exit immediately.""" task = asyncio.create_task( portfolio_sync_loop(config_no_creds, mock_session_factory) ) # Should complete almost immediately since no creds await asyncio.wait_for(task, timeout=2.0) # If we get here without timeout, the function returned correctly # --------------------------------------------------------------------------- # Market hours integration with loop # --------------------------------------------------------------------------- class TestSyncLoopMarketHours: """Tests that the loop respects market hours.""" async def test_skips_sync_outside_market_hours( self, config, mock_session_factory ) -> None: sync_called = False async def mock_sync(broker, sf): nonlocal sync_called sync_called = True with ( patch( "services.api_gateway.tasks.portfolio_sync.AlpacaBroker" ) as MockBroker, patch( "services.api_gateway.tasks.portfolio_sync._sync_once", side_effect=mock_sync, ), patch( "services.api_gateway.tasks.portfolio_sync.is_market_open", return_value=False, ), ): MockBroker.return_value = AsyncMock() task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory)) await asyncio.sleep(1.5) task.cancel() try: await task except asyncio.CancelledError: pass assert not sync_called, "Sync should not run outside market hours" async def test_runs_sync_during_market_hours( self, config, mock_session_factory ) -> None: sync_called = False async def mock_sync(broker, sf): nonlocal sync_called sync_called = True with ( patch( "services.api_gateway.tasks.portfolio_sync.AlpacaBroker" ) as MockBroker, patch( "services.api_gateway.tasks.portfolio_sync._sync_once", side_effect=mock_sync, ), patch( "services.api_gateway.tasks.portfolio_sync.is_market_open", return_value=True, ), ): MockBroker.return_value = AsyncMock() task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory)) await asyncio.sleep(1.5) task.cancel() try: await task except asyncio.CancelledError: pass assert sync_called, "Sync should run during market hours"