"""Pre-trade risk management checks and position sizing. Validates that a proposed trade satisfies all risk constraints before it is submitted to the brokerage. """ from __future__ import annotations import logging from datetime import datetime, timedelta from zoneinfo import ZoneInfo from redis.asyncio import Redis from services.trade_executor.config import TradeExecutorConfig from shared.broker.base import BaseBroker from shared.schemas.trading import AccountInfo, PositionInfo, SignalDirection, TradeSignal logger = logging.getLogger(__name__) _ET = ZoneInfo("America/New_York") # Market hours in Eastern Time _MARKET_OPEN_HOUR = 9 _MARKET_OPEN_MINUTE = 30 _MARKET_CLOSE_HOUR = 16 _MARKET_CLOSE_MINUTE = 0 TRADING_PAUSED_KEY = "trading:paused" class RiskManager: """Performs pre-trade risk checks and calculates position sizes. Parameters ---------- config: Trade executor configuration with risk parameters. broker: Broker instance for querying current positions and account info. redis: Redis client for checking the trading pause flag. """ def __init__( self, config: TradeExecutorConfig, broker: BaseBroker, redis: Redis | None = None, ) -> None: self.config = config self.broker = broker self.redis = redis # ticker -> last exit timestamp self._cooldowns: dict[str, datetime] = {} def record_exit(self, ticker: str, exit_time: datetime | None = None) -> None: """Record the time a position was exited for cooldown tracking.""" self._cooldowns[ticker] = exit_time or datetime.now(tz=_ET) async def check_risk(self, signal: TradeSignal) -> tuple[bool, str]: """Run all pre-trade risk checks. Returns ------- tuple[bool, str] ``(approved, reason)`` — ``approved`` is ``True`` when all checks pass, otherwise ``reason`` explains the failure. """ # 0. Trading pause flag if self.redis is not None: paused = await self.redis.get(TRADING_PAUSED_KEY) if paused: return False, "trading_paused" # 1. Market hours now_et = datetime.now(tz=_ET) if not self._is_market_hours(now_et): return False, "outside_market_hours" # 2. Cooldown if signal.ticker in self._cooldowns: last_exit = self._cooldowns[signal.ticker] cooldown_end = last_exit + timedelta(minutes=self.config.cooldown_minutes) if now_et < cooldown_end: remaining = (cooldown_end - now_et).total_seconds() / 60 return False, f"cooldown_active ({remaining:.1f}m remaining)" # 3. Max positions positions = await self.broker.get_positions() if len(positions) >= self.config.max_positions: return False, "max_positions_exceeded" # 4. Max total exposure account = await self.broker.get_account() total_exposure = sum(abs(p.market_value) for p in positions) max_exposure = account.equity * self.config.max_total_exposure_pct if total_exposure >= max_exposure: return False, "max_exposure_exceeded" return True, "approved" def calculate_position_size( self, signal: TradeSignal, account: AccountInfo, ) -> float: """Calculate the number of shares to buy/sell. Uses fixed-fractional sizing: ``equity * max_position_pct`` gives the maximum dollar value per position, then scales by signal strength. Parameters ---------- signal: The trade signal (includes current price via strength). account: Current account info (equity, buying power). Returns ------- float Number of shares (whole shares). """ if signal.strength <= 0 or account.equity <= 0: return 0.0 position_value = account.equity * self.config.max_position_pct position_value *= signal.strength # Need a price to compute qty — use the signal's embedded price # or fall back to getting it from the snapshot. For simplicity # the executor will pass the current price through the signal's # sentiment_context or fetch it directly. current_price = 0.0 if signal.sentiment_context and "current_price" in signal.sentiment_context: current_price = float(signal.sentiment_context["current_price"]) if current_price <= 0: logger.warning("No current price for %s, cannot size position", signal.ticker) return 0.0 qty = position_value / current_price return max(int(qty), 0) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ @staticmethod def _is_market_hours(now_et: datetime) -> bool: """Return ``True`` if *now_et* falls within regular US market hours. Market hours: Monday--Friday, 9:30 AM -- 4:00 PM ET. """ # Weekday check (0=Monday ... 6=Sunday) if now_et.weekday() >= 5: return False market_open = now_et.replace( hour=_MARKET_OPEN_HOUR, minute=_MARKET_OPEN_MINUTE, second=0, microsecond=0, ) market_close = now_et.replace( hour=_MARKET_CLOSE_HOUR, minute=_MARKET_CLOSE_MINUTE, second=0, microsecond=0, ) return market_open <= now_et < market_close