trading/services/trade_executor/risk_manager.py
Viktor Barzin a3cdd0f1a5
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
fix: resolve all remaining TODOs, add dev mode auth bypass
- Learning engine: expand default weights from 3 to all 9 strategies
- Learning engine: resolve placeholder strategy_id with DB lookup
- Learning engine: pass strategy_sources from trade execution
- Trade executor: respect trading:paused Redis flag in RiskManager
- Portfolio sync: compute actual daily P&L from day-start snapshot
- Portfolio API: cumulative P&L from first snapshot, read pause flag
- Portfolio metrics: compute max drawdown and avg hold duration
- Add strategy_sources field to TradeExecution schema
- Add dev_mode config (TRADING_DEV_MODE) to bypass auth for local dev
- Dashboard: VITE_DEV_MODE bypasses ProtectedRoute and 401 redirects
- Vite proxy target configurable via VITE_API_TARGET
- Add top-level README.md and remaining-work-plan.md
- Update CLAUDE.md with correct counts and remove stale TODOs
- 404 tests passing

Made-with: Cursor
2026-02-25 22:02:25 +00:00

173 lines
5.6 KiB
Python

"""Pre-trade risk management checks and position sizing.
Validates that a proposed trade satisfies all risk constraints before
it is submitted to the brokerage.
"""
from __future__ import annotations
import logging
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo
from redis.asyncio import Redis
from services.trade_executor.config import TradeExecutorConfig
from shared.broker.base import BaseBroker
from shared.schemas.trading import AccountInfo, PositionInfo, SignalDirection, TradeSignal
logger = logging.getLogger(__name__)
_ET = ZoneInfo("America/New_York")
# Market hours in Eastern Time
_MARKET_OPEN_HOUR = 9
_MARKET_OPEN_MINUTE = 30
_MARKET_CLOSE_HOUR = 16
_MARKET_CLOSE_MINUTE = 0
TRADING_PAUSED_KEY = "trading:paused"
class RiskManager:
"""Performs pre-trade risk checks and calculates position sizes.
Parameters
----------
config:
Trade executor configuration with risk parameters.
broker:
Broker instance for querying current positions and account info.
redis:
Redis client for checking the trading pause flag.
"""
def __init__(
self,
config: TradeExecutorConfig,
broker: BaseBroker,
redis: Redis | None = None,
) -> None:
self.config = config
self.broker = broker
self.redis = redis
# ticker -> last exit timestamp
self._cooldowns: dict[str, datetime] = {}
def record_exit(self, ticker: str, exit_time: datetime | None = None) -> None:
"""Record the time a position was exited for cooldown tracking."""
self._cooldowns[ticker] = exit_time or datetime.now(tz=_ET)
async def check_risk(self, signal: TradeSignal) -> tuple[bool, str]:
"""Run all pre-trade risk checks.
Returns
-------
tuple[bool, str]
``(approved, reason)`` — ``approved`` is ``True`` when
all checks pass, otherwise ``reason`` explains the failure.
"""
# 0. Trading pause flag
if self.redis is not None:
paused = await self.redis.get(TRADING_PAUSED_KEY)
if paused:
return False, "trading_paused"
# 1. Market hours
now_et = datetime.now(tz=_ET)
if not self._is_market_hours(now_et):
return False, "outside_market_hours"
# 2. Cooldown
if signal.ticker in self._cooldowns:
last_exit = self._cooldowns[signal.ticker]
cooldown_end = last_exit + timedelta(minutes=self.config.cooldown_minutes)
if now_et < cooldown_end:
remaining = (cooldown_end - now_et).total_seconds() / 60
return False, f"cooldown_active ({remaining:.1f}m remaining)"
# 3. Max positions
positions = await self.broker.get_positions()
if len(positions) >= self.config.max_positions:
return False, "max_positions_exceeded"
# 4. Max total exposure
account = await self.broker.get_account()
total_exposure = sum(abs(p.market_value) for p in positions)
max_exposure = account.equity * self.config.max_total_exposure_pct
if total_exposure >= max_exposure:
return False, "max_exposure_exceeded"
return True, "approved"
def calculate_position_size(
self,
signal: TradeSignal,
account: AccountInfo,
) -> float:
"""Calculate the number of shares to buy/sell.
Uses fixed-fractional sizing: ``equity * max_position_pct``
gives the maximum dollar value per position, then scales by
signal strength.
Parameters
----------
signal:
The trade signal (includes current price via strength).
account:
Current account info (equity, buying power).
Returns
-------
float
Number of shares (whole shares).
"""
if signal.strength <= 0 or account.equity <= 0:
return 0.0
position_value = account.equity * self.config.max_position_pct
position_value *= signal.strength
# Need a price to compute qty — use the signal's embedded price
# or fall back to getting it from the snapshot. For simplicity
# the executor will pass the current price through the signal's
# sentiment_context or fetch it directly.
current_price = 0.0
if signal.sentiment_context and "current_price" in signal.sentiment_context:
current_price = float(signal.sentiment_context["current_price"])
if current_price <= 0:
logger.warning("No current price for %s, cannot size position", signal.ticker)
return 0.0
qty = position_value / current_price
return max(int(qty), 0)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _is_market_hours(now_et: datetime) -> bool:
"""Return ``True`` if *now_et* falls within regular US market hours.
Market hours: Monday--Friday, 9:30 AM -- 4:00 PM ET.
"""
# Weekday check (0=Monday ... 6=Sunday)
if now_et.weekday() >= 5:
return False
market_open = now_et.replace(
hour=_MARKET_OPEN_HOUR,
minute=_MARKET_OPEN_MINUTE,
second=0,
microsecond=0,
)
market_close = now_et.replace(
hour=_MARKET_CLOSE_HOUR,
minute=_MARKET_CLOSE_MINUTE,
second=0,
microsecond=0,
)
return market_open <= now_et < market_close