"""Tests for the Trade Executor service. Covers RiskManager (market hours, positions, exposure, cooldown, position sizing) and the end-to-end executor flow with a mocked broker. """ from __future__ import annotations from datetime import datetime, timedelta, timezone from decimal import Decimal from unittest.mock import AsyncMock, MagicMock, patch from uuid import UUID from zoneinfo import ZoneInfo import pytest from services.trade_executor.config import TradeExecutorConfig from services.trade_executor.main import process_signal from services.trade_executor.risk_manager import RiskManager from shared.constants.kevin import KEVIN_STRATEGY_UUID from shared.models.trading import TradeSide as TradeSideModel from shared.schemas.trading import ( AccountInfo, OrderResult, OrderSide, OrderStatus, PositionInfo, SignalDirection, TradeSignal, ) _ET = ZoneInfo("America/New_York") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_config(**overrides) -> TradeExecutorConfig: defaults = dict( max_position_pct=0.05, max_total_exposure_pct=0.80, max_positions=20, default_stop_loss_pct=0.03, cooldown_minutes=30, alpaca_api_key="test", alpaca_secret_key="test", paper_trading=True, ) defaults.update(overrides) return TradeExecutorConfig(**defaults) def _make_signal( ticker: str = "AAPL", direction: SignalDirection = SignalDirection.LONG, strength: float = 0.8, current_price: float = 150.0, ) -> TradeSignal: return TradeSignal( ticker=ticker, direction=direction, strength=strength, strategy_sources=["test"], sentiment_context={"current_price": current_price}, timestamp=datetime.now(timezone.utc), ) def _make_kevin_signal( ticker: str = "NVDA", direction: SignalDirection = SignalDirection.LONG, strength: float = 0.8, current_price: Decimal | None = Decimal("100"), target_dollars: Decimal | None = Decimal("2000"), stop_loss_pct: Decimal | None = Decimal("0.08"), take_profit_pct: Decimal | None = Decimal("0.20"), strategy_id: UUID | None = KEVIN_STRATEGY_UUID, ) -> TradeSignal: """A Kevin-style signal: price on the new ``current_price`` field, pre-computed ``target_dollars``, and stop/take percentages — but NO ``sentiment_context``.""" return TradeSignal( ticker=ticker, direction=direction, strength=strength, strategy_sources=["kevin:buy:0.8"], current_price=current_price, target_dollars=target_dollars, stop_loss_pct=stop_loss_pct, take_profit_pct=take_profit_pct, strategy_id=strategy_id, timestamp=datetime.now(timezone.utc), ) def _make_account(equity: float = 100_000.0) -> AccountInfo: return AccountInfo( equity=equity, cash=equity, buying_power=equity * 2, portfolio_value=equity, ) def _make_position(ticker: str = "AAPL", market_value: float = 5000.0) -> PositionInfo: return PositionInfo( ticker=ticker, qty=10.0, avg_entry=150.0, current_price=150.0, unrealized_pnl=0.0, market_value=market_value, ) def _mock_broker(positions: list[PositionInfo] | None = None, account: AccountInfo | None = None): """Create an AsyncMock broker with configurable positions and account.""" broker = AsyncMock() broker.get_positions = AsyncMock(return_value=positions or []) broker.get_account = AsyncMock(return_value=account or _make_account()) broker.submit_order = AsyncMock( return_value=OrderResult( order_id="ord-123", ticker="AAPL", side=OrderSide.BUY, qty=10.0, filled_price=150.0, status=OrderStatus.FILLED, timestamp=datetime.now(timezone.utc), ) ) return broker # --------------------------------------------------------------------------- # RiskManager — risk check passes # --------------------------------------------------------------------------- class TestRiskCheckPasses: """All conditions met -> risk check passes.""" @pytest.mark.asyncio async def test_all_conditions_met(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) rm = RiskManager(config, broker) signal = _make_signal() # Patch _is_market_hours to return True with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is True assert reason == "approved" class TestRiskCheckTradingPaused: """Risk check fails when trading is paused via Redis flag.""" @pytest.mark.asyncio async def test_paused_flag_rejects(self): config = _make_config() broker = _mock_broker() redis_mock = AsyncMock() redis_mock.get = AsyncMock(return_value=b"1") rm = RiskManager(config, broker, redis=redis_mock) signal = _make_signal() with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is False assert reason == "trading_paused" @pytest.mark.asyncio async def test_no_pause_flag_passes_through(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) redis_mock = AsyncMock() redis_mock.get = AsyncMock(return_value=None) rm = RiskManager(config, broker, redis=redis_mock) signal = _make_signal() with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is True @pytest.mark.asyncio async def test_no_redis_skips_pause_check(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) rm = RiskManager(config, broker, redis=None) signal = _make_signal() with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is True # --------------------------------------------------------------------------- # RiskManager — max positions exceeded # --------------------------------------------------------------------------- class TestRiskCheckMaxPositions: """Risk check fails when max_positions is already reached.""" @pytest.mark.asyncio async def test_max_positions_exceeded(self): config = _make_config(max_positions=2) # Already have 2 positions positions = [_make_position("AAPL"), _make_position("MSFT")] broker = _mock_broker(positions=positions, account=_make_account()) rm = RiskManager(config, broker) signal = _make_signal(ticker="GOOG") with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is False assert "max_positions" in reason # --------------------------------------------------------------------------- # RiskManager — max exposure exceeded # --------------------------------------------------------------------------- class TestRiskCheckMaxExposure: """Risk check fails when total exposure exceeds the limit.""" @pytest.mark.asyncio async def test_max_exposure_exceeded(self): config = _make_config(max_total_exposure_pct=0.50) account = _make_account(equity=100_000) # Single position worth $60k = 60% of equity, limit is 50% positions = [_make_position("AAPL", market_value=60_000)] broker = _mock_broker(positions=positions, account=account) rm = RiskManager(config, broker) signal = _make_signal(ticker="MSFT") with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is False assert "max_exposure" in reason # --------------------------------------------------------------------------- # RiskManager — cooldown active # --------------------------------------------------------------------------- class TestRiskCheckCooldown: """Risk check fails when a ticker is in cooldown.""" @pytest.mark.asyncio async def test_cooldown_active(self): config = _make_config(cooldown_minutes=30) broker = _mock_broker() rm = RiskManager(config, broker) # Record an exit 10 minutes ago now_et = datetime.now(tz=_ET) rm.record_exit("AAPL", now_et - timedelta(minutes=10)) signal = _make_signal(ticker="AAPL") with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is False assert "cooldown" in reason @pytest.mark.asyncio async def test_cooldown_expired(self): """After cooldown period expires the trade should be approved.""" config = _make_config(cooldown_minutes=30) broker = _mock_broker() rm = RiskManager(config, broker) # Record an exit 45 minutes ago now_et = datetime.now(tz=_ET) rm.record_exit("AAPL", now_et - timedelta(minutes=45)) signal = _make_signal(ticker="AAPL") with patch.object(RiskManager, "_is_market_hours", return_value=True): approved, reason = await rm.check_risk(signal) assert approved is True # --------------------------------------------------------------------------- # RiskManager — outside market hours # --------------------------------------------------------------------------- class TestRiskCheckMarketHours: """Risk check fails outside regular market hours.""" @pytest.mark.asyncio async def test_outside_market_hours(self): config = _make_config() broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal() # Force market hours check to fail (no patching — use the real check # with a time that is definitely outside market hours) with patch.object(RiskManager, "_is_market_hours", return_value=False): approved, reason = await rm.check_risk(signal) assert approved is False assert "market_hours" in reason def test_market_hours_weekday(self): """A weekday at 10:00 AM ET should be within market hours.""" # Tuesday 10:00 AM ET t = datetime(2026, 2, 24, 10, 0, 0, tzinfo=_ET) assert RiskManager._is_market_hours(t) is True def test_market_hours_weekend(self): """Saturday should always be outside market hours.""" t = datetime(2026, 2, 21, 10, 0, 0, tzinfo=_ET) # Saturday assert RiskManager._is_market_hours(t) is False def test_market_hours_before_open(self): """8:00 AM ET on a weekday is before market open.""" t = datetime(2026, 2, 24, 8, 0, 0, tzinfo=_ET) # Tuesday 8 AM assert RiskManager._is_market_hours(t) is False def test_market_hours_after_close(self): """5:00 PM ET on a weekday is after market close.""" t = datetime(2026, 2, 24, 17, 0, 0, tzinfo=_ET) # Tuesday 5 PM assert RiskManager._is_market_hours(t) is False # --------------------------------------------------------------------------- # Position sizing — scales by strength # --------------------------------------------------------------------------- class TestPositionSizingScalesByStrength: """Position size should scale proportionally with signal strength.""" def test_full_strength(self): config = _make_config(max_position_pct=0.05) broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal(strength=1.0, current_price=100.0) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # position_value = 100k * 0.05 * 1.0 = 5000 / 100 = 50 shares assert qty == 50 def test_half_strength(self): config = _make_config(max_position_pct=0.05) broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal(strength=0.5, current_price=100.0) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # position_value = 100k * 0.05 * 0.5 = 2500 / 100 = 25 shares assert qty == 25 # --------------------------------------------------------------------------- # Position sizing — respects max_position_pct # --------------------------------------------------------------------------- class TestPositionSizingRespectsMaxPct: """Position size should respect the max_position_pct cap.""" def test_respects_max_pct(self): config = _make_config(max_position_pct=0.02) broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal(strength=1.0, current_price=50.0) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # position_value = 100k * 0.02 * 1.0 = 2000 / 50 = 40 shares assert qty == 40 def test_zero_price_returns_zero(self): config = _make_config() broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_signal(strength=0.8, current_price=0.0) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) assert qty == 0 # --------------------------------------------------------------------------- # Position sizing — Kevin path (target_dollars + current_price field) # --------------------------------------------------------------------------- class TestPositionSizingHonorsTargetDollars: """When the signal carries ``target_dollars`` (Kevin's pre-computed sizing), use it directly and ignore signal strength.""" def test_target_dollars_drives_qty(self): config = _make_config(max_position_pct=0.05) broker = _mock_broker() rm = RiskManager(config, broker) signal = _make_kevin_signal( target_dollars=Decimal("2000"), current_price=Decimal("100") ) account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # 2000 / 100 = 20 shares — NOT scaled by strength/max_position_pct. assert qty == 20 def test_target_dollars_ignores_strength(self): """qty must be identical regardless of strength when target_dollars is set.""" config = _make_config(max_position_pct=0.05) broker = _mock_broker() rm = RiskManager(config, broker) account = _make_account(equity=100_000) low = rm.calculate_position_size( _make_kevin_signal( strength=0.1, target_dollars=Decimal("2000"), current_price=Decimal("100") ), account, ) high = rm.calculate_position_size( _make_kevin_signal( strength=1.0, target_dollars=Decimal("2000"), current_price=Decimal("100") ), account, ) assert low == high == 20 class TestPositionSizingReadsCurrentPriceField: """Sizing must read the new ``current_price`` field when ``sentiment_context`` is absent (the legacy price source).""" def test_current_price_field_used_when_no_sentiment_context(self): config = _make_config(max_position_pct=0.05) broker = _mock_broker() rm = RiskManager(config, broker) # Legacy fixed-fractional path (no target_dollars) but price lives # on the new field, not in sentiment_context. signal = _make_kevin_signal( strength=1.0, current_price=Decimal("100"), target_dollars=None, stop_loss_pct=None, take_profit_pct=None, ) assert signal.sentiment_context is None account = _make_account(equity=100_000) qty = rm.calculate_position_size(signal, account) # 100k * 0.05 * 1.0 = 5000 / 100 = 50 shares. assert qty == 50 # --------------------------------------------------------------------------- # Executor flow — approved signal # --------------------------------------------------------------------------- class TestExecutorFlowApproved: """End-to-end: approved signal -> order submitted -> trade published.""" @pytest.mark.asyncio async def test_approved_signal_flow(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) # Patch risk check to approve with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) # Verify order was submitted broker.submit_order.assert_called_once() order_arg = broker.submit_order.call_args[0][0] assert order_arg.ticker == "AAPL" assert order_arg.side == OrderSide.BUY # Verify trade was published publisher.publish.assert_called_once() counters["trades_executed"].add.assert_called_once_with(1) # --------------------------------------------------------------------------- # Executor flow — bracket vs simple order construction # --------------------------------------------------------------------------- class TestExecutorBracketOrders: """LONG entries with both stop/take pcts become BRACKET orders; EXIT signals (or signals missing a pct) stay SIMPLE.""" @pytest.mark.asyncio async def test_long_entry_with_pcts_builds_bracket(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_kevin_signal( ticker="NVDA", direction=SignalDirection.LONG, current_price=Decimal("100"), target_dollars=Decimal("2000"), stop_loss_pct=Decimal("0.08"), take_profit_pct=Decimal("0.20"), ) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) broker.submit_order.assert_called_once() order_arg = broker.submit_order.call_args[0][0] assert order_arg.order_class == "bracket" assert order_arg.side == OrderSide.BUY # entry=100 → stop=100*(1-0.08)=92.0, take=100*(1+0.20)=120.0 assert order_arg.stop_loss_price == 92.0 assert order_arg.take_profit_price == 120.0 @pytest.mark.asyncio async def test_exit_signal_with_pcts_stays_simple(self): """The bridge stamps stop/take pcts even on EXIT signals; the direction guard must keep the resulting SELL order SIMPLE.""" config = _make_config() # EXIT now requires a held position to size from. held = _make_position(ticker="NVDA", market_value=2000.0) broker = _mock_broker(positions=[held], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_kevin_signal( ticker="NVDA", direction=SignalDirection.EXIT, current_price=Decimal("100"), target_dollars=Decimal("2000"), stop_loss_pct=Decimal("0.08"), take_profit_pct=Decimal("0.20"), ) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) broker.submit_order.assert_called_once() order_arg = broker.submit_order.call_args[0][0] assert order_arg.order_class == "simple" assert order_arg.side == OrderSide.SELL assert order_arg.take_profit_price is None assert order_arg.stop_loss_price is None # --------------------------------------------------------------------------- # Executor flow — EXIT sizing from the held broker position # --------------------------------------------------------------------------- def _held_position(ticker: str, qty: float, avg_entry: float) -> PositionInfo: return PositionInfo( ticker=ticker, qty=qty, avg_entry=avg_entry, current_price=avg_entry, unrealized_pnl=0.0, market_value=qty * avg_entry, ) def _exit_filled_broker( positions: list[PositionInfo], fill_price: float, fill_qty: float ): """Broker whose submit_order returns a FILLED SELL at fill_price/fill_qty.""" broker = AsyncMock() broker.get_positions = AsyncMock(return_value=positions) broker.get_account = AsyncMock(return_value=_make_account(100_000)) broker.submit_order = AsyncMock( return_value=OrderResult( order_id="ord-exit", ticker=positions[0].ticker if positions else "NVDA", side=OrderSide.SELL, qty=fill_qty, filled_price=fill_price, status=OrderStatus.FILLED, timestamp=datetime.now(timezone.utc), ) ) return broker class TestExecutorExitSizing: """EXIT signals must be sized from the currently-held broker position, NOT from the signal's target_dollars (which would open/size a fresh position).""" @pytest.mark.asyncio async def test_exit_sells_full_held_qty(self): """A Kevin EXIT carrying target_dollars=$2000 (=20 sh @ $100) on a position of 37 held shares must SELL 37 — the full held qty.""" config = _make_config() held = _held_position("NVDA", qty=37.0, avg_entry=90.0) broker = _exit_filled_broker([held], fill_price=110.0, fill_qty=37.0) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_kevin_signal( ticker="NVDA", direction=SignalDirection.EXIT, current_price=Decimal("100"), target_dollars=Decimal("2000"), ) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters ) broker.submit_order.assert_called_once() order_arg = broker.submit_order.call_args[0][0] assert order_arg.side == OrderSide.SELL assert order_arg.qty == 37.0 # held qty, NOT 20 (=target_dollars/price) assert order_arg.order_class == "simple" @pytest.mark.asyncio async def test_exit_with_no_held_position_submits_nothing(self): """EXIT for a ticker with no held position → no order, skip logged, rejection counted (never a zero/garbage sell).""" config = _make_config() # Holds a DIFFERENT ticker — nothing for NVDA. broker = _exit_filled_broker( [_held_position("AAPL", qty=10.0, avg_entry=150.0)], fill_price=110.0, fill_qty=0.0, ) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_kevin_signal( ticker="NVDA", direction=SignalDirection.EXIT, current_price=Decimal("100"), target_dollars=Decimal("2000"), ) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters ) broker.submit_order.assert_not_called() publisher.publish.assert_not_called() counters["rejections"].add.assert_called_once() @pytest.mark.asyncio async def test_entry_sizing_path_unchanged(self): """LONG entries keep the risk_manager.calculate_position_size path — target_dollars=$2000 @ $100 → 20 shares (not driven by held qty).""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_kevin_signal( ticker="NVDA", direction=SignalDirection.LONG, current_price=Decimal("100"), target_dollars=Decimal("2000"), ) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters ) broker.submit_order.assert_called_once() order_arg = broker.submit_order.call_args[0][0] assert order_arg.side == OrderSide.BUY assert order_arg.qty == 20.0 # target_dollars / current_price # --------------------------------------------------------------------------- # Executor flow — realized P&L on close # --------------------------------------------------------------------------- class TestExecutorRealizedPnl: """When an EXIT fill closes a long, the persisted Trade row carries the round-trip realized P&L; ENTRY trades leave pnl=None.""" @pytest.mark.asyncio async def test_exit_writes_realized_pnl(self): """SELL 10 @ 110 against avg_entry 90 → pnl = (110-90)*10 = 200.""" config = _make_config() held = _held_position("NVDA", qty=10.0, avg_entry=90.0) broker = _exit_filled_broker([held], fill_price=110.0, fill_qty=10.0) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_kevin_signal( ticker="NVDA", direction=SignalDirection.EXIT, current_price=Decimal("100"), target_dollars=Decimal("2000"), ) mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock() db_factory = _make_mock_db_session_factory(mock_session) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, db_factory, ) trade_obj = mock_session.add.call_args[0][0] assert trade_obj.side == TradeSideModel.SELL assert trade_obj.pnl == 200.0 @pytest.mark.asyncio async def test_entry_trade_has_null_pnl(self): """An ENTRY (LONG) trade is persisted with pnl=None.""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_kevin_signal( ticker="NVDA", direction=SignalDirection.LONG, current_price=Decimal("100"), target_dollars=Decimal("2000"), ) mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock() db_factory = _make_mock_db_session_factory(mock_session) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, db_factory, ) trade_obj = mock_session.add.call_args[0][0] assert trade_obj.side == TradeSideModel.BUY assert trade_obj.pnl is None # --------------------------------------------------------------------------- # Executor flow — rejected signal # --------------------------------------------------------------------------- class TestExecutorFlowRejected: """End-to-end: rejected signal -> no order, rejection logged.""" @pytest.mark.asyncio async def test_rejected_signal_flow(self): config = _make_config() broker = _mock_broker() publisher = AsyncMock() counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL") with patch.object( RiskManager, "check_risk", return_value=(False, "outside_market_hours") ): await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) # No order should have been submitted broker.submit_order.assert_not_called() # No trade should have been published publisher.publish.assert_not_called() # Rejection counter should have been incremented counters["rejections"].add.assert_called_once() # --------------------------------------------------------------------------- # Executor flow — DB persistence # --------------------------------------------------------------------------- def _make_mock_db_session_factory(session=None): """Create a mock async_sessionmaker that yields a mock session.""" if session is None: session = AsyncMock() session.add = MagicMock() session.commit = AsyncMock() factory = MagicMock() ctx = AsyncMock() ctx.__aenter__ = AsyncMock(return_value=session) ctx.__aexit__ = AsyncMock(return_value=False) factory.return_value = ctx return factory class TestExecutorDBPersistence: """Verify that trades are persisted to the DB when db_session_factory is provided.""" @pytest.mark.asyncio async def test_trade_persisted_with_signal_id(self): """When db_session_factory is provided, a Trade row should be created.""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock() db_factory = _make_mock_db_session_factory(mock_session) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, db_factory ) # Trade should be persisted mock_session.add.assert_called_once() mock_session.commit.assert_awaited_once() # Verify the trade object trade_obj = mock_session.add.call_args[0][0] assert trade_obj.ticker == "AAPL" assert trade_obj.signal_id == signal.signal_id @pytest.mark.asyncio async def test_entry_trade_records_broker_order_id(self): """The persisted entry Trade must carry the Alpaca order id so reconciliation can find the bracket order later.""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_kevin_signal(direction=SignalDirection.LONG) mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock() db_factory = _make_mock_db_session_factory(mock_session) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, db_factory ) trade_obj = mock_session.add.call_args[0][0] assert trade_obj.broker_order_id == "ord-123" @pytest.mark.asyncio async def test_trade_not_persisted_without_db(self): """When db_session_factory is None, no DB write should happen.""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, None ) # Should still publish publisher.publish.assert_called_once() @pytest.mark.asyncio async def test_db_error_does_not_block_publishing(self): """A DB error should not prevent the trade from being published.""" config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock(side_effect=RuntimeError("DB connection lost")) db_factory = _make_mock_db_session_factory(mock_session) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, db_factory ) # Trade should still be published despite DB error publisher.publish.assert_called_once() counters["trades_executed"].add.assert_called_once_with(1) def test_signal_id_flows_through_execution(self): """signal_id from TradeSignal should appear in the published TradeExecution.""" signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) assert signal.signal_id is not None assert isinstance(signal.signal_id, UUID) class TestExecutorStrategyAttribution: """Kevin trades must carry strategy_id so they are recorded as Kevin trades and surface on the dashboard (which filters by strategy_id).""" @pytest.mark.asyncio async def test_strategy_id_persisted_and_published(self): config = _make_config() broker = _mock_broker(positions=[], account=_make_account(100_000)) publisher = AsyncMock() publisher.publish = AsyncMock(return_value=b"1-0") counters = { "trades_executed": MagicMock(), "rejections": MagicMock(), "fill_latency": MagicMock(), } signal = _make_kevin_signal() mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock() db_factory = _make_mock_db_session_factory(mock_session) with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): await process_signal( signal, RiskManager(config, broker), broker, publisher, counters, db_factory ) trade_obj = mock_session.add.call_args[0][0] assert trade_obj.strategy_id == KEVIN_STRATEGY_UUID published = publisher.publish.call_args[0][0] assert published["strategy_id"] == str(KEVIN_STRATEGY_UUID)