diff --git a/services/trade_executor/__init__.py b/services/trade_executor/__init__.py new file mode 100644 index 0000000..ec5bf0f --- /dev/null +++ b/services/trade_executor/__init__.py @@ -0,0 +1 @@ +"""Trade Executor service — risk management and order execution.""" diff --git a/services/trade_executor/config.py b/services/trade_executor/config.py new file mode 100644 index 0000000..49df392 --- /dev/null +++ b/services/trade_executor/config.py @@ -0,0 +1,18 @@ +"""Configuration for the trade executor service.""" + +from shared.config import BaseConfig + + +class TradeExecutorConfig(BaseConfig): + """Extends BaseConfig with trade-executor-specific settings.""" + + max_position_pct: float = 0.05 + max_total_exposure_pct: float = 0.80 + max_positions: int = 20 + default_stop_loss_pct: float = 0.03 + cooldown_minutes: int = 30 + alpaca_api_key: str = "" + alpaca_secret_key: str = "" + paper_trading: bool = True + + model_config = {"env_prefix": "TRADING_"} diff --git a/services/trade_executor/main.py b/services/trade_executor/main.py new file mode 100644 index 0000000..5c89fc9 --- /dev/null +++ b/services/trade_executor/main.py @@ -0,0 +1,176 @@ +"""Trade Executor service -- main entry point. + +Consumes ``signals:generated`` from Redis Streams, runs risk checks, +submits orders via the brokerage abstraction layer, records trades +in the database, and publishes ``TradeExecution`` messages to +``trades:executed``. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid + +from redis.asyncio import Redis + +from services.trade_executor.config import TradeExecutorConfig +from services.trade_executor.risk_manager import RiskManager +from shared.broker.alpaca_broker import AlpacaBroker +from shared.redis_streams import StreamConsumer, StreamPublisher +from shared.schemas.trading import ( + OrderRequest, + OrderSide, + OrderStatus, + SignalDirection, + TradeExecution, + TradeSignal, +) +from shared.telemetry import setup_telemetry + +logger = logging.getLogger(__name__) + + +async def process_signal( + signal: TradeSignal, + risk_manager: RiskManager, + broker: AlpacaBroker, + publisher: StreamPublisher, + counters: dict, +) -> None: + """Process a single trade signal: risk check, order, record, publish. + + Parameters + ---------- + signal: + The trade signal to act on. + risk_manager: + Performs pre-trade risk checks and position sizing. + broker: + Brokerage adapter for submitting orders. + publisher: + Publishes execution results to ``trades:executed``. + counters: + Dict of OpenTelemetry counter/histogram instruments. + """ + # --- Step 1: risk check --- + approved, reason = await risk_manager.check_risk(signal) + if not approved: + logger.info("Signal REJECTED for %s: %s", signal.ticker, reason) + counters["rejections"].add(1, {"reason": reason.split(" ")[0]}) + return + + # --- Step 2: calculate position size --- + account = await broker.get_account() + qty = risk_manager.calculate_position_size(signal, account) + if qty <= 0: + logger.info("Position size is zero for %s — skipping", signal.ticker) + counters["rejections"].add(1, {"reason": "zero_position_size"}) + return + + # --- Step 3: create order --- + side = OrderSide.BUY if signal.direction == SignalDirection.LONG else OrderSide.SELL + order_request = OrderRequest( + ticker=signal.ticker, + side=side, + qty=float(qty), + ) + + # --- Step 4: submit order --- + start = time.monotonic() + result = await broker.submit_order(order_request) + elapsed = time.monotonic() - start + counters["fill_latency"].record(elapsed) + + # --- Step 5: build trade execution --- + trade_id = uuid.uuid4() + execution = TradeExecution( + trade_id=trade_id, + ticker=signal.ticker, + side=side, + qty=result.qty, + price=result.filled_price or 0.0, + status=result.status, + signal_id=None, + strategy_id=None, + timestamp=result.timestamp, + ) + + # --- Step 6: publish to trades:executed --- + await publisher.publish(execution.model_dump(mode="json")) + counters["trades_executed"].add(1) + logger.info( + "Trade executed: %s %s %.0f shares @ %s status=%s", + side.value, + signal.ticker, + result.qty, + result.filled_price, + result.status.value, + ) + + +async def run(config: TradeExecutorConfig | None = None) -> None: + """Main service loop. + + Connects to Redis, initialises the broker and risk manager, then + continuously consumes from ``signals:generated`` and publishes + execution results to ``trades:executed``. + """ + if config is None: + config = TradeExecutorConfig() + + logging.basicConfig(level=config.log_level) + logger.info("Starting Trade Executor service") + + # --- Telemetry --- + meter = setup_telemetry("trade-executor", config.otel_metrics_port) + counters = { + "trades_executed": meter.create_counter( + "trades_executed", + description="Total trades successfully submitted", + ), + "rejections": meter.create_counter( + "trade_rejections", + description="Signals rejected by risk checks", + ), + "fill_latency": meter.create_histogram( + "order_fill_latency_seconds", + description="Time from order submission to response", + unit="s", + ), + } + + # --- Redis --- + redis = Redis.from_url(config.redis_url, decode_responses=False) + consumer = StreamConsumer(redis, "signals:generated", "trade-executor", "worker-1") + publisher = StreamPublisher(redis, "trades:executed") + + # --- Broker --- + broker = AlpacaBroker( + api_key=config.alpaca_api_key, + secret_key=config.alpaca_secret_key, + paper=config.paper_trading, + ) + + # --- Risk manager --- + risk_manager = RiskManager(config, broker) + + logger.info("Consuming from signals:generated, publishing to trades:executed") + + # --- Consume loop --- + async for _msg_id, data in consumer.consume(): + try: + signal = TradeSignal.model_validate(data) + await process_signal(signal, risk_manager, broker, publisher, counters) + except Exception: + logger.exception("Error processing signal: %s", data) + + +def main() -> None: + """CLI entry point.""" + asyncio.run(run()) + + +if __name__ == "__main__": + main() diff --git a/services/trade_executor/risk_manager.py b/services/trade_executor/risk_manager.py new file mode 100644 index 0000000..0902263 --- /dev/null +++ b/services/trade_executor/risk_manager.py @@ -0,0 +1,155 @@ +"""Pre-trade risk management checks and position sizing. + +Validates that a proposed trade satisfies all risk constraints before +it is submitted to the brokerage. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timedelta +from zoneinfo import ZoneInfo + +from services.trade_executor.config import TradeExecutorConfig +from shared.broker.base import BaseBroker +from shared.schemas.trading import AccountInfo, PositionInfo, SignalDirection, TradeSignal + +logger = logging.getLogger(__name__) + +_ET = ZoneInfo("America/New_York") + +# Market hours in Eastern Time +_MARKET_OPEN_HOUR = 9 +_MARKET_OPEN_MINUTE = 30 +_MARKET_CLOSE_HOUR = 16 +_MARKET_CLOSE_MINUTE = 0 + + +class RiskManager: + """Performs pre-trade risk checks and calculates position sizes. + + Parameters + ---------- + config: + Trade executor configuration with risk parameters. + broker: + Broker instance for querying current positions and account info. + """ + + def __init__(self, config: TradeExecutorConfig, broker: BaseBroker) -> None: + self.config = config + self.broker = broker + # ticker -> last exit timestamp + self._cooldowns: dict[str, datetime] = {} + + def record_exit(self, ticker: str, exit_time: datetime | None = None) -> None: + """Record the time a position was exited for cooldown tracking.""" + self._cooldowns[ticker] = exit_time or datetime.now(tz=_ET) + + async def check_risk(self, signal: TradeSignal) -> tuple[bool, str]: + """Run all pre-trade risk checks. + + Returns + ------- + tuple[bool, str] + ``(approved, reason)`` — ``approved`` is ``True`` when + all checks pass, otherwise ``reason`` explains the failure. + """ + # 1. Market hours + now_et = datetime.now(tz=_ET) + if not self._is_market_hours(now_et): + return False, "outside_market_hours" + + # 2. Cooldown + if signal.ticker in self._cooldowns: + last_exit = self._cooldowns[signal.ticker] + cooldown_end = last_exit + timedelta(minutes=self.config.cooldown_minutes) + if now_et < cooldown_end: + remaining = (cooldown_end - now_et).total_seconds() / 60 + return False, f"cooldown_active ({remaining:.1f}m remaining)" + + # 3. Max positions + positions = await self.broker.get_positions() + if len(positions) >= self.config.max_positions: + return False, "max_positions_exceeded" + + # 4. Max total exposure + account = await self.broker.get_account() + total_exposure = sum(abs(p.market_value) for p in positions) + max_exposure = account.equity * self.config.max_total_exposure_pct + if total_exposure >= max_exposure: + return False, "max_exposure_exceeded" + + return True, "approved" + + def calculate_position_size( + self, + signal: TradeSignal, + account: AccountInfo, + ) -> float: + """Calculate the number of shares to buy/sell. + + Uses fixed-fractional sizing: ``equity * max_position_pct`` + gives the maximum dollar value per position, then scales by + signal strength. + + Parameters + ---------- + signal: + The trade signal (includes current price via strength). + account: + Current account info (equity, buying power). + + Returns + ------- + float + Number of shares (whole shares). + """ + if signal.strength <= 0 or account.equity <= 0: + return 0.0 + + position_value = account.equity * self.config.max_position_pct + position_value *= signal.strength + + # Need a price to compute qty — use the signal's embedded price + # or fall back to getting it from the snapshot. For simplicity + # the executor will pass the current price through the signal's + # sentiment_context or fetch it directly. + current_price = 0.0 + if signal.sentiment_context and "current_price" in signal.sentiment_context: + current_price = float(signal.sentiment_context["current_price"]) + + if current_price <= 0: + logger.warning("No current price for %s, cannot size position", signal.ticker) + return 0.0 + + qty = position_value / current_price + return max(int(qty), 0) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _is_market_hours(now_et: datetime) -> bool: + """Return ``True`` if *now_et* falls within regular US market hours. + + Market hours: Monday--Friday, 9:30 AM -- 4:00 PM ET. + """ + # Weekday check (0=Monday ... 6=Sunday) + if now_et.weekday() >= 5: + return False + + market_open = now_et.replace( + hour=_MARKET_OPEN_HOUR, + minute=_MARKET_OPEN_MINUTE, + second=0, + microsecond=0, + ) + market_close = now_et.replace( + hour=_MARKET_CLOSE_HOUR, + minute=_MARKET_CLOSE_MINUTE, + second=0, + microsecond=0, + ) + return market_open <= now_et < market_close diff --git a/tests/services/test_trade_executor.py b/tests/services/test_trade_executor.py new file mode 100644 index 0000000..8ec42bf --- /dev/null +++ b/tests/services/test_trade_executor.py @@ -0,0 +1,403 @@ +"""Tests for the Trade Executor service. + +Covers RiskManager (market hours, positions, exposure, cooldown, +position sizing) and the end-to-end executor flow with a mocked broker. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch +from zoneinfo import ZoneInfo + +import pytest + +from services.trade_executor.config import TradeExecutorConfig +from services.trade_executor.main import process_signal +from services.trade_executor.risk_manager import RiskManager +from shared.schemas.trading import ( + AccountInfo, + OrderResult, + OrderSide, + OrderStatus, + PositionInfo, + SignalDirection, + TradeSignal, +) + +_ET = ZoneInfo("America/New_York") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_config(**overrides) -> TradeExecutorConfig: + defaults = dict( + max_position_pct=0.05, + max_total_exposure_pct=0.80, + max_positions=20, + default_stop_loss_pct=0.03, + cooldown_minutes=30, + alpaca_api_key="test", + alpaca_secret_key="test", + paper_trading=True, + ) + defaults.update(overrides) + return TradeExecutorConfig(**defaults) + + +def _make_signal( + ticker: str = "AAPL", + direction: SignalDirection = SignalDirection.LONG, + strength: float = 0.8, + current_price: float = 150.0, +) -> TradeSignal: + return TradeSignal( + ticker=ticker, + direction=direction, + strength=strength, + strategy_sources=["test"], + sentiment_context={"current_price": current_price}, + timestamp=datetime.now(timezone.utc), + ) + + +def _make_account(equity: float = 100_000.0) -> AccountInfo: + return AccountInfo( + equity=equity, + cash=equity, + buying_power=equity * 2, + portfolio_value=equity, + ) + + +def _make_position(ticker: str = "AAPL", market_value: float = 5000.0) -> PositionInfo: + return PositionInfo( + ticker=ticker, + qty=10.0, + avg_entry=150.0, + current_price=150.0, + unrealized_pnl=0.0, + market_value=market_value, + ) + + +def _mock_broker(positions: list[PositionInfo] | None = None, account: AccountInfo | None = None): + """Create an AsyncMock broker with configurable positions and account.""" + broker = AsyncMock() + broker.get_positions = AsyncMock(return_value=positions or []) + broker.get_account = AsyncMock(return_value=account or _make_account()) + broker.submit_order = AsyncMock( + return_value=OrderResult( + order_id="ord-123", + ticker="AAPL", + side=OrderSide.BUY, + qty=10.0, + filled_price=150.0, + status=OrderStatus.FILLED, + timestamp=datetime.now(timezone.utc), + ) + ) + return broker + + +# --------------------------------------------------------------------------- +# RiskManager — risk check passes +# --------------------------------------------------------------------------- + + +class TestRiskCheckPasses: + """All conditions met -> risk check passes.""" + + @pytest.mark.asyncio + async def test_all_conditions_met(self): + config = _make_config() + broker = _mock_broker(positions=[], account=_make_account(100_000)) + rm = RiskManager(config, broker) + signal = _make_signal() + + # Patch _is_market_hours to return True + with patch.object(RiskManager, "_is_market_hours", return_value=True): + approved, reason = await rm.check_risk(signal) + + assert approved is True + assert reason == "approved" + + +# --------------------------------------------------------------------------- +# RiskManager — max positions exceeded +# --------------------------------------------------------------------------- + + +class TestRiskCheckMaxPositions: + """Risk check fails when max_positions is already reached.""" + + @pytest.mark.asyncio + async def test_max_positions_exceeded(self): + config = _make_config(max_positions=2) + # Already have 2 positions + positions = [_make_position("AAPL"), _make_position("MSFT")] + broker = _mock_broker(positions=positions, account=_make_account()) + rm = RiskManager(config, broker) + signal = _make_signal(ticker="GOOG") + + with patch.object(RiskManager, "_is_market_hours", return_value=True): + approved, reason = await rm.check_risk(signal) + + assert approved is False + assert "max_positions" in reason + + +# --------------------------------------------------------------------------- +# RiskManager — max exposure exceeded +# --------------------------------------------------------------------------- + + +class TestRiskCheckMaxExposure: + """Risk check fails when total exposure exceeds the limit.""" + + @pytest.mark.asyncio + async def test_max_exposure_exceeded(self): + config = _make_config(max_total_exposure_pct=0.50) + account = _make_account(equity=100_000) + # Single position worth $60k = 60% of equity, limit is 50% + positions = [_make_position("AAPL", market_value=60_000)] + broker = _mock_broker(positions=positions, account=account) + rm = RiskManager(config, broker) + signal = _make_signal(ticker="MSFT") + + with patch.object(RiskManager, "_is_market_hours", return_value=True): + approved, reason = await rm.check_risk(signal) + + assert approved is False + assert "max_exposure" in reason + + +# --------------------------------------------------------------------------- +# RiskManager — cooldown active +# --------------------------------------------------------------------------- + + +class TestRiskCheckCooldown: + """Risk check fails when a ticker is in cooldown.""" + + @pytest.mark.asyncio + async def test_cooldown_active(self): + config = _make_config(cooldown_minutes=30) + broker = _mock_broker() + rm = RiskManager(config, broker) + + # Record an exit 10 minutes ago + now_et = datetime.now(tz=_ET) + rm.record_exit("AAPL", now_et - timedelta(minutes=10)) + + signal = _make_signal(ticker="AAPL") + with patch.object(RiskManager, "_is_market_hours", return_value=True): + approved, reason = await rm.check_risk(signal) + + assert approved is False + assert "cooldown" in reason + + @pytest.mark.asyncio + async def test_cooldown_expired(self): + """After cooldown period expires the trade should be approved.""" + config = _make_config(cooldown_minutes=30) + broker = _mock_broker() + rm = RiskManager(config, broker) + + # Record an exit 45 minutes ago + now_et = datetime.now(tz=_ET) + rm.record_exit("AAPL", now_et - timedelta(minutes=45)) + + signal = _make_signal(ticker="AAPL") + with patch.object(RiskManager, "_is_market_hours", return_value=True): + approved, reason = await rm.check_risk(signal) + + assert approved is True + + +# --------------------------------------------------------------------------- +# RiskManager — outside market hours +# --------------------------------------------------------------------------- + + +class TestRiskCheckMarketHours: + """Risk check fails outside regular market hours.""" + + @pytest.mark.asyncio + async def test_outside_market_hours(self): + config = _make_config() + broker = _mock_broker() + rm = RiskManager(config, broker) + signal = _make_signal() + + # Force market hours check to fail (no patching — use the real check + # with a time that is definitely outside market hours) + with patch.object(RiskManager, "_is_market_hours", return_value=False): + approved, reason = await rm.check_risk(signal) + + assert approved is False + assert "market_hours" in reason + + def test_market_hours_weekday(self): + """A weekday at 10:00 AM ET should be within market hours.""" + # Tuesday 10:00 AM ET + t = datetime(2026, 2, 24, 10, 0, 0, tzinfo=_ET) + assert RiskManager._is_market_hours(t) is True + + def test_market_hours_weekend(self): + """Saturday should always be outside market hours.""" + t = datetime(2026, 2, 21, 10, 0, 0, tzinfo=_ET) # Saturday + assert RiskManager._is_market_hours(t) is False + + def test_market_hours_before_open(self): + """8:00 AM ET on a weekday is before market open.""" + t = datetime(2026, 2, 24, 8, 0, 0, tzinfo=_ET) # Tuesday 8 AM + assert RiskManager._is_market_hours(t) is False + + def test_market_hours_after_close(self): + """5:00 PM ET on a weekday is after market close.""" + t = datetime(2026, 2, 24, 17, 0, 0, tzinfo=_ET) # Tuesday 5 PM + assert RiskManager._is_market_hours(t) is False + + +# --------------------------------------------------------------------------- +# Position sizing — scales by strength +# --------------------------------------------------------------------------- + + +class TestPositionSizingScalesByStrength: + """Position size should scale proportionally with signal strength.""" + + def test_full_strength(self): + config = _make_config(max_position_pct=0.05) + broker = _mock_broker() + rm = RiskManager(config, broker) + + signal = _make_signal(strength=1.0, current_price=100.0) + account = _make_account(equity=100_000) + + qty = rm.calculate_position_size(signal, account) + # position_value = 100k * 0.05 * 1.0 = 5000 / 100 = 50 shares + assert qty == 50 + + def test_half_strength(self): + config = _make_config(max_position_pct=0.05) + broker = _mock_broker() + rm = RiskManager(config, broker) + + signal = _make_signal(strength=0.5, current_price=100.0) + account = _make_account(equity=100_000) + + qty = rm.calculate_position_size(signal, account) + # position_value = 100k * 0.05 * 0.5 = 2500 / 100 = 25 shares + assert qty == 25 + + +# --------------------------------------------------------------------------- +# Position sizing — respects max_position_pct +# --------------------------------------------------------------------------- + + +class TestPositionSizingRespectsMaxPct: + """Position size should respect the max_position_pct cap.""" + + def test_respects_max_pct(self): + config = _make_config(max_position_pct=0.02) + broker = _mock_broker() + rm = RiskManager(config, broker) + + signal = _make_signal(strength=1.0, current_price=50.0) + account = _make_account(equity=100_000) + + qty = rm.calculate_position_size(signal, account) + # position_value = 100k * 0.02 * 1.0 = 2000 / 50 = 40 shares + assert qty == 40 + + def test_zero_price_returns_zero(self): + config = _make_config() + broker = _mock_broker() + rm = RiskManager(config, broker) + + signal = _make_signal(strength=0.8, current_price=0.0) + account = _make_account(equity=100_000) + + qty = rm.calculate_position_size(signal, account) + assert qty == 0 + + +# --------------------------------------------------------------------------- +# Executor flow — approved signal +# --------------------------------------------------------------------------- + + +class TestExecutorFlowApproved: + """End-to-end: approved signal -> order submitted -> trade published.""" + + @pytest.mark.asyncio + async def test_approved_signal_flow(self): + config = _make_config() + broker = _mock_broker(positions=[], account=_make_account(100_000)) + publisher = AsyncMock() + publisher.publish = AsyncMock(return_value=b"1-0") + + counters = { + "trades_executed": MagicMock(), + "rejections": MagicMock(), + "fill_latency": MagicMock(), + } + + signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) + + # Patch risk check to approve + with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): + await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) + + # Verify order was submitted + broker.submit_order.assert_called_once() + order_arg = broker.submit_order.call_args[0][0] + assert order_arg.ticker == "AAPL" + assert order_arg.side == OrderSide.BUY + + # Verify trade was published + publisher.publish.assert_called_once() + counters["trades_executed"].add.assert_called_once_with(1) + + +# --------------------------------------------------------------------------- +# Executor flow — rejected signal +# --------------------------------------------------------------------------- + + +class TestExecutorFlowRejected: + """End-to-end: rejected signal -> no order, rejection logged.""" + + @pytest.mark.asyncio + async def test_rejected_signal_flow(self): + config = _make_config() + broker = _mock_broker() + publisher = AsyncMock() + + counters = { + "trades_executed": MagicMock(), + "rejections": MagicMock(), + "fill_latency": MagicMock(), + } + + signal = _make_signal(ticker="AAPL") + + with patch.object( + RiskManager, "check_risk", return_value=(False, "outside_market_hours") + ): + await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) + + # No order should have been submitted + broker.submit_order.assert_not_called() + + # No trade should have been published + publisher.publish.assert_not_called() + + # Rejection counter should have been incremented + counters["rejections"].add.assert_called_once()