feat: trade executor — risk management and order execution

This commit is contained in:
Viktor Barzin 2026-02-22 15:36:08 +00:00
parent f3e5fc944d
commit 3fef8a631c
No known key found for this signature in database
GPG key ID: 0EB088298288D958
5 changed files with 753 additions and 0 deletions

View file

@ -0,0 +1 @@
"""Trade Executor service — risk management and order execution."""

View file

@ -0,0 +1,18 @@
"""Configuration for the trade executor service."""
from shared.config import BaseConfig
class TradeExecutorConfig(BaseConfig):
"""Extends BaseConfig with trade-executor-specific settings."""
max_position_pct: float = 0.05
max_total_exposure_pct: float = 0.80
max_positions: int = 20
default_stop_loss_pct: float = 0.03
cooldown_minutes: int = 30
alpaca_api_key: str = ""
alpaca_secret_key: str = ""
paper_trading: bool = True
model_config = {"env_prefix": "TRADING_"}

View file

@ -0,0 +1,176 @@
"""Trade Executor service -- main entry point.
Consumes ``signals:generated`` from Redis Streams, runs risk checks,
submits orders via the brokerage abstraction layer, records trades
in the database, and publishes ``TradeExecution`` messages to
``trades:executed``.
"""
from __future__ import annotations
import asyncio
import logging
import time
import uuid
from redis.asyncio import Redis
from services.trade_executor.config import TradeExecutorConfig
from services.trade_executor.risk_manager import RiskManager
from shared.broker.alpaca_broker import AlpacaBroker
from shared.redis_streams import StreamConsumer, StreamPublisher
from shared.schemas.trading import (
OrderRequest,
OrderSide,
OrderStatus,
SignalDirection,
TradeExecution,
TradeSignal,
)
from shared.telemetry import setup_telemetry
logger = logging.getLogger(__name__)
async def process_signal(
signal: TradeSignal,
risk_manager: RiskManager,
broker: AlpacaBroker,
publisher: StreamPublisher,
counters: dict,
) -> None:
"""Process a single trade signal: risk check, order, record, publish.
Parameters
----------
signal:
The trade signal to act on.
risk_manager:
Performs pre-trade risk checks and position sizing.
broker:
Brokerage adapter for submitting orders.
publisher:
Publishes execution results to ``trades:executed``.
counters:
Dict of OpenTelemetry counter/histogram instruments.
"""
# --- Step 1: risk check ---
approved, reason = await risk_manager.check_risk(signal)
if not approved:
logger.info("Signal REJECTED for %s: %s", signal.ticker, reason)
counters["rejections"].add(1, {"reason": reason.split(" ")[0]})
return
# --- Step 2: calculate position size ---
account = await broker.get_account()
qty = risk_manager.calculate_position_size(signal, account)
if qty <= 0:
logger.info("Position size is zero for %s — skipping", signal.ticker)
counters["rejections"].add(1, {"reason": "zero_position_size"})
return
# --- Step 3: create order ---
side = OrderSide.BUY if signal.direction == SignalDirection.LONG else OrderSide.SELL
order_request = OrderRequest(
ticker=signal.ticker,
side=side,
qty=float(qty),
)
# --- Step 4: submit order ---
start = time.monotonic()
result = await broker.submit_order(order_request)
elapsed = time.monotonic() - start
counters["fill_latency"].record(elapsed)
# --- Step 5: build trade execution ---
trade_id = uuid.uuid4()
execution = TradeExecution(
trade_id=trade_id,
ticker=signal.ticker,
side=side,
qty=result.qty,
price=result.filled_price or 0.0,
status=result.status,
signal_id=None,
strategy_id=None,
timestamp=result.timestamp,
)
# --- Step 6: publish to trades:executed ---
await publisher.publish(execution.model_dump(mode="json"))
counters["trades_executed"].add(1)
logger.info(
"Trade executed: %s %s %.0f shares @ %s status=%s",
side.value,
signal.ticker,
result.qty,
result.filled_price,
result.status.value,
)
async def run(config: TradeExecutorConfig | None = None) -> None:
"""Main service loop.
Connects to Redis, initialises the broker and risk manager, then
continuously consumes from ``signals:generated`` and publishes
execution results to ``trades:executed``.
"""
if config is None:
config = TradeExecutorConfig()
logging.basicConfig(level=config.log_level)
logger.info("Starting Trade Executor service")
# --- Telemetry ---
meter = setup_telemetry("trade-executor", config.otel_metrics_port)
counters = {
"trades_executed": meter.create_counter(
"trades_executed",
description="Total trades successfully submitted",
),
"rejections": meter.create_counter(
"trade_rejections",
description="Signals rejected by risk checks",
),
"fill_latency": meter.create_histogram(
"order_fill_latency_seconds",
description="Time from order submission to response",
unit="s",
),
}
# --- Redis ---
redis = Redis.from_url(config.redis_url, decode_responses=False)
consumer = StreamConsumer(redis, "signals:generated", "trade-executor", "worker-1")
publisher = StreamPublisher(redis, "trades:executed")
# --- Broker ---
broker = AlpacaBroker(
api_key=config.alpaca_api_key,
secret_key=config.alpaca_secret_key,
paper=config.paper_trading,
)
# --- Risk manager ---
risk_manager = RiskManager(config, broker)
logger.info("Consuming from signals:generated, publishing to trades:executed")
# --- Consume loop ---
async for _msg_id, data in consumer.consume():
try:
signal = TradeSignal.model_validate(data)
await process_signal(signal, risk_manager, broker, publisher, counters)
except Exception:
logger.exception("Error processing signal: %s", data)
def main() -> None:
"""CLI entry point."""
asyncio.run(run())
if __name__ == "__main__":
main()

View file

@ -0,0 +1,155 @@
"""Pre-trade risk management checks and position sizing.
Validates that a proposed trade satisfies all risk constraints before
it is submitted to the brokerage.
"""
from __future__ import annotations
import logging
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo
from services.trade_executor.config import TradeExecutorConfig
from shared.broker.base import BaseBroker
from shared.schemas.trading import AccountInfo, PositionInfo, SignalDirection, TradeSignal
logger = logging.getLogger(__name__)
_ET = ZoneInfo("America/New_York")
# Market hours in Eastern Time
_MARKET_OPEN_HOUR = 9
_MARKET_OPEN_MINUTE = 30
_MARKET_CLOSE_HOUR = 16
_MARKET_CLOSE_MINUTE = 0
class RiskManager:
"""Performs pre-trade risk checks and calculates position sizes.
Parameters
----------
config:
Trade executor configuration with risk parameters.
broker:
Broker instance for querying current positions and account info.
"""
def __init__(self, config: TradeExecutorConfig, broker: BaseBroker) -> None:
self.config = config
self.broker = broker
# ticker -> last exit timestamp
self._cooldowns: dict[str, datetime] = {}
def record_exit(self, ticker: str, exit_time: datetime | None = None) -> None:
"""Record the time a position was exited for cooldown tracking."""
self._cooldowns[ticker] = exit_time or datetime.now(tz=_ET)
async def check_risk(self, signal: TradeSignal) -> tuple[bool, str]:
"""Run all pre-trade risk checks.
Returns
-------
tuple[bool, str]
``(approved, reason)`` ``approved`` is ``True`` when
all checks pass, otherwise ``reason`` explains the failure.
"""
# 1. Market hours
now_et = datetime.now(tz=_ET)
if not self._is_market_hours(now_et):
return False, "outside_market_hours"
# 2. Cooldown
if signal.ticker in self._cooldowns:
last_exit = self._cooldowns[signal.ticker]
cooldown_end = last_exit + timedelta(minutes=self.config.cooldown_minutes)
if now_et < cooldown_end:
remaining = (cooldown_end - now_et).total_seconds() / 60
return False, f"cooldown_active ({remaining:.1f}m remaining)"
# 3. Max positions
positions = await self.broker.get_positions()
if len(positions) >= self.config.max_positions:
return False, "max_positions_exceeded"
# 4. Max total exposure
account = await self.broker.get_account()
total_exposure = sum(abs(p.market_value) for p in positions)
max_exposure = account.equity * self.config.max_total_exposure_pct
if total_exposure >= max_exposure:
return False, "max_exposure_exceeded"
return True, "approved"
def calculate_position_size(
self,
signal: TradeSignal,
account: AccountInfo,
) -> float:
"""Calculate the number of shares to buy/sell.
Uses fixed-fractional sizing: ``equity * max_position_pct``
gives the maximum dollar value per position, then scales by
signal strength.
Parameters
----------
signal:
The trade signal (includes current price via strength).
account:
Current account info (equity, buying power).
Returns
-------
float
Number of shares (whole shares).
"""
if signal.strength <= 0 or account.equity <= 0:
return 0.0
position_value = account.equity * self.config.max_position_pct
position_value *= signal.strength
# Need a price to compute qty — use the signal's embedded price
# or fall back to getting it from the snapshot. For simplicity
# the executor will pass the current price through the signal's
# sentiment_context or fetch it directly.
current_price = 0.0
if signal.sentiment_context and "current_price" in signal.sentiment_context:
current_price = float(signal.sentiment_context["current_price"])
if current_price <= 0:
logger.warning("No current price for %s, cannot size position", signal.ticker)
return 0.0
qty = position_value / current_price
return max(int(qty), 0)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _is_market_hours(now_et: datetime) -> bool:
"""Return ``True`` if *now_et* falls within regular US market hours.
Market hours: Monday--Friday, 9:30 AM -- 4:00 PM ET.
"""
# Weekday check (0=Monday ... 6=Sunday)
if now_et.weekday() >= 5:
return False
market_open = now_et.replace(
hour=_MARKET_OPEN_HOUR,
minute=_MARKET_OPEN_MINUTE,
second=0,
microsecond=0,
)
market_close = now_et.replace(
hour=_MARKET_CLOSE_HOUR,
minute=_MARKET_CLOSE_MINUTE,
second=0,
microsecond=0,
)
return market_open <= now_et < market_close