trading/tests/services/test_trade_executor.py
Viktor Barzin a3cdd0f1a5
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
fix: resolve all remaining TODOs, add dev mode auth bypass
- Learning engine: expand default weights from 3 to all 9 strategies
- Learning engine: resolve placeholder strategy_id with DB lookup
- Learning engine: pass strategy_sources from trade execution
- Trade executor: respect trading:paused Redis flag in RiskManager
- Portfolio sync: compute actual daily P&L from day-start snapshot
- Portfolio API: cumulative P&L from first snapshot, read pause flag
- Portfolio metrics: compute max drawdown and avg hold duration
- Add strategy_sources field to TradeExecution schema
- Add dev_mode config (TRADING_DEV_MODE) to bypass auth for local dev
- Dashboard: VITE_DEV_MODE bypasses ProtectedRoute and 401 redirects
- Vite proxy target configurable via VITE_API_TARGET
- Add top-level README.md and remaining-work-plan.md
- Update CLAUDE.md with correct counts and remove stale TODOs
- 404 tests passing

Made-with: Cursor
2026-02-25 22:02:25 +00:00

564 lines
20 KiB
Python

"""Tests for the Trade Executor service.
Covers RiskManager (market hours, positions, exposure, cooldown,
position sizing) and the end-to-end executor flow with a mocked broker.
"""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from zoneinfo import ZoneInfo
import pytest
from services.trade_executor.config import TradeExecutorConfig
from services.trade_executor.main import process_signal
from services.trade_executor.risk_manager import RiskManager
from shared.schemas.trading import (
AccountInfo,
OrderResult,
OrderSide,
OrderStatus,
PositionInfo,
SignalDirection,
TradeSignal,
)
_ET = ZoneInfo("America/New_York")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_config(**overrides) -> TradeExecutorConfig:
defaults = dict(
max_position_pct=0.05,
max_total_exposure_pct=0.80,
max_positions=20,
default_stop_loss_pct=0.03,
cooldown_minutes=30,
alpaca_api_key="test",
alpaca_secret_key="test",
paper_trading=True,
)
defaults.update(overrides)
return TradeExecutorConfig(**defaults)
def _make_signal(
ticker: str = "AAPL",
direction: SignalDirection = SignalDirection.LONG,
strength: float = 0.8,
current_price: float = 150.0,
) -> TradeSignal:
return TradeSignal(
ticker=ticker,
direction=direction,
strength=strength,
strategy_sources=["test"],
sentiment_context={"current_price": current_price},
timestamp=datetime.now(timezone.utc),
)
def _make_account(equity: float = 100_000.0) -> AccountInfo:
return AccountInfo(
equity=equity,
cash=equity,
buying_power=equity * 2,
portfolio_value=equity,
)
def _make_position(ticker: str = "AAPL", market_value: float = 5000.0) -> PositionInfo:
return PositionInfo(
ticker=ticker,
qty=10.0,
avg_entry=150.0,
current_price=150.0,
unrealized_pnl=0.0,
market_value=market_value,
)
def _mock_broker(positions: list[PositionInfo] | None = None, account: AccountInfo | None = None):
"""Create an AsyncMock broker with configurable positions and account."""
broker = AsyncMock()
broker.get_positions = AsyncMock(return_value=positions or [])
broker.get_account = AsyncMock(return_value=account or _make_account())
broker.submit_order = AsyncMock(
return_value=OrderResult(
order_id="ord-123",
ticker="AAPL",
side=OrderSide.BUY,
qty=10.0,
filled_price=150.0,
status=OrderStatus.FILLED,
timestamp=datetime.now(timezone.utc),
)
)
return broker
# ---------------------------------------------------------------------------
# RiskManager — risk check passes
# ---------------------------------------------------------------------------
class TestRiskCheckPasses:
"""All conditions met -> risk check passes."""
@pytest.mark.asyncio
async def test_all_conditions_met(self):
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
rm = RiskManager(config, broker)
signal = _make_signal()
# Patch _is_market_hours to return True
with patch.object(RiskManager, "_is_market_hours", return_value=True):
approved, reason = await rm.check_risk(signal)
assert approved is True
assert reason == "approved"
class TestRiskCheckTradingPaused:
"""Risk check fails when trading is paused via Redis flag."""
@pytest.mark.asyncio
async def test_paused_flag_rejects(self):
config = _make_config()
broker = _mock_broker()
redis_mock = AsyncMock()
redis_mock.get = AsyncMock(return_value=b"1")
rm = RiskManager(config, broker, redis=redis_mock)
signal = _make_signal()
with patch.object(RiskManager, "_is_market_hours", return_value=True):
approved, reason = await rm.check_risk(signal)
assert approved is False
assert reason == "trading_paused"
@pytest.mark.asyncio
async def test_no_pause_flag_passes_through(self):
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
redis_mock = AsyncMock()
redis_mock.get = AsyncMock(return_value=None)
rm = RiskManager(config, broker, redis=redis_mock)
signal = _make_signal()
with patch.object(RiskManager, "_is_market_hours", return_value=True):
approved, reason = await rm.check_risk(signal)
assert approved is True
@pytest.mark.asyncio
async def test_no_redis_skips_pause_check(self):
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
rm = RiskManager(config, broker, redis=None)
signal = _make_signal()
with patch.object(RiskManager, "_is_market_hours", return_value=True):
approved, reason = await rm.check_risk(signal)
assert approved is True
# ---------------------------------------------------------------------------
# RiskManager — max positions exceeded
# ---------------------------------------------------------------------------
class TestRiskCheckMaxPositions:
"""Risk check fails when max_positions is already reached."""
@pytest.mark.asyncio
async def test_max_positions_exceeded(self):
config = _make_config(max_positions=2)
# Already have 2 positions
positions = [_make_position("AAPL"), _make_position("MSFT")]
broker = _mock_broker(positions=positions, account=_make_account())
rm = RiskManager(config, broker)
signal = _make_signal(ticker="GOOG")
with patch.object(RiskManager, "_is_market_hours", return_value=True):
approved, reason = await rm.check_risk(signal)
assert approved is False
assert "max_positions" in reason
# ---------------------------------------------------------------------------
# RiskManager — max exposure exceeded
# ---------------------------------------------------------------------------
class TestRiskCheckMaxExposure:
"""Risk check fails when total exposure exceeds the limit."""
@pytest.mark.asyncio
async def test_max_exposure_exceeded(self):
config = _make_config(max_total_exposure_pct=0.50)
account = _make_account(equity=100_000)
# Single position worth $60k = 60% of equity, limit is 50%
positions = [_make_position("AAPL", market_value=60_000)]
broker = _mock_broker(positions=positions, account=account)
rm = RiskManager(config, broker)
signal = _make_signal(ticker="MSFT")
with patch.object(RiskManager, "_is_market_hours", return_value=True):
approved, reason = await rm.check_risk(signal)
assert approved is False
assert "max_exposure" in reason
# ---------------------------------------------------------------------------
# RiskManager — cooldown active
# ---------------------------------------------------------------------------
class TestRiskCheckCooldown:
"""Risk check fails when a ticker is in cooldown."""
@pytest.mark.asyncio
async def test_cooldown_active(self):
config = _make_config(cooldown_minutes=30)
broker = _mock_broker()
rm = RiskManager(config, broker)
# Record an exit 10 minutes ago
now_et = datetime.now(tz=_ET)
rm.record_exit("AAPL", now_et - timedelta(minutes=10))
signal = _make_signal(ticker="AAPL")
with patch.object(RiskManager, "_is_market_hours", return_value=True):
approved, reason = await rm.check_risk(signal)
assert approved is False
assert "cooldown" in reason
@pytest.mark.asyncio
async def test_cooldown_expired(self):
"""After cooldown period expires the trade should be approved."""
config = _make_config(cooldown_minutes=30)
broker = _mock_broker()
rm = RiskManager(config, broker)
# Record an exit 45 minutes ago
now_et = datetime.now(tz=_ET)
rm.record_exit("AAPL", now_et - timedelta(minutes=45))
signal = _make_signal(ticker="AAPL")
with patch.object(RiskManager, "_is_market_hours", return_value=True):
approved, reason = await rm.check_risk(signal)
assert approved is True
# ---------------------------------------------------------------------------
# RiskManager — outside market hours
# ---------------------------------------------------------------------------
class TestRiskCheckMarketHours:
"""Risk check fails outside regular market hours."""
@pytest.mark.asyncio
async def test_outside_market_hours(self):
config = _make_config()
broker = _mock_broker()
rm = RiskManager(config, broker)
signal = _make_signal()
# Force market hours check to fail (no patching — use the real check
# with a time that is definitely outside market hours)
with patch.object(RiskManager, "_is_market_hours", return_value=False):
approved, reason = await rm.check_risk(signal)
assert approved is False
assert "market_hours" in reason
def test_market_hours_weekday(self):
"""A weekday at 10:00 AM ET should be within market hours."""
# Tuesday 10:00 AM ET
t = datetime(2026, 2, 24, 10, 0, 0, tzinfo=_ET)
assert RiskManager._is_market_hours(t) is True
def test_market_hours_weekend(self):
"""Saturday should always be outside market hours."""
t = datetime(2026, 2, 21, 10, 0, 0, tzinfo=_ET) # Saturday
assert RiskManager._is_market_hours(t) is False
def test_market_hours_before_open(self):
"""8:00 AM ET on a weekday is before market open."""
t = datetime(2026, 2, 24, 8, 0, 0, tzinfo=_ET) # Tuesday 8 AM
assert RiskManager._is_market_hours(t) is False
def test_market_hours_after_close(self):
"""5:00 PM ET on a weekday is after market close."""
t = datetime(2026, 2, 24, 17, 0, 0, tzinfo=_ET) # Tuesday 5 PM
assert RiskManager._is_market_hours(t) is False
# ---------------------------------------------------------------------------
# Position sizing — scales by strength
# ---------------------------------------------------------------------------
class TestPositionSizingScalesByStrength:
"""Position size should scale proportionally with signal strength."""
def test_full_strength(self):
config = _make_config(max_position_pct=0.05)
broker = _mock_broker()
rm = RiskManager(config, broker)
signal = _make_signal(strength=1.0, current_price=100.0)
account = _make_account(equity=100_000)
qty = rm.calculate_position_size(signal, account)
# position_value = 100k * 0.05 * 1.0 = 5000 / 100 = 50 shares
assert qty == 50
def test_half_strength(self):
config = _make_config(max_position_pct=0.05)
broker = _mock_broker()
rm = RiskManager(config, broker)
signal = _make_signal(strength=0.5, current_price=100.0)
account = _make_account(equity=100_000)
qty = rm.calculate_position_size(signal, account)
# position_value = 100k * 0.05 * 0.5 = 2500 / 100 = 25 shares
assert qty == 25
# ---------------------------------------------------------------------------
# Position sizing — respects max_position_pct
# ---------------------------------------------------------------------------
class TestPositionSizingRespectsMaxPct:
"""Position size should respect the max_position_pct cap."""
def test_respects_max_pct(self):
config = _make_config(max_position_pct=0.02)
broker = _mock_broker()
rm = RiskManager(config, broker)
signal = _make_signal(strength=1.0, current_price=50.0)
account = _make_account(equity=100_000)
qty = rm.calculate_position_size(signal, account)
# position_value = 100k * 0.02 * 1.0 = 2000 / 50 = 40 shares
assert qty == 40
def test_zero_price_returns_zero(self):
config = _make_config()
broker = _mock_broker()
rm = RiskManager(config, broker)
signal = _make_signal(strength=0.8, current_price=0.0)
account = _make_account(equity=100_000)
qty = rm.calculate_position_size(signal, account)
assert qty == 0
# ---------------------------------------------------------------------------
# Executor flow — approved signal
# ---------------------------------------------------------------------------
class TestExecutorFlowApproved:
"""End-to-end: approved signal -> order submitted -> trade published."""
@pytest.mark.asyncio
async def test_approved_signal_flow(self):
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
counters = {
"trades_executed": MagicMock(),
"rejections": MagicMock(),
"fill_latency": MagicMock(),
}
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
# Patch risk check to approve
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
await process_signal(signal, RiskManager(config, broker), broker, publisher, counters)
# Verify order was submitted
broker.submit_order.assert_called_once()
order_arg = broker.submit_order.call_args[0][0]
assert order_arg.ticker == "AAPL"
assert order_arg.side == OrderSide.BUY
# Verify trade was published
publisher.publish.assert_called_once()
counters["trades_executed"].add.assert_called_once_with(1)
# ---------------------------------------------------------------------------
# Executor flow — rejected signal
# ---------------------------------------------------------------------------
class TestExecutorFlowRejected:
"""End-to-end: rejected signal -> no order, rejection logged."""
@pytest.mark.asyncio
async def test_rejected_signal_flow(self):
config = _make_config()
broker = _mock_broker()
publisher = AsyncMock()
counters = {
"trades_executed": MagicMock(),
"rejections": MagicMock(),
"fill_latency": MagicMock(),
}
signal = _make_signal(ticker="AAPL")
with patch.object(
RiskManager, "check_risk", return_value=(False, "outside_market_hours")
):
await process_signal(signal, RiskManager(config, broker), broker, publisher, counters)
# No order should have been submitted
broker.submit_order.assert_not_called()
# No trade should have been published
publisher.publish.assert_not_called()
# Rejection counter should have been incremented
counters["rejections"].add.assert_called_once()
# ---------------------------------------------------------------------------
# Executor flow — DB persistence
# ---------------------------------------------------------------------------
def _make_mock_db_session_factory(session=None):
"""Create a mock async_sessionmaker that yields a mock session."""
if session is None:
session = AsyncMock()
session.add = MagicMock()
session.commit = AsyncMock()
factory = MagicMock()
ctx = AsyncMock()
ctx.__aenter__ = AsyncMock(return_value=session)
ctx.__aexit__ = AsyncMock(return_value=False)
factory.return_value = ctx
return factory
class TestExecutorDBPersistence:
"""Verify that trades are persisted to the DB when db_session_factory is provided."""
@pytest.mark.asyncio
async def test_trade_persisted_with_signal_id(self):
"""When db_session_factory is provided, a Trade row should be created."""
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
counters = {
"trades_executed": MagicMock(),
"rejections": MagicMock(),
"fill_latency": MagicMock(),
}
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
db_factory = _make_mock_db_session_factory(mock_session)
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
await process_signal(
signal, RiskManager(config, broker), broker, publisher, counters, db_factory
)
# Trade should be persisted
mock_session.add.assert_called_once()
mock_session.commit.assert_awaited_once()
# Verify the trade object
trade_obj = mock_session.add.call_args[0][0]
assert trade_obj.ticker == "AAPL"
assert trade_obj.signal_id == signal.signal_id
@pytest.mark.asyncio
async def test_trade_not_persisted_without_db(self):
"""When db_session_factory is None, no DB write should happen."""
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
counters = {
"trades_executed": MagicMock(),
"rejections": MagicMock(),
"fill_latency": MagicMock(),
}
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
await process_signal(
signal, RiskManager(config, broker), broker, publisher, counters, None
)
# Should still publish
publisher.publish.assert_called_once()
@pytest.mark.asyncio
async def test_db_error_does_not_block_publishing(self):
"""A DB error should not prevent the trade from being published."""
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
counters = {
"trades_executed": MagicMock(),
"rejections": MagicMock(),
"fill_latency": MagicMock(),
}
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock(side_effect=RuntimeError("DB connection lost"))
db_factory = _make_mock_db_session_factory(mock_session)
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
await process_signal(
signal, RiskManager(config, broker), broker, publisher, counters, db_factory
)
# Trade should still be published despite DB error
publisher.publish.assert_called_once()
counters["trades_executed"].add.assert_called_once_with(1)
def test_signal_id_flows_through_execution(self):
"""signal_id from TradeSignal should appear in the published TradeExecution."""
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
assert signal.signal_id is not None
# Verify signal_id is a UUID
from uuid import UUID
assert isinstance(signal.signal_id, UUID)