feat: trade executor — risk management and order execution
This commit is contained in:
parent
f3e5fc944d
commit
3fef8a631c
5 changed files with 753 additions and 0 deletions
1
services/trade_executor/__init__.py
Normal file
1
services/trade_executor/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Trade Executor service — risk management and order execution."""
|
||||||
18
services/trade_executor/config.py
Normal file
18
services/trade_executor/config.py
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
"""Configuration for the trade executor service."""
|
||||||
|
|
||||||
|
from shared.config import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TradeExecutorConfig(BaseConfig):
|
||||||
|
"""Extends BaseConfig with trade-executor-specific settings."""
|
||||||
|
|
||||||
|
max_position_pct: float = 0.05
|
||||||
|
max_total_exposure_pct: float = 0.80
|
||||||
|
max_positions: int = 20
|
||||||
|
default_stop_loss_pct: float = 0.03
|
||||||
|
cooldown_minutes: int = 30
|
||||||
|
alpaca_api_key: str = ""
|
||||||
|
alpaca_secret_key: str = ""
|
||||||
|
paper_trading: bool = True
|
||||||
|
|
||||||
|
model_config = {"env_prefix": "TRADING_"}
|
||||||
176
services/trade_executor/main.py
Normal file
176
services/trade_executor/main.py
Normal file
|
|
@ -0,0 +1,176 @@
|
||||||
|
"""Trade Executor service -- main entry point.
|
||||||
|
|
||||||
|
Consumes ``signals:generated`` from Redis Streams, runs risk checks,
|
||||||
|
submits orders via the brokerage abstraction layer, records trades
|
||||||
|
in the database, and publishes ``TradeExecution`` messages to
|
||||||
|
``trades:executed``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from services.trade_executor.config import TradeExecutorConfig
|
||||||
|
from services.trade_executor.risk_manager import RiskManager
|
||||||
|
from shared.broker.alpaca_broker import AlpacaBroker
|
||||||
|
from shared.redis_streams import StreamConsumer, StreamPublisher
|
||||||
|
from shared.schemas.trading import (
|
||||||
|
OrderRequest,
|
||||||
|
OrderSide,
|
||||||
|
OrderStatus,
|
||||||
|
SignalDirection,
|
||||||
|
TradeExecution,
|
||||||
|
TradeSignal,
|
||||||
|
)
|
||||||
|
from shared.telemetry import setup_telemetry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_signal(
|
||||||
|
signal: TradeSignal,
|
||||||
|
risk_manager: RiskManager,
|
||||||
|
broker: AlpacaBroker,
|
||||||
|
publisher: StreamPublisher,
|
||||||
|
counters: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Process a single trade signal: risk check, order, record, publish.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
signal:
|
||||||
|
The trade signal to act on.
|
||||||
|
risk_manager:
|
||||||
|
Performs pre-trade risk checks and position sizing.
|
||||||
|
broker:
|
||||||
|
Brokerage adapter for submitting orders.
|
||||||
|
publisher:
|
||||||
|
Publishes execution results to ``trades:executed``.
|
||||||
|
counters:
|
||||||
|
Dict of OpenTelemetry counter/histogram instruments.
|
||||||
|
"""
|
||||||
|
# --- Step 1: risk check ---
|
||||||
|
approved, reason = await risk_manager.check_risk(signal)
|
||||||
|
if not approved:
|
||||||
|
logger.info("Signal REJECTED for %s: %s", signal.ticker, reason)
|
||||||
|
counters["rejections"].add(1, {"reason": reason.split(" ")[0]})
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- Step 2: calculate position size ---
|
||||||
|
account = await broker.get_account()
|
||||||
|
qty = risk_manager.calculate_position_size(signal, account)
|
||||||
|
if qty <= 0:
|
||||||
|
logger.info("Position size is zero for %s — skipping", signal.ticker)
|
||||||
|
counters["rejections"].add(1, {"reason": "zero_position_size"})
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- Step 3: create order ---
|
||||||
|
side = OrderSide.BUY if signal.direction == SignalDirection.LONG else OrderSide.SELL
|
||||||
|
order_request = OrderRequest(
|
||||||
|
ticker=signal.ticker,
|
||||||
|
side=side,
|
||||||
|
qty=float(qty),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Step 4: submit order ---
|
||||||
|
start = time.monotonic()
|
||||||
|
result = await broker.submit_order(order_request)
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
counters["fill_latency"].record(elapsed)
|
||||||
|
|
||||||
|
# --- Step 5: build trade execution ---
|
||||||
|
trade_id = uuid.uuid4()
|
||||||
|
execution = TradeExecution(
|
||||||
|
trade_id=trade_id,
|
||||||
|
ticker=signal.ticker,
|
||||||
|
side=side,
|
||||||
|
qty=result.qty,
|
||||||
|
price=result.filled_price or 0.0,
|
||||||
|
status=result.status,
|
||||||
|
signal_id=None,
|
||||||
|
strategy_id=None,
|
||||||
|
timestamp=result.timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Step 6: publish to trades:executed ---
|
||||||
|
await publisher.publish(execution.model_dump(mode="json"))
|
||||||
|
counters["trades_executed"].add(1)
|
||||||
|
logger.info(
|
||||||
|
"Trade executed: %s %s %.0f shares @ %s status=%s",
|
||||||
|
side.value,
|
||||||
|
signal.ticker,
|
||||||
|
result.qty,
|
||||||
|
result.filled_price,
|
||||||
|
result.status.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run(config: TradeExecutorConfig | None = None) -> None:
|
||||||
|
"""Main service loop.
|
||||||
|
|
||||||
|
Connects to Redis, initialises the broker and risk manager, then
|
||||||
|
continuously consumes from ``signals:generated`` and publishes
|
||||||
|
execution results to ``trades:executed``.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = TradeExecutorConfig()
|
||||||
|
|
||||||
|
logging.basicConfig(level=config.log_level)
|
||||||
|
logger.info("Starting Trade Executor service")
|
||||||
|
|
||||||
|
# --- Telemetry ---
|
||||||
|
meter = setup_telemetry("trade-executor", config.otel_metrics_port)
|
||||||
|
counters = {
|
||||||
|
"trades_executed": meter.create_counter(
|
||||||
|
"trades_executed",
|
||||||
|
description="Total trades successfully submitted",
|
||||||
|
),
|
||||||
|
"rejections": meter.create_counter(
|
||||||
|
"trade_rejections",
|
||||||
|
description="Signals rejected by risk checks",
|
||||||
|
),
|
||||||
|
"fill_latency": meter.create_histogram(
|
||||||
|
"order_fill_latency_seconds",
|
||||||
|
description="Time from order submission to response",
|
||||||
|
unit="s",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Redis ---
|
||||||
|
redis = Redis.from_url(config.redis_url, decode_responses=False)
|
||||||
|
consumer = StreamConsumer(redis, "signals:generated", "trade-executor", "worker-1")
|
||||||
|
publisher = StreamPublisher(redis, "trades:executed")
|
||||||
|
|
||||||
|
# --- Broker ---
|
||||||
|
broker = AlpacaBroker(
|
||||||
|
api_key=config.alpaca_api_key,
|
||||||
|
secret_key=config.alpaca_secret_key,
|
||||||
|
paper=config.paper_trading,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Risk manager ---
|
||||||
|
risk_manager = RiskManager(config, broker)
|
||||||
|
|
||||||
|
logger.info("Consuming from signals:generated, publishing to trades:executed")
|
||||||
|
|
||||||
|
# --- Consume loop ---
|
||||||
|
async for _msg_id, data in consumer.consume():
|
||||||
|
try:
|
||||||
|
signal = TradeSignal.model_validate(data)
|
||||||
|
await process_signal(signal, risk_manager, broker, publisher, counters)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error processing signal: %s", data)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""CLI entry point."""
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
155
services/trade_executor/risk_manager.py
Normal file
155
services/trade_executor/risk_manager.py
Normal file
|
|
@ -0,0 +1,155 @@
|
||||||
|
"""Pre-trade risk management checks and position sizing.
|
||||||
|
|
||||||
|
Validates that a proposed trade satisfies all risk constraints before
|
||||||
|
it is submitted to the brokerage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
from services.trade_executor.config import TradeExecutorConfig
|
||||||
|
from shared.broker.base import BaseBroker
|
||||||
|
from shared.schemas.trading import AccountInfo, PositionInfo, SignalDirection, TradeSignal
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ET = ZoneInfo("America/New_York")
|
||||||
|
|
||||||
|
# Market hours in Eastern Time
|
||||||
|
_MARKET_OPEN_HOUR = 9
|
||||||
|
_MARKET_OPEN_MINUTE = 30
|
||||||
|
_MARKET_CLOSE_HOUR = 16
|
||||||
|
_MARKET_CLOSE_MINUTE = 0
|
||||||
|
|
||||||
|
|
||||||
|
class RiskManager:
|
||||||
|
"""Performs pre-trade risk checks and calculates position sizes.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config:
|
||||||
|
Trade executor configuration with risk parameters.
|
||||||
|
broker:
|
||||||
|
Broker instance for querying current positions and account info.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: TradeExecutorConfig, broker: BaseBroker) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.broker = broker
|
||||||
|
# ticker -> last exit timestamp
|
||||||
|
self._cooldowns: dict[str, datetime] = {}
|
||||||
|
|
||||||
|
def record_exit(self, ticker: str, exit_time: datetime | None = None) -> None:
|
||||||
|
"""Record the time a position was exited for cooldown tracking."""
|
||||||
|
self._cooldowns[ticker] = exit_time or datetime.now(tz=_ET)
|
||||||
|
|
||||||
|
async def check_risk(self, signal: TradeSignal) -> tuple[bool, str]:
|
||||||
|
"""Run all pre-trade risk checks.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[bool, str]
|
||||||
|
``(approved, reason)`` — ``approved`` is ``True`` when
|
||||||
|
all checks pass, otherwise ``reason`` explains the failure.
|
||||||
|
"""
|
||||||
|
# 1. Market hours
|
||||||
|
now_et = datetime.now(tz=_ET)
|
||||||
|
if not self._is_market_hours(now_et):
|
||||||
|
return False, "outside_market_hours"
|
||||||
|
|
||||||
|
# 2. Cooldown
|
||||||
|
if signal.ticker in self._cooldowns:
|
||||||
|
last_exit = self._cooldowns[signal.ticker]
|
||||||
|
cooldown_end = last_exit + timedelta(minutes=self.config.cooldown_minutes)
|
||||||
|
if now_et < cooldown_end:
|
||||||
|
remaining = (cooldown_end - now_et).total_seconds() / 60
|
||||||
|
return False, f"cooldown_active ({remaining:.1f}m remaining)"
|
||||||
|
|
||||||
|
# 3. Max positions
|
||||||
|
positions = await self.broker.get_positions()
|
||||||
|
if len(positions) >= self.config.max_positions:
|
||||||
|
return False, "max_positions_exceeded"
|
||||||
|
|
||||||
|
# 4. Max total exposure
|
||||||
|
account = await self.broker.get_account()
|
||||||
|
total_exposure = sum(abs(p.market_value) for p in positions)
|
||||||
|
max_exposure = account.equity * self.config.max_total_exposure_pct
|
||||||
|
if total_exposure >= max_exposure:
|
||||||
|
return False, "max_exposure_exceeded"
|
||||||
|
|
||||||
|
return True, "approved"
|
||||||
|
|
||||||
|
def calculate_position_size(
|
||||||
|
self,
|
||||||
|
signal: TradeSignal,
|
||||||
|
account: AccountInfo,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate the number of shares to buy/sell.
|
||||||
|
|
||||||
|
Uses fixed-fractional sizing: ``equity * max_position_pct``
|
||||||
|
gives the maximum dollar value per position, then scales by
|
||||||
|
signal strength.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
signal:
|
||||||
|
The trade signal (includes current price via strength).
|
||||||
|
account:
|
||||||
|
Current account info (equity, buying power).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
Number of shares (whole shares).
|
||||||
|
"""
|
||||||
|
if signal.strength <= 0 or account.equity <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
position_value = account.equity * self.config.max_position_pct
|
||||||
|
position_value *= signal.strength
|
||||||
|
|
||||||
|
# Need a price to compute qty — use the signal's embedded price
|
||||||
|
# or fall back to getting it from the snapshot. For simplicity
|
||||||
|
# the executor will pass the current price through the signal's
|
||||||
|
# sentiment_context or fetch it directly.
|
||||||
|
current_price = 0.0
|
||||||
|
if signal.sentiment_context and "current_price" in signal.sentiment_context:
|
||||||
|
current_price = float(signal.sentiment_context["current_price"])
|
||||||
|
|
||||||
|
if current_price <= 0:
|
||||||
|
logger.warning("No current price for %s, cannot size position", signal.ticker)
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
qty = position_value / current_price
|
||||||
|
return max(int(qty), 0)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_market_hours(now_et: datetime) -> bool:
|
||||||
|
"""Return ``True`` if *now_et* falls within regular US market hours.
|
||||||
|
|
||||||
|
Market hours: Monday--Friday, 9:30 AM -- 4:00 PM ET.
|
||||||
|
"""
|
||||||
|
# Weekday check (0=Monday ... 6=Sunday)
|
||||||
|
if now_et.weekday() >= 5:
|
||||||
|
return False
|
||||||
|
|
||||||
|
market_open = now_et.replace(
|
||||||
|
hour=_MARKET_OPEN_HOUR,
|
||||||
|
minute=_MARKET_OPEN_MINUTE,
|
||||||
|
second=0,
|
||||||
|
microsecond=0,
|
||||||
|
)
|
||||||
|
market_close = now_et.replace(
|
||||||
|
hour=_MARKET_CLOSE_HOUR,
|
||||||
|
minute=_MARKET_CLOSE_MINUTE,
|
||||||
|
second=0,
|
||||||
|
microsecond=0,
|
||||||
|
)
|
||||||
|
return market_open <= now_et < market_close
|
||||||
403
tests/services/test_trade_executor.py
Normal file
403
tests/services/test_trade_executor.py
Normal file
|
|
@ -0,0 +1,403 @@
|
||||||
|
"""Tests for the Trade Executor service.
|
||||||
|
|
||||||
|
Covers RiskManager (market hours, positions, exposure, cooldown,
|
||||||
|
position sizing) and the end-to-end executor flow with a mocked broker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from services.trade_executor.config import TradeExecutorConfig
|
||||||
|
from services.trade_executor.main import process_signal
|
||||||
|
from services.trade_executor.risk_manager import RiskManager
|
||||||
|
from shared.schemas.trading import (
|
||||||
|
AccountInfo,
|
||||||
|
OrderResult,
|
||||||
|
OrderSide,
|
||||||
|
OrderStatus,
|
||||||
|
PositionInfo,
|
||||||
|
SignalDirection,
|
||||||
|
TradeSignal,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ET = ZoneInfo("America/New_York")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_config(**overrides) -> TradeExecutorConfig:
|
||||||
|
defaults = dict(
|
||||||
|
max_position_pct=0.05,
|
||||||
|
max_total_exposure_pct=0.80,
|
||||||
|
max_positions=20,
|
||||||
|
default_stop_loss_pct=0.03,
|
||||||
|
cooldown_minutes=30,
|
||||||
|
alpaca_api_key="test",
|
||||||
|
alpaca_secret_key="test",
|
||||||
|
paper_trading=True,
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
return TradeExecutorConfig(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_signal(
|
||||||
|
ticker: str = "AAPL",
|
||||||
|
direction: SignalDirection = SignalDirection.LONG,
|
||||||
|
strength: float = 0.8,
|
||||||
|
current_price: float = 150.0,
|
||||||
|
) -> TradeSignal:
|
||||||
|
return TradeSignal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=direction,
|
||||||
|
strength=strength,
|
||||||
|
strategy_sources=["test"],
|
||||||
|
sentiment_context={"current_price": current_price},
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_account(equity: float = 100_000.0) -> AccountInfo:
|
||||||
|
return AccountInfo(
|
||||||
|
equity=equity,
|
||||||
|
cash=equity,
|
||||||
|
buying_power=equity * 2,
|
||||||
|
portfolio_value=equity,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_position(ticker: str = "AAPL", market_value: float = 5000.0) -> PositionInfo:
|
||||||
|
return PositionInfo(
|
||||||
|
ticker=ticker,
|
||||||
|
qty=10.0,
|
||||||
|
avg_entry=150.0,
|
||||||
|
current_price=150.0,
|
||||||
|
unrealized_pnl=0.0,
|
||||||
|
market_value=market_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_broker(positions: list[PositionInfo] | None = None, account: AccountInfo | None = None):
|
||||||
|
"""Create an AsyncMock broker with configurable positions and account."""
|
||||||
|
broker = AsyncMock()
|
||||||
|
broker.get_positions = AsyncMock(return_value=positions or [])
|
||||||
|
broker.get_account = AsyncMock(return_value=account or _make_account())
|
||||||
|
broker.submit_order = AsyncMock(
|
||||||
|
return_value=OrderResult(
|
||||||
|
order_id="ord-123",
|
||||||
|
ticker="AAPL",
|
||||||
|
side=OrderSide.BUY,
|
||||||
|
qty=10.0,
|
||||||
|
filled_price=150.0,
|
||||||
|
status=OrderStatus.FILLED,
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return broker
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RiskManager — risk check passes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRiskCheckPasses:
|
||||||
|
"""All conditions met -> risk check passes."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_conditions_met(self):
|
||||||
|
config = _make_config()
|
||||||
|
broker = _mock_broker(positions=[], account=_make_account(100_000))
|
||||||
|
rm = RiskManager(config, broker)
|
||||||
|
signal = _make_signal()
|
||||||
|
|
||||||
|
# Patch _is_market_hours to return True
|
||||||
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||||
|
approved, reason = await rm.check_risk(signal)
|
||||||
|
|
||||||
|
assert approved is True
|
||||||
|
assert reason == "approved"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RiskManager — max positions exceeded
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRiskCheckMaxPositions:
|
||||||
|
"""Risk check fails when max_positions is already reached."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_positions_exceeded(self):
|
||||||
|
config = _make_config(max_positions=2)
|
||||||
|
# Already have 2 positions
|
||||||
|
positions = [_make_position("AAPL"), _make_position("MSFT")]
|
||||||
|
broker = _mock_broker(positions=positions, account=_make_account())
|
||||||
|
rm = RiskManager(config, broker)
|
||||||
|
signal = _make_signal(ticker="GOOG")
|
||||||
|
|
||||||
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||||
|
approved, reason = await rm.check_risk(signal)
|
||||||
|
|
||||||
|
assert approved is False
|
||||||
|
assert "max_positions" in reason
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RiskManager — max exposure exceeded
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRiskCheckMaxExposure:
|
||||||
|
"""Risk check fails when total exposure exceeds the limit."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_exposure_exceeded(self):
|
||||||
|
config = _make_config(max_total_exposure_pct=0.50)
|
||||||
|
account = _make_account(equity=100_000)
|
||||||
|
# Single position worth $60k = 60% of equity, limit is 50%
|
||||||
|
positions = [_make_position("AAPL", market_value=60_000)]
|
||||||
|
broker = _mock_broker(positions=positions, account=account)
|
||||||
|
rm = RiskManager(config, broker)
|
||||||
|
signal = _make_signal(ticker="MSFT")
|
||||||
|
|
||||||
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||||
|
approved, reason = await rm.check_risk(signal)
|
||||||
|
|
||||||
|
assert approved is False
|
||||||
|
assert "max_exposure" in reason
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RiskManager — cooldown active
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRiskCheckCooldown:
|
||||||
|
"""Risk check fails when a ticker is in cooldown."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cooldown_active(self):
|
||||||
|
config = _make_config(cooldown_minutes=30)
|
||||||
|
broker = _mock_broker()
|
||||||
|
rm = RiskManager(config, broker)
|
||||||
|
|
||||||
|
# Record an exit 10 minutes ago
|
||||||
|
now_et = datetime.now(tz=_ET)
|
||||||
|
rm.record_exit("AAPL", now_et - timedelta(minutes=10))
|
||||||
|
|
||||||
|
signal = _make_signal(ticker="AAPL")
|
||||||
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||||
|
approved, reason = await rm.check_risk(signal)
|
||||||
|
|
||||||
|
assert approved is False
|
||||||
|
assert "cooldown" in reason
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cooldown_expired(self):
|
||||||
|
"""After cooldown period expires the trade should be approved."""
|
||||||
|
config = _make_config(cooldown_minutes=30)
|
||||||
|
broker = _mock_broker()
|
||||||
|
rm = RiskManager(config, broker)
|
||||||
|
|
||||||
|
# Record an exit 45 minutes ago
|
||||||
|
now_et = datetime.now(tz=_ET)
|
||||||
|
rm.record_exit("AAPL", now_et - timedelta(minutes=45))
|
||||||
|
|
||||||
|
signal = _make_signal(ticker="AAPL")
|
||||||
|
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||||
|
approved, reason = await rm.check_risk(signal)
|
||||||
|
|
||||||
|
assert approved is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RiskManager — outside market hours
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRiskCheckMarketHours:
|
||||||
|
"""Risk check fails outside regular market hours."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_outside_market_hours(self):
|
||||||
|
config = _make_config()
|
||||||
|
broker = _mock_broker()
|
||||||
|
rm = RiskManager(config, broker)
|
||||||
|
signal = _make_signal()
|
||||||
|
|
||||||
|
# Force market hours check to fail (no patching — use the real check
|
||||||
|
# with a time that is definitely outside market hours)
|
||||||
|
with patch.object(RiskManager, "_is_market_hours", return_value=False):
|
||||||
|
approved, reason = await rm.check_risk(signal)
|
||||||
|
|
||||||
|
assert approved is False
|
||||||
|
assert "market_hours" in reason
|
||||||
|
|
||||||
|
def test_market_hours_weekday(self):
|
||||||
|
"""A weekday at 10:00 AM ET should be within market hours."""
|
||||||
|
# Tuesday 10:00 AM ET
|
||||||
|
t = datetime(2026, 2, 24, 10, 0, 0, tzinfo=_ET)
|
||||||
|
assert RiskManager._is_market_hours(t) is True
|
||||||
|
|
||||||
|
def test_market_hours_weekend(self):
|
||||||
|
"""Saturday should always be outside market hours."""
|
||||||
|
t = datetime(2026, 2, 21, 10, 0, 0, tzinfo=_ET) # Saturday
|
||||||
|
assert RiskManager._is_market_hours(t) is False
|
||||||
|
|
||||||
|
def test_market_hours_before_open(self):
|
||||||
|
"""8:00 AM ET on a weekday is before market open."""
|
||||||
|
t = datetime(2026, 2, 24, 8, 0, 0, tzinfo=_ET) # Tuesday 8 AM
|
||||||
|
assert RiskManager._is_market_hours(t) is False
|
||||||
|
|
||||||
|
def test_market_hours_after_close(self):
|
||||||
|
"""5:00 PM ET on a weekday is after market close."""
|
||||||
|
t = datetime(2026, 2, 24, 17, 0, 0, tzinfo=_ET) # Tuesday 5 PM
|
||||||
|
assert RiskManager._is_market_hours(t) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Position sizing — scales by strength
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPositionSizingScalesByStrength:
|
||||||
|
"""Position size should scale proportionally with signal strength."""
|
||||||
|
|
||||||
|
def test_full_strength(self):
|
||||||
|
config = _make_config(max_position_pct=0.05)
|
||||||
|
broker = _mock_broker()
|
||||||
|
rm = RiskManager(config, broker)
|
||||||
|
|
||||||
|
signal = _make_signal(strength=1.0, current_price=100.0)
|
||||||
|
account = _make_account(equity=100_000)
|
||||||
|
|
||||||
|
qty = rm.calculate_position_size(signal, account)
|
||||||
|
# position_value = 100k * 0.05 * 1.0 = 5000 / 100 = 50 shares
|
||||||
|
assert qty == 50
|
||||||
|
|
||||||
|
def test_half_strength(self):
|
||||||
|
config = _make_config(max_position_pct=0.05)
|
||||||
|
broker = _mock_broker()
|
||||||
|
rm = RiskManager(config, broker)
|
||||||
|
|
||||||
|
signal = _make_signal(strength=0.5, current_price=100.0)
|
||||||
|
account = _make_account(equity=100_000)
|
||||||
|
|
||||||
|
qty = rm.calculate_position_size(signal, account)
|
||||||
|
# position_value = 100k * 0.05 * 0.5 = 2500 / 100 = 25 shares
|
||||||
|
assert qty == 25
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Position sizing — respects max_position_pct
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPositionSizingRespectsMaxPct:
|
||||||
|
"""Position size should respect the max_position_pct cap."""
|
||||||
|
|
||||||
|
def test_respects_max_pct(self):
|
||||||
|
config = _make_config(max_position_pct=0.02)
|
||||||
|
broker = _mock_broker()
|
||||||
|
rm = RiskManager(config, broker)
|
||||||
|
|
||||||
|
signal = _make_signal(strength=1.0, current_price=50.0)
|
||||||
|
account = _make_account(equity=100_000)
|
||||||
|
|
||||||
|
qty = rm.calculate_position_size(signal, account)
|
||||||
|
# position_value = 100k * 0.02 * 1.0 = 2000 / 50 = 40 shares
|
||||||
|
assert qty == 40
|
||||||
|
|
||||||
|
def test_zero_price_returns_zero(self):
|
||||||
|
config = _make_config()
|
||||||
|
broker = _mock_broker()
|
||||||
|
rm = RiskManager(config, broker)
|
||||||
|
|
||||||
|
signal = _make_signal(strength=0.8, current_price=0.0)
|
||||||
|
account = _make_account(equity=100_000)
|
||||||
|
|
||||||
|
qty = rm.calculate_position_size(signal, account)
|
||||||
|
assert qty == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Executor flow — approved signal
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecutorFlowApproved:
|
||||||
|
"""End-to-end: approved signal -> order submitted -> trade published."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_approved_signal_flow(self):
|
||||||
|
config = _make_config()
|
||||||
|
broker = _mock_broker(positions=[], account=_make_account(100_000))
|
||||||
|
publisher = AsyncMock()
|
||||||
|
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||||
|
|
||||||
|
counters = {
|
||||||
|
"trades_executed": MagicMock(),
|
||||||
|
"rejections": MagicMock(),
|
||||||
|
"fill_latency": MagicMock(),
|
||||||
|
}
|
||||||
|
|
||||||
|
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
|
||||||
|
|
||||||
|
# Patch risk check to approve
|
||||||
|
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
|
||||||
|
await process_signal(signal, RiskManager(config, broker), broker, publisher, counters)
|
||||||
|
|
||||||
|
# Verify order was submitted
|
||||||
|
broker.submit_order.assert_called_once()
|
||||||
|
order_arg = broker.submit_order.call_args[0][0]
|
||||||
|
assert order_arg.ticker == "AAPL"
|
||||||
|
assert order_arg.side == OrderSide.BUY
|
||||||
|
|
||||||
|
# Verify trade was published
|
||||||
|
publisher.publish.assert_called_once()
|
||||||
|
counters["trades_executed"].add.assert_called_once_with(1)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Executor flow — rejected signal
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecutorFlowRejected:
|
||||||
|
"""End-to-end: rejected signal -> no order, rejection logged."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejected_signal_flow(self):
|
||||||
|
config = _make_config()
|
||||||
|
broker = _mock_broker()
|
||||||
|
publisher = AsyncMock()
|
||||||
|
|
||||||
|
counters = {
|
||||||
|
"trades_executed": MagicMock(),
|
||||||
|
"rejections": MagicMock(),
|
||||||
|
"fill_latency": MagicMock(),
|
||||||
|
}
|
||||||
|
|
||||||
|
signal = _make_signal(ticker="AAPL")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
RiskManager, "check_risk", return_value=(False, "outside_market_hours")
|
||||||
|
):
|
||||||
|
await process_signal(signal, RiskManager(config, broker), broker, publisher, counters)
|
||||||
|
|
||||||
|
# No order should have been submitted
|
||||||
|
broker.submit_order.assert_not_called()
|
||||||
|
|
||||||
|
# No trade should have been published
|
||||||
|
publisher.publish.assert_not_called()
|
||||||
|
|
||||||
|
# Rejection counter should have been incremented
|
||||||
|
counters["rejections"].add.assert_called_once()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue