"""Trade Executor service -- main entry point. Consumes ``signals:generated`` from Redis Streams, runs risk checks, submits orders via the brokerage abstraction layer, records trades in the database, and publishes ``TradeExecution`` messages to ``trades:executed``. """ from __future__ import annotations import asyncio import logging import signal import time import uuid from redis.asyncio import Redis from sqlalchemy.ext.asyncio import async_sessionmaker from services.trade_executor.config import TradeExecutorConfig from services.trade_executor.risk_manager import RiskManager from shared.broker.alpaca_broker import AlpacaBroker from shared.db import create_db from shared.models.trading import Trade as TradeModel from shared.models.trading import TradeSide as TradeSideModel from shared.models.trading import TradeStatus as TradeStatusModel from shared.redis_streams import StreamConsumer, StreamPublisher from shared.schemas.trading import ( OrderRequest, OrderSide, OrderStatus, SignalDirection, TradeExecution, TradeSignal, ) from shared.telemetry import setup_telemetry logger = logging.getLogger(__name__) async def process_signal( signal: TradeSignal, risk_manager: RiskManager, broker: AlpacaBroker, publisher: StreamPublisher, counters: dict, db_session_factory: async_sessionmaker | None = None, ) -> None: """Process a single trade signal: risk check, order, record, publish. Parameters ---------- signal: The trade signal to act on. risk_manager: Performs pre-trade risk checks and position sizing. broker: Brokerage adapter for submitting orders. publisher: Publishes execution results to ``trades:executed``. counters: Dict of OpenTelemetry counter/histogram instruments. db_session_factory: Optional async session factory for persisting trades to the DB. """ # --- Step 1: risk check --- approved, reason = await risk_manager.check_risk(signal) if not approved: logger.info("Signal REJECTED for %s: %s", signal.ticker, reason) counters["rejections"].add(1, {"reason": reason.split(" ")[0]}) return # --- Step 2: calculate position size --- account = await broker.get_account() qty = risk_manager.calculate_position_size(signal, account) if qty <= 0: logger.info("Position size is zero for %s — skipping", signal.ticker) counters["rejections"].add(1, {"reason": "zero_position_size"}) return # --- Step 3: create order --- side = OrderSide.BUY if signal.direction == SignalDirection.LONG else OrderSide.SELL order_request = OrderRequest( ticker=signal.ticker, side=side, qty=float(qty), ) # --- Step 4: submit order --- start = time.monotonic() result = await broker.submit_order(order_request) elapsed = time.monotonic() - start counters["fill_latency"].record(elapsed) # --- Step 5: build trade execution --- trade_id = uuid.uuid4() execution = TradeExecution( trade_id=trade_id, ticker=signal.ticker, side=side, qty=result.qty, price=result.filled_price or 0.0, status=result.status, signal_id=signal.signal_id, strategy_id=None, strategy_sources=signal.strategy_sources, timestamp=result.timestamp, ) # --- Step 6: persist trade to DB --- if db_session_factory is not None: try: side_map = { OrderSide.BUY: TradeSideModel.BUY, OrderSide.SELL: TradeSideModel.SELL, } status_map = { OrderStatus.PENDING: TradeStatusModel.PENDING, OrderStatus.FILLED: TradeStatusModel.FILLED, OrderStatus.CANCELLED: TradeStatusModel.CANCELLED, OrderStatus.REJECTED: TradeStatusModel.REJECTED, } async with db_session_factory() as session: db_trade = TradeModel( id=trade_id, ticker=signal.ticker, side=side_map[side], qty=result.qty, price=result.filled_price or 0.0, timestamp=str(result.timestamp), signal_id=signal.signal_id, status=status_map.get(result.status, TradeStatusModel.PENDING), ) session.add(db_trade) await session.commit() logger.debug("Persisted trade %s to DB (signal_id=%s)", trade_id, signal.signal_id) except Exception: logger.exception("Failed to persist trade to DB") # --- Step 7: publish to trades:executed --- await publisher.publish(execution.model_dump(mode="json")) counters["trades_executed"].add(1) logger.info( "Trade executed: %s %s %.0f shares @ %s status=%s", side.value, signal.ticker, result.qty, result.filled_price, result.status.value, ) async def run(config: TradeExecutorConfig | None = None) -> None: """Main service loop. Connects to Redis, initialises the broker and risk manager, then continuously consumes from ``signals:generated`` and publishes execution results to ``trades:executed``. """ if config is None: config = TradeExecutorConfig() logging.basicConfig(level=config.log_level) logger.info("Starting Trade Executor service") # --- Telemetry --- meter = setup_telemetry("trade-executor", config.otel_metrics_port) counters = { "trades_executed": meter.create_counter( "trades_executed", description="Total trades successfully submitted", ), "rejections": meter.create_counter( "trade_rejections", description="Signals rejected by risk checks", ), "fill_latency": meter.create_histogram( "order_fill_latency_seconds", description="Time from order submission to response", unit="s", ), } # --- Redis --- redis = Redis.from_url(config.redis_url, decode_responses=False) consumer = StreamConsumer(redis, "signals:generated", "trade-executor", "worker-1") publisher = StreamPublisher(redis, "trades:executed") # --- Broker --- broker = AlpacaBroker( api_key=config.alpaca_api_key, secret_key=config.alpaca_secret_key, paper=config.paper_trading, ) # --- Risk manager --- risk_manager = RiskManager(config, broker, redis=redis) # --- Database (for persisting trades) --- db_session_factory = None try: _engine, db_session_factory = create_db(config) logger.info("Database session factory initialised for trade persistence") except Exception: logger.exception("Failed to initialise DB — trades will NOT be persisted") logger.info("Consuming from signals:generated, publishing to trades:executed") # Graceful shutdown on SIGTERM/SIGINT shutdown_event = asyncio.Event() loop = asyncio.get_running_loop() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, shutdown_event.set) # --- Consume loop --- try: async for _msg_id, data in consumer.consume(): if shutdown_event.is_set(): break try: signal_msg = TradeSignal.model_validate(data) await process_signal(signal_msg, risk_manager, broker, publisher, counters, db_session_factory) except Exception: logger.exception("Error processing signal: %s", data) finally: await redis.aclose() logger.info("Trade executor stopped gracefully") def main() -> None: """CLI entry point.""" asyncio.run(run()) if __name__ == "__main__": main()