"""Tests for the Trade Executor service. Covers RiskManager (market hours, positions, exposure, cooldown, position sizing) and the end-to-end executor flow with a mocked broker. """ from __future__ import annotations from datetime import datetime, timedelta, timezone from decimal import Decimal from unittest.mock import AsyncMock, MagicMock, patch from zoneinfo import ZoneInfo import pytest from services.trade_executor.config import TradeExecutorConfig from services.trade_executor.main import process_signal from services.trade_executor.risk_manager import RiskManager from shared.schemas.trading import ( AccountInfo, OrderResult, OrderSide, OrderStatus, PositionInfo, SignalDirection, TradeSignal, ) _ET = ZoneInfo("America/New_York") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_config(**overrides) -> TradeExecutorConfig: defaults = dict( max_position_pct=0.05, max_total_exposure_pct=0.80, max_positions=20, default_stop_loss_pct=0.03, cooldown_minutes=30, alpaca_api_key="test", alpaca_secret_key="test", paper_trading=True, ) defaults.update(overrides) return TradeExecutorConfig(**defaults) def _make_signal( ticker: str = "AAPL", direction: SignalDirection = SignalDirection.LONG, strength: float = 0.8, current_price: float = 150.0, ) -> TradeSignal: return TradeSignal( ticker=ticker, direction=direction, strength=strength, strategy_sources=["test"], sentiment_context={"current_price": current_price}, timestamp=datetime.now(timezone.utc), ) def _make_kevin_signal( ticker: str = "NVDA", direction: SignalDirection = SignalDirection.LONG, strength: float = 0.8, current_price: Decimal | None = Decimal("100"), target_dollars: Decimal | None = Decimal("2000"), stop_loss_pct: Decimal | None = Decimal("0.08"), take_profit_pct: Decimal | None = Decimal("0.20"), ) -> TradeSignal: """A Kevin-style signal: price on the new ``current_price`` field, pre-computed ``target_dollars``, and stop/take percentages — but NO ``sentiment_context``.""" return TradeSignal( ticker=ticker, direction=direction, strength=strength, strategy_sources=["kevin:buy:0.8"], current_price=current_price, target_dollars=target_dollars, stop_loss_pct=stop_loss_pct, take_profit_pct=take_profit_pct, timestamp=datetime.now(timezone.utc), ) def _make_account(equity: float = 100_000.0) -> AccountInfo: return AccountInfo( equity=equity, cash=equity, buying_power=equity * 2, portfolio_value=equity, ) def _make_position(ticker: str = "AAPL", market_value: float = 5000.0) -> PositionInfo: return PositionInfo( ticker=ticker, qty=10.0, avg_entry=150.0, current_price=150.0, unrealized_pnl=0.0, market_value=market_value, ) def _mock_broker(positions: list[PositionInfo] | None = None, account: AccountInfo | None = None): """Create an AsyncMock broker with configurable positions and account.""" broker = AsyncMock() broker.get_positions = AsyncMock(return_value=positions or []) broker.get_account = AsyncMock(return_value=account or _make_account()) broker.submit_order = AsyncMock( return_value=OrderResult( order_id="ord-123", ticker="AAPL", side=OrderSide.BUY, qty=10.0, filled_price=150.0, status=OrderStatus.FILLED, timestamp=datetime.now(timezone.utc), ) ) return broker # --------------------------------------------------------------------------- # RiskManager — risk check passes # --------------------------------------------------------------------------- class TestRiskCheckPasses: """All conditions met -> risk check passes.""" @pytest.mark.asyncio async def test_all_conditions_met(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) rm = RiskManager(config, broker) signal = _make_signal() # Patch _is_market_hours to return True with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is True assert reason == "approved" class TestRiskCheckTradingPaused: """Risk check fails when trading is paused via Redis flag.""" @pytest.mark.asyncio async def test_paused_flag_rejects(self): config = _make_config() broker = _mock_broker() redis_mock = AsyncMock() redis_mock.get = AsyncMock(return_value=b"1") rm = RiskManager(config, broker, redis=redis_mock) signal = _make_signal() with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is False assert reason == "trading_paused" @pytest.mark.asyncio async def test_no_pause_flag_passes_through(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) redis_mock = AsyncMock() redis_mock.get = AsyncMock(return_value=None) rm = RiskManager(config, broker, redis=redis_mock) signal = _make_signal() with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is True @pytest.mark.asyncio async def test_no_redis_skips_pause_check(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) rm = RiskManager(config, broker, redis=None) signal = _make_signal() with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is True # --------------------------------------------------------------------------- # RiskManager — max positions exceeded # --------------------------------------------------------------------------- class TestRiskCheckMaxPositions: """Risk check fails when max_positions is already reached.""" @pytest.mark.asyncio async def test_max_positions_exceeded(self): config = _make_config(max_positions=2) # Already have 2 positions positions = [_make_position("AAPL"), _make_position("MSFT")] broker = _mock_broker(positions=positions, account=_make_account()) rm = RiskManager(config, broker) signal = _make_signal(ticker="GOOG") with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is False assert "max_positions" in reason # --------------------------------------------------------------------------- # RiskManager — max exposure exceeded # --------------------------------------------------------------------------- class TestRiskCheckMaxExposure: """Risk check fails when total exposure exceeds the limit.""" @pytest.mark.asyncio async def test_max_exposure_exceeded(self): config = _make_config(max_total_exposure_pct=0.50) account = _make_account(equity=100_000) # Single position worth $60k = 60% of equity, limit is 50% positions = [_make_position("AAPL", market_value=60_000)] broker = _mock_broker(positions=positions, account=account) rm = RiskManager(config, broker) signal = _make_signal(ticker="MSFT") with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is False assert "max_exposure" in reason # --------------------------------------------------------------------------- # RiskManager — cooldown active # --------------------------------------------------------------------------- class TestRiskCheckCooldown: """Risk check fails when a ticker is in cooldown.""" @pytest.mark.asyncio async def test_cooldown_active(self): config = _make_config(cooldown_minutes=30) broker = _mock_broker() rm = RiskManager(config, broker) # Record an exit 10 minutes ago now_et = datetime.now(tz=_ET) rm.record_exit("AAPL", now_et - timedelta(minutes=10)) signal = _make_signal(ticker="AAPL") with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is False assert "cooldown" in reason @pytest.mark.asyncio async def test_cooldown_expired(self): """After cooldown period expires the trade should be approved.""" config = _make_config(cooldown_minutes=30) broker = _mock_broker() rm = RiskManager(config, broker) # Record an exit 45 minutes ago now_et = datetime.now(tz=_ET) rm.record_exit("AAPL", now_et - timedelta(minutes=45)) signal = _make_signal(ticker="AAPL") with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is True # --------------------------------------------------------------------------- # RiskManager — outside market hours # --------------------------------------------------------------------------- class TestRiskCheckMarketHours: """Risk check fails outside regular market hours.""" @pytest.mark.asyncio async def test_outside_market_hours(self): config = _make_config() broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal() # Force market hours check to fail (no patching — use the real check # with a time that is definitely outside market hours) with patch.object(RiskManager, "_is_market_hours", return_value=False): approved, reason = await rm.check_risk(signal) assert approved is False assert "market_hours" in reason def test_market_hours_weekday(self): """A weekday at 10:00 AM ET should be within market hours.""" # Tuesday 10:00 AM ET t = datetime(2026, 2, 24, 10, 0, 0, tzinfo=_ET) assert RiskManager._is_market_hours(t) is True def test_market_hours_weekend(self): """Saturday should always be outside market hours.""" t = datetime(2026, 2, 21, 10, 0, 0, tzinfo=_ET) # Saturday assert RiskManager._is_market_hours(t) is False def test_market_hours_before_open(self): """8:00 AM ET on a weekday is before market open.""" t = datetime(2026, 2, 24, 8, 0, 0, tzinfo=_ET) # Tuesday 8 AM assert RiskManager._is_market_hours(t) is False def test_market_hours_after_close(self): """5:00 PM ET on a weekday is after market close.""" t = datetime(2026, 2, 24, 17, 0, 0, tzinfo=_ET) # Tuesday 5 PM assert RiskManager._is_market_hours(t) is False # --------------------------------------------------------------------------- # Position sizing — scales by strength # --------------------------------------------------------------------------- class TestPositionSizingScalesByStrength: """Position size should scale proportionally with signal strength.""" def test_full_strength(self): config = _make_config(max_position_pct=0.05) broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal(strength=1.0, current_price=100.0) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # position_value = 100k * 0.05 * 1.0 = 5000 / 100 = 50 shares assert qty == 50 def test_half_strength(self): config = _make_config(max_position_pct=0.05) broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal(strength=0.5, current_price=100.0) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # position_value = 100k * 0.05 * 0.5 = 2500 / 100 = 25 shares assert qty == 25 # --------------------------------------------------------------------------- # Position sizing — respects max_position_pct # --------------------------------------------------------------------------- class TestPositionSizingRespectsMaxPct: """Position size should respect the max_position_pct cap.""" def test_respects_max_pct(self): config = _make_config(max_position_pct=0.02) broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal(strength=1.0, current_price=50.0) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # position_value = 100k * 0.02 * 1.0 = 2000 / 50 = 40 shares assert qty == 40 def test_zero_price_returns_zero(self): config = _make_config() broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal(strength=0.8, current_price=0.0) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) assert qty == 0 # --------------------------------------------------------------------------- # Position sizing — Kevin path (target_dollars + current_price field) # --------------------------------------------------------------------------- class TestPositionSizingHonorsTargetDollars: """When the signal carries ``target_dollars`` (Kevin's pre-computed sizing), use it directly and ignore signal strength.""" def test_target_dollars_drives_qty(self): config = _make_config(max_position_pct=0.05) broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_kevin_signal( target_dollars=Decimal("2000"), current_price=Decimal("100") ) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # 2000 / 100 = 20 shares — NOT scaled by strength/max_position_pct. assert qty == 20 def test_target_dollars_ignores_strength(self): """qty must be identical regardless of strength when target_dollars is set.""" config = _make_config(max_position_pct=0.05) broker = _mock_broker() rm = RiskManager(config, broker) account = _make_account(equity=100_000) low = rm.calculate_position_size( _make_kevin_signal( strength=0.1, target_dollars=Decimal("2000"), current_price=Decimal("100") ), account, ) high = rm.calculate_position_size( _make_kevin_signal( strength=1.0, target_dollars=Decimal("2000"), current_price=Decimal("100") ), account, ) assert low == high == 20 class TestPositionSizingReadsCurrentPriceField: """Sizing must read the new ``current_price`` field when ``sentiment_context`` is absent (the legacy price source).""" def test_current_price_field_used_when_no_sentiment_context(self): config = _make_config(max_position_pct=0.05) broker = _mock_broker() rm = RiskManager(config, broker) # Legacy fixed-fractional path (no target_dollars) but price lives # on the new field, not in sentiment_context. signal = _make_kevin_signal( strength=1.0, current_price=Decimal("100"), target_dollars=None, stop_loss_pct=None, take_profit_pct=None, ) assert signal.sentiment_context is None account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # 100k * 0.05 * 1.0 = 5000 / 100 = 50 shares. assert qty == 50 # --------------------------------------------------------------------------- # Executor flow — approved signal # --------------------------------------------------------------------------- class TestExecutorFlowApproved: """End-to-end: approved signal -> order submitted -> trade published.""" @pytest.mark.asyncio async def test_approved_signal_flow(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) # Patch risk check to approve with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) # Verify order was submitted broker.submit_order.assert_called_once() order_arg = broker.submit_order.call_args[0][0] assert order_arg.ticker == "AAPL" assert order_arg.side == OrderSide.BUY # Verify trade was published publisher.publish.assert_called_once() counters["trades_executed"].add.assert_called_once_with(1) # --------------------------------------------------------------------------- # Executor flow — bracket vs simple order construction # --------------------------------------------------------------------------- class TestExecutorBracketOrders: """LONG entries with both stop/take pcts become BRACKET orders; EXIT signals (or signals missing a pct) stay SIMPLE.""" @pytest.mark.asyncio async def test_long_entry_with_pcts_builds_bracket(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_kevin_signal( ticker="NVDA", direction=SignalDirection.LONG, current_price=Decimal("100"), target_dollars=Decimal("2000"), stop_loss_pct=Decimal("0.08"), take_profit_pct=Decimal("0.20"), ) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) broker.submit_order.assert_called_once() order_arg = broker.submit_order.call_args[0][0] assert order_arg.order_class == "bracket" assert order_arg.side == OrderSide.BUY # entry=100 → stop=100*(1-0.08)=92.0, take=100*(1+0.20)=120.0 assert order_arg.stop_loss_price == 92.0 assert order_arg.take_profit_price == 120.0 @pytest.mark.asyncio async def test_exit_signal_with_pcts_stays_simple(self): """The bridge stamps stop/take pcts even on EXIT signals; the direction guard must keep the resulting SELL order SIMPLE.""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_kevin_signal( ticker="NVDA", direction=SignalDirection.EXIT, current_price=Decimal("100"), target_dollars=Decimal("2000"), stop_loss_pct=Decimal("0.08"), take_profit_pct=Decimal("0.20"), ) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) broker.submit_order.assert_called_once() order_arg = broker.submit_order.call_args[0][0] assert order_arg.order_class == "simple" assert order_arg.side == OrderSide.SELL assert order_arg.take_profit_price is None assert order_arg.stop_loss_price is None # --------------------------------------------------------------------------- # Executor flow — rejected signal # --------------------------------------------------------------------------- class TestExecutorFlowRejected: """End-to-end: rejected signal -> no order, rejection logged.""" @pytest.mark.asyncio async def test_rejected_signal_flow(self): config = _make_config() broker = _mock_broker() publisher = AsyncMock() counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL") with patch.object( RiskManager, "check_risk", return_value=(False, "outside_market_hours") ): await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) # No order should have been submitted broker.submit_order.assert_not_called() # No trade should have been published publisher.publish.assert_not_called() # Rejection counter should have been incremented counters["rejections"].add.assert_called_once() # --------------------------------------------------------------------------- # Executor flow — DB persistence # --------------------------------------------------------------------------- def _make_mock_db_session_factory(session=None): """Create a mock async_sessionmaker that yields a mock session.""" if session is None: session = AsyncMock() session.add = MagicMock() session.commit = AsyncMock() factory = MagicMock() ctx = AsyncMock() ctx.__aenter__ = AsyncMock(return_value=session) ctx.__aexit__ = AsyncMock(return_value=False) factory.return_value = ctx return factory class TestExecutorDBPersistence: """Verify that trades are persisted to the DB when db_session_factory is provided.""" @pytest.mark.asyncio async def test_trade_persisted_with_signal_id(self): """When db_session_factory is provided, a Trade row should be created.""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock() db_factory = _make_mock_db_session_factory(mock_session) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, db_factory ) # Trade should be persisted mock_session.add.assert_called_once() mock_session.commit.assert_awaited_once() # Verify the trade object trade_obj = mock_session.add.call_args[0][0] assert trade_obj.ticker == "AAPL" assert trade_obj.signal_id == signal.signal_id @pytest.mark.asyncio async def test_trade_not_persisted_without_db(self): """When db_session_factory is None, no DB write should happen.""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, None ) # Should still publish publisher.publish.assert_called_once() @pytest.mark.asyncio async def test_db_error_does_not_block_publishing(self): """A DB error should not prevent the trade from being published.""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock(side_effect=RuntimeError("DB connection lost")) db_factory = _make_mock_db_session_factory(mock_session) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, db_factory ) # Trade should still be published despite DB error publisher.publish.assert_called_once() counters["trades_executed"].add.assert_called_once_with(1) def test_signal_id_flows_through_execution(self): """signal_id from TradeSignal should appear in the published TradeExecution.""" signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) assert signal.signal_id is not None # Verify signal_id is a UUID from uuid import UUID assert isinstance(signal.signal_id, UUID)