"""Tests for the Trade Executor service. Covers RiskManager (market hours, positions, exposure, cooldown, position sizing) and the end-to-end executor flow with a mocked broker. """ from __future__ import annotations from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch from zoneinfo import ZoneInfo import pytest from services.trade_executor.config import TradeExecutorConfig from services.trade_executor.main import process_signal from services.trade_executor.risk_manager import RiskManager from shared.schemas.trading import ( AccountInfo, OrderResult, OrderSide, OrderStatus, PositionInfo, SignalDirection, TradeSignal, ) _ET = ZoneInfo("America/New_York") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_config(**overrides) -> TradeExecutorConfig: defaults = dict( max_position_pct=0.05, max_total_exposure_pct=0.80, max_positions=20, default_stop_loss_pct=0.03, cooldown_minutes=30, alpaca_api_key="test", alpaca_secret_key="test", paper_trading=True, ) defaults.update(overrides) return TradeExecutorConfig(**defaults) def _make_signal( ticker: str = "AAPL", direction: SignalDirection = SignalDirection.LONG, strength: float = 0.8, current_price: float = 150.0, ) -> TradeSignal: return TradeSignal( ticker=ticker, direction=direction, strength=strength, strategy_sources=["test"], sentiment_context={"current_price": current_price}, timestamp=datetime.now(timezone.utc), ) def _make_account(equity: float = 100_000.0) -> AccountInfo: return AccountInfo( equity=equity, cash=equity, buying_power=equity * 2, portfolio_value=equity, ) def _make_position(ticker: str = "AAPL", market_value: float = 5000.0) -> PositionInfo: return PositionInfo( ticker=ticker, qty=10.0, avg_entry=150.0, current_price=150.0, unrealized_pnl=0.0, market_value=market_value, ) def _mock_broker(positions: list[PositionInfo] | None = None, account: AccountInfo | None = None): """Create an AsyncMock broker with configurable positions and account.""" broker = AsyncMock() broker.get_positions = AsyncMock(return_value=positions or []) broker.get_account = AsyncMock(return_value=account or _make_account()) broker.submit_order = AsyncMock( return_value=OrderResult( order_id="ord-123", ticker="AAPL", side=OrderSide.BUY, qty=10.0, filled_price=150.0, status=OrderStatus.FILLED, timestamp=datetime.now(timezone.utc), ) ) return broker # --------------------------------------------------------------------------- # RiskManager — risk check passes # --------------------------------------------------------------------------- class TestRiskCheckPasses: """All conditions met -> risk check passes.""" @pytest.mark.asyncio async def test_all_conditions_met(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) rm = RiskManager(config, broker) signal = _make_signal() # Patch _is_market_hours to return True with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is True assert reason == "approved" class TestRiskCheckTradingPaused: """Risk check fails when trading is paused via Redis flag.""" @pytest.mark.asyncio async def test_paused_flag_rejects(self): config = _make_config() broker = _mock_broker() redis_mock = AsyncMock() redis_mock.get = AsyncMock(return_value=b"1") rm = RiskManager(config, broker, redis=redis_mock) signal = _make_signal() with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is False assert reason == "trading_paused" @pytest.mark.asyncio async def test_no_pause_flag_passes_through(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) redis_mock = AsyncMock() redis_mock.get = AsyncMock(return_value=None) rm = RiskManager(config, broker, redis=redis_mock) signal = _make_signal() with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is True @pytest.mark.asyncio async def test_no_redis_skips_pause_check(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) rm = RiskManager(config, broker, redis=None) signal = _make_signal() with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is True # --------------------------------------------------------------------------- # RiskManager — max positions exceeded # --------------------------------------------------------------------------- class TestRiskCheckMaxPositions: """Risk check fails when max_positions is already reached.""" @pytest.mark.asyncio async def test_max_positions_exceeded(self): config = _make_config(max_positions=2) # Already have 2 positions positions = [_make_position("AAPL"), _make_position("MSFT")] broker = _mock_broker(positions=positions, account=_make_account()) rm = RiskManager(config, broker) signal = _make_signal(ticker="GOOG") with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is False assert "max_positions" in reason # --------------------------------------------------------------------------- # RiskManager — max exposure exceeded # --------------------------------------------------------------------------- class TestRiskCheckMaxExposure: """Risk check fails when total exposure exceeds the limit.""" @pytest.mark.asyncio async def test_max_exposure_exceeded(self): config = _make_config(max_total_exposure_pct=0.50) account = _make_account(equity=100_000) # Single position worth $60k = 60% of equity, limit is 50% positions = [_make_position("AAPL", market_value=60_000)] broker = _mock_broker(positions=positions, account=account) rm = RiskManager(config, broker) signal = _make_signal(ticker="MSFT") with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is False assert "max_exposure" in reason # --------------------------------------------------------------------------- # RiskManager — cooldown active # --------------------------------------------------------------------------- class TestRiskCheckCooldown: """Risk check fails when a ticker is in cooldown.""" @pytest.mark.asyncio async def test_cooldown_active(self): config = _make_config(cooldown_minutes=30) broker = _mock_broker() rm = RiskManager(config, broker) # Record an exit 10 minutes ago now_et = datetime.now(tz=_ET) rm.record_exit("AAPL", now_et - timedelta(minutes=10)) signal = _make_signal(ticker="AAPL") with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is False assert "cooldown" in reason @pytest.mark.asyncio async def test_cooldown_expired(self): """After cooldown period expires the trade should be approved.""" config = _make_config(cooldown_minutes=30) broker = _mock_broker() rm = RiskManager(config, broker) # Record an exit 45 minutes ago now_et = datetime.now(tz=_ET) rm.record_exit("AAPL", now_et - timedelta(minutes=45)) signal = _make_signal(ticker="AAPL") with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is True # --------------------------------------------------------------------------- # RiskManager — outside market hours # --------------------------------------------------------------------------- class TestRiskCheckMarketHours: """Risk check fails outside regular market hours.""" @pytest.mark.asyncio async def test_outside_market_hours(self): config = _make_config() broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal() # Force market hours check to fail (no patching — use the real check # with a time that is definitely outside market hours) with patch.object(RiskManager, "_is_market_hours", return_value=False): approved, reason = await rm.check_risk(signal) assert approved is False assert "market_hours" in reason def test_market_hours_weekday(self): """A weekday at 10:00 AM ET should be within market hours.""" # Tuesday 10:00 AM ET t = datetime(2026, 2, 24, 10, 0, 0, tzinfo=_ET) assert RiskManager._is_market_hours(t) is True def test_market_hours_weekend(self): """Saturday should always be outside market hours.""" t = datetime(2026, 2, 21, 10, 0, 0, tzinfo=_ET) # Saturday assert RiskManager._is_market_hours(t) is False def test_market_hours_before_open(self): """8:00 AM ET on a weekday is before market open.""" t = datetime(2026, 2, 24, 8, 0, 0, tzinfo=_ET) # Tuesday 8 AM assert RiskManager._is_market_hours(t) is False def test_market_hours_after_close(self): """5:00 PM ET on a weekday is after market close.""" t = datetime(2026, 2, 24, 17, 0, 0, tzinfo=_ET) # Tuesday 5 PM assert RiskManager._is_market_hours(t) is False # --------------------------------------------------------------------------- # Position sizing — scales by strength # --------------------------------------------------------------------------- class TestPositionSizingScalesByStrength: """Position size should scale proportionally with signal strength.""" def test_full_strength(self): config = _make_config(max_position_pct=0.05) broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal(strength=1.0, current_price=100.0) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # position_value = 100k * 0.05 * 1.0 = 5000 / 100 = 50 shares assert qty == 50 def test_half_strength(self): config = _make_config(max_position_pct=0.05) broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal(strength=0.5, current_price=100.0) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # position_value = 100k * 0.05 * 0.5 = 2500 / 100 = 25 shares assert qty == 25 # --------------------------------------------------------------------------- # Position sizing — respects max_position_pct # --------------------------------------------------------------------------- class TestPositionSizingRespectsMaxPct: """Position size should respect the max_position_pct cap.""" def test_respects_max_pct(self): config = _make_config(max_position_pct=0.02) broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal(strength=1.0, current_price=50.0) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # position_value = 100k * 0.02 * 1.0 = 2000 / 50 = 40 shares assert qty == 40 def test_zero_price_returns_zero(self): config = _make_config() broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal(strength=0.8, current_price=0.0) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) assert qty == 0 # --------------------------------------------------------------------------- # Executor flow — approved signal # --------------------------------------------------------------------------- class TestExecutorFlowApproved: """End-to-end: approved signal -> order submitted -> trade published.""" @pytest.mark.asyncio async def test_approved_signal_flow(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) # Patch risk check to approve with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) # Verify order was submitted broker.submit_order.assert_called_once() order_arg = broker.submit_order.call_args[0][0] assert order_arg.ticker == "AAPL" assert order_arg.side == OrderSide.BUY # Verify trade was published publisher.publish.assert_called_once() counters["trades_executed"].add.assert_called_once_with(1) # --------------------------------------------------------------------------- # Executor flow — rejected signal # --------------------------------------------------------------------------- class TestExecutorFlowRejected: """End-to-end: rejected signal -> no order, rejection logged.""" @pytest.mark.asyncio async def test_rejected_signal_flow(self): config = _make_config() broker = _mock_broker() publisher = AsyncMock() counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL") with patch.object( RiskManager, "check_risk", return_value=(False, "outside_market_hours") ): await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) # No order should have been submitted broker.submit_order.assert_not_called() # No trade should have been published publisher.publish.assert_not_called() # Rejection counter should have been incremented counters["rejections"].add.assert_called_once() # --------------------------------------------------------------------------- # Executor flow — DB persistence # --------------------------------------------------------------------------- def _make_mock_db_session_factory(session=None): """Create a mock async_sessionmaker that yields a mock session.""" if session is None: session = AsyncMock() session.add = MagicMock() session.commit = AsyncMock() factory = MagicMock() ctx = AsyncMock() ctx.__aenter__ = AsyncMock(return_value=session) ctx.__aexit__ = AsyncMock(return_value=False) factory.return_value = ctx return factory class TestExecutorDBPersistence: """Verify that trades are persisted to the DB when db_session_factory is provided.""" @pytest.mark.asyncio async def test_trade_persisted_with_signal_id(self): """When db_session_factory is provided, a Trade row should be created.""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock() db_factory = _make_mock_db_session_factory(mock_session) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, db_factory ) # Trade should be persisted mock_session.add.assert_called_once() mock_session.commit.assert_awaited_once() # Verify the trade object trade_obj = mock_session.add.call_args[0][0] assert trade_obj.ticker == "AAPL" assert trade_obj.signal_id == signal.signal_id @pytest.mark.asyncio async def test_trade_not_persisted_without_db(self): """When db_session_factory is None, no DB write should happen.""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, None ) # Should still publish publisher.publish.assert_called_once() @pytest.mark.asyncio async def test_db_error_does_not_block_publishing(self): """A DB error should not prevent the trade from being published.""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock(side_effect=RuntimeError("DB connection lost")) db_factory = _make_mock_db_session_factory(mock_session) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, db_factory ) # Trade should still be published despite DB error publisher.publish.assert_called_once() counters["trades_executed"].add.assert_called_once_with(1) def test_signal_id_flows_through_execution(self): """signal_id from TradeSignal should appear in the published TradeExecution.""" signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) assert signal.signal_id is not None # Verify signal_id is a UUID from uuid import UUID assert isinstance(signal.signal_id, UUID)