feat: trade executor — risk management and order execution
This commit is contained in:
parent
f3e5fc944d
commit
3fef8a631c
5 changed files with 753 additions and 0 deletions
403
tests/services/test_trade_executor.py
Normal file
403
tests/services/test_trade_executor.py
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
"""Tests for the Trade Executor service.
|
||||
|
||||
Covers RiskManager (market hours, positions, exposure, cooldown,
|
||||
position sizing) and the end-to-end executor flow with a mocked broker.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
|
||||
from services.trade_executor.config import TradeExecutorConfig
|
||||
from services.trade_executor.main import process_signal
|
||||
from services.trade_executor.risk_manager import RiskManager
|
||||
from shared.schemas.trading import (
|
||||
AccountInfo,
|
||||
OrderResult,
|
||||
OrderSide,
|
||||
OrderStatus,
|
||||
PositionInfo,
|
||||
SignalDirection,
|
||||
TradeSignal,
|
||||
)
|
||||
|
||||
_ET = ZoneInfo("America/New_York")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_config(**overrides) -> TradeExecutorConfig:
|
||||
defaults = dict(
|
||||
max_position_pct=0.05,
|
||||
max_total_exposure_pct=0.80,
|
||||
max_positions=20,
|
||||
default_stop_loss_pct=0.03,
|
||||
cooldown_minutes=30,
|
||||
alpaca_api_key="test",
|
||||
alpaca_secret_key="test",
|
||||
paper_trading=True,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return TradeExecutorConfig(**defaults)
|
||||
|
||||
|
||||
def _make_signal(
|
||||
ticker: str = "AAPL",
|
||||
direction: SignalDirection = SignalDirection.LONG,
|
||||
strength: float = 0.8,
|
||||
current_price: float = 150.0,
|
||||
) -> TradeSignal:
|
||||
return TradeSignal(
|
||||
ticker=ticker,
|
||||
direction=direction,
|
||||
strength=strength,
|
||||
strategy_sources=["test"],
|
||||
sentiment_context={"current_price": current_price},
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_account(equity: float = 100_000.0) -> AccountInfo:
|
||||
return AccountInfo(
|
||||
equity=equity,
|
||||
cash=equity,
|
||||
buying_power=equity * 2,
|
||||
portfolio_value=equity,
|
||||
)
|
||||
|
||||
|
||||
def _make_position(ticker: str = "AAPL", market_value: float = 5000.0) -> PositionInfo:
|
||||
return PositionInfo(
|
||||
ticker=ticker,
|
||||
qty=10.0,
|
||||
avg_entry=150.0,
|
||||
current_price=150.0,
|
||||
unrealized_pnl=0.0,
|
||||
market_value=market_value,
|
||||
)
|
||||
|
||||
|
||||
def _mock_broker(positions: list[PositionInfo] | None = None, account: AccountInfo | None = None):
|
||||
"""Create an AsyncMock broker with configurable positions and account."""
|
||||
broker = AsyncMock()
|
||||
broker.get_positions = AsyncMock(return_value=positions or [])
|
||||
broker.get_account = AsyncMock(return_value=account or _make_account())
|
||||
broker.submit_order = AsyncMock(
|
||||
return_value=OrderResult(
|
||||
order_id="ord-123",
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
qty=10.0,
|
||||
filled_price=150.0,
|
||||
status=OrderStatus.FILLED,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
return broker
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskManager — risk check passes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRiskCheckPasses:
|
||||
"""All conditions met -> risk check passes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_conditions_met(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker(positions=[], account=_make_account(100_000))
|
||||
rm = RiskManager(config, broker)
|
||||
signal = _make_signal()
|
||||
|
||||
# Patch _is_market_hours to return True
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is True
|
||||
assert reason == "approved"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskManager — max positions exceeded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRiskCheckMaxPositions:
|
||||
"""Risk check fails when max_positions is already reached."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_positions_exceeded(self):
|
||||
config = _make_config(max_positions=2)
|
||||
# Already have 2 positions
|
||||
positions = [_make_position("AAPL"), _make_position("MSFT")]
|
||||
broker = _mock_broker(positions=positions, account=_make_account())
|
||||
rm = RiskManager(config, broker)
|
||||
signal = _make_signal(ticker="GOOG")
|
||||
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is False
|
||||
assert "max_positions" in reason
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskManager — max exposure exceeded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRiskCheckMaxExposure:
|
||||
"""Risk check fails when total exposure exceeds the limit."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_exposure_exceeded(self):
|
||||
config = _make_config(max_total_exposure_pct=0.50)
|
||||
account = _make_account(equity=100_000)
|
||||
# Single position worth $60k = 60% of equity, limit is 50%
|
||||
positions = [_make_position("AAPL", market_value=60_000)]
|
||||
broker = _mock_broker(positions=positions, account=account)
|
||||
rm = RiskManager(config, broker)
|
||||
signal = _make_signal(ticker="MSFT")
|
||||
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is False
|
||||
assert "max_exposure" in reason
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskManager — cooldown active
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRiskCheckCooldown:
|
||||
"""Risk check fails when a ticker is in cooldown."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_active(self):
|
||||
config = _make_config(cooldown_minutes=30)
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
|
||||
# Record an exit 10 minutes ago
|
||||
now_et = datetime.now(tz=_ET)
|
||||
rm.record_exit("AAPL", now_et - timedelta(minutes=10))
|
||||
|
||||
signal = _make_signal(ticker="AAPL")
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is False
|
||||
assert "cooldown" in reason
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_expired(self):
|
||||
"""After cooldown period expires the trade should be approved."""
|
||||
config = _make_config(cooldown_minutes=30)
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
|
||||
# Record an exit 45 minutes ago
|
||||
now_et = datetime.now(tz=_ET)
|
||||
rm.record_exit("AAPL", now_et - timedelta(minutes=45))
|
||||
|
||||
signal = _make_signal(ticker="AAPL")
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskManager — outside market hours
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRiskCheckMarketHours:
|
||||
"""Risk check fails outside regular market hours."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_outside_market_hours(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
signal = _make_signal()
|
||||
|
||||
# Force market hours check to fail (no patching — use the real check
|
||||
# with a time that is definitely outside market hours)
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=False):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is False
|
||||
assert "market_hours" in reason
|
||||
|
||||
def test_market_hours_weekday(self):
|
||||
"""A weekday at 10:00 AM ET should be within market hours."""
|
||||
# Tuesday 10:00 AM ET
|
||||
t = datetime(2026, 2, 24, 10, 0, 0, tzinfo=_ET)
|
||||
assert RiskManager._is_market_hours(t) is True
|
||||
|
||||
def test_market_hours_weekend(self):
|
||||
"""Saturday should always be outside market hours."""
|
||||
t = datetime(2026, 2, 21, 10, 0, 0, tzinfo=_ET) # Saturday
|
||||
assert RiskManager._is_market_hours(t) is False
|
||||
|
||||
def test_market_hours_before_open(self):
|
||||
"""8:00 AM ET on a weekday is before market open."""
|
||||
t = datetime(2026, 2, 24, 8, 0, 0, tzinfo=_ET) # Tuesday 8 AM
|
||||
assert RiskManager._is_market_hours(t) is False
|
||||
|
||||
def test_market_hours_after_close(self):
|
||||
"""5:00 PM ET on a weekday is after market close."""
|
||||
t = datetime(2026, 2, 24, 17, 0, 0, tzinfo=_ET) # Tuesday 5 PM
|
||||
assert RiskManager._is_market_hours(t) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Position sizing — scales by strength
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPositionSizingScalesByStrength:
|
||||
"""Position size should scale proportionally with signal strength."""
|
||||
|
||||
def test_full_strength(self):
|
||||
config = _make_config(max_position_pct=0.05)
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
|
||||
signal = _make_signal(strength=1.0, current_price=100.0)
|
||||
account = _make_account(equity=100_000)
|
||||
|
||||
qty = rm.calculate_position_size(signal, account)
|
||||
# position_value = 100k * 0.05 * 1.0 = 5000 / 100 = 50 shares
|
||||
assert qty == 50
|
||||
|
||||
def test_half_strength(self):
|
||||
config = _make_config(max_position_pct=0.05)
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
|
||||
signal = _make_signal(strength=0.5, current_price=100.0)
|
||||
account = _make_account(equity=100_000)
|
||||
|
||||
qty = rm.calculate_position_size(signal, account)
|
||||
# position_value = 100k * 0.05 * 0.5 = 2500 / 100 = 25 shares
|
||||
assert qty == 25
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Position sizing — respects max_position_pct
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPositionSizingRespectsMaxPct:
|
||||
"""Position size should respect the max_position_pct cap."""
|
||||
|
||||
def test_respects_max_pct(self):
|
||||
config = _make_config(max_position_pct=0.02)
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
|
||||
signal = _make_signal(strength=1.0, current_price=50.0)
|
||||
account = _make_account(equity=100_000)
|
||||
|
||||
qty = rm.calculate_position_size(signal, account)
|
||||
# position_value = 100k * 0.02 * 1.0 = 2000 / 50 = 40 shares
|
||||
assert qty == 40
|
||||
|
||||
def test_zero_price_returns_zero(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
|
||||
signal = _make_signal(strength=0.8, current_price=0.0)
|
||||
account = _make_account(equity=100_000)
|
||||
|
||||
qty = rm.calculate_position_size(signal, account)
|
||||
assert qty == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor flow — approved signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecutorFlowApproved:
|
||||
"""End-to-end: approved signal -> order submitted -> trade published."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approved_signal_flow(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker(positions=[], account=_make_account(100_000))
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||
|
||||
counters = {
|
||||
"trades_executed": MagicMock(),
|
||||
"rejections": MagicMock(),
|
||||
"fill_latency": MagicMock(),
|
||||
}
|
||||
|
||||
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
|
||||
|
||||
# Patch risk check to approve
|
||||
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
|
||||
await process_signal(signal, RiskManager(config, broker), broker, publisher, counters)
|
||||
|
||||
# Verify order was submitted
|
||||
broker.submit_order.assert_called_once()
|
||||
order_arg = broker.submit_order.call_args[0][0]
|
||||
assert order_arg.ticker == "AAPL"
|
||||
assert order_arg.side == OrderSide.BUY
|
||||
|
||||
# Verify trade was published
|
||||
publisher.publish.assert_called_once()
|
||||
counters["trades_executed"].add.assert_called_once_with(1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor flow — rejected signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecutorFlowRejected:
|
||||
"""End-to-end: rejected signal -> no order, rejection logged."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejected_signal_flow(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker()
|
||||
publisher = AsyncMock()
|
||||
|
||||
counters = {
|
||||
"trades_executed": MagicMock(),
|
||||
"rejections": MagicMock(),
|
||||
"fill_latency": MagicMock(),
|
||||
}
|
||||
|
||||
signal = _make_signal(ticker="AAPL")
|
||||
|
||||
with patch.object(
|
||||
RiskManager, "check_risk", return_value=(False, "outside_market_hours")
|
||||
):
|
||||
await process_signal(signal, RiskManager(config, broker), broker, publisher, counters)
|
||||
|
||||
# No order should have been submitted
|
||||
broker.submit_order.assert_not_called()
|
||||
|
||||
# No trade should have been published
|
||||
publisher.publish.assert_not_called()
|
||||
|
||||
# Rejection counter should have been incremented
|
||||
counters["rejections"].add.assert_called_once()
|
||||
Loading…
Add table
Add a link
Reference in a new issue