"""Backtest endpoints — run backtests and retrieve results.""" from __future__ import annotations import asyncio import json import logging import uuid from datetime import datetime, timezone from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from pydantic import BaseModel, Field from services.api_gateway.auth.middleware import get_current_user logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/backtest", tags=["backtest"]) # Store references to background tasks to prevent garbage collection _background_tasks: set[asyncio.Task] = set() # All available strategy classes keyed by name _STRATEGY_REGISTRY: dict[str, type] | None = None def _get_strategy_registry() -> dict[str, type]: """Lazy-load strategy classes to avoid import-time side effects.""" global _STRATEGY_REGISTRY if _STRATEGY_REGISTRY is None: from shared.strategies import ( MomentumStrategy, MeanReversionStrategy, NewsDrivenStrategy, ValueStrategy, MACDCrossoverStrategy, BollingerBreakoutStrategy, VWAPStrategy, LiquidityStrategy, MAStackStrategy, ) _STRATEGY_REGISTRY = { "momentum": MomentumStrategy, "mean_reversion": MeanReversionStrategy, "news_driven": NewsDrivenStrategy, "value": ValueStrategy, "macd_crossover": MACDCrossoverStrategy, "bollinger_breakout": BollingerBreakoutStrategy, "vwap": VWAPStrategy, "liquidity": LiquidityStrategy, "ma_stack": MAStackStrategy, } return _STRATEGY_REGISTRY class BacktestRequest(BaseModel): """Request body for starting a new backtest.""" start_date: datetime end_date: datetime initial_capital: float = Field(default=100_000.0, gt=0) commission_per_trade: float = Field(default=0.0, ge=0) slippage_pct: float = Field(default=0.001, ge=0) strategies: list[str] = Field(default_factory=list) tickers: list[str] = Field(default_factory=lambda: ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"]) max_position_pct: float = Field(default=0.05, gt=0, le=1.0) signal_threshold: float = Field(default=0.3, ge=0, le=1.0) @router.post("/run") async def run_backtest( body: BacktestRequest, request: Request, _user: dict = Depends(get_current_user), ) -> dict: """Start a backtest with the given configuration. Returns a ``run_id`` immediately. The backtest runs in a background task and its results can be retrieved via ``GET /api/backtest/{run_id}``. """ run_id = str(uuid.uuid4()) redis = request.app.state.redis config = request.app.state.config # Store initial status await redis.setex( f"backtest:{run_id}", 86400, # 24h TTL json.dumps({ "run_id": run_id, "status": "running", "config": body.model_dump(mode="json"), "started_at": datetime.now(tz=timezone.utc).isoformat(), }), ) # Launch background task (stored in set to prevent GC) task = asyncio.create_task(_run_backtest_task(run_id, body, redis, config)) _background_tasks.add(task) task.add_done_callback(_background_tasks.discard) return {"run_id": run_id, "status": "running"} async def _run_backtest_task( run_id: str, config: BacktestRequest, redis, app_config, ) -> None: """Execute the backtest in the background and store results in Redis.""" try: from backtester.config import BacktestConfig from backtester.data_loader import BacktestDataLoader from backtester.engine import BacktestEngine # ---- Fetch historical bars from Alpaca ---- bars = await _fetch_alpaca_bars( tickers=config.tickers, start=config.start_date, end=config.end_date, api_key=app_config.alpaca_api_key, secret_key=app_config.alpaca_secret_key, ) if not bars: await redis.setex( f"backtest:{run_id}", 86400, json.dumps({ "run_id": run_id, "status": "failed", "error": "No historical bar data returned from Alpaca. Check tickers and date range.", }), ) return data_loader = BacktestDataLoader(bars=bars) # ---- Build strategy list ---- registry = _get_strategy_registry() strategy_names = config.strategies or list(registry.keys()) strategies = [ registry[name]() for name in strategy_names if name in registry ] if not strategies: await redis.setex( f"backtest:{run_id}", 86400, json.dumps({ "run_id": run_id, "status": "failed", "error": f"No valid strategies selected. Available: {list(registry.keys())}", }), ) return bt_config = BacktestConfig( start_date=config.start_date, end_date=config.end_date, initial_capital=config.initial_capital, commission_per_trade=config.commission_per_trade, slippage_pct=config.slippage_pct, strategy_weights={}, # equal weights max_position_pct=config.max_position_pct, signal_threshold=config.signal_threshold, ) engine = BacktestEngine(config=bt_config, strategies=strategies) result = await engine.run(data_loader) # ---- Build response matching frontend expectations ---- equity_curve = [ {"timestamp": ts.isoformat(), "value": eq} for ts, eq in result.equity_curve ] await redis.setex( f"backtest:{run_id}", 86400, json.dumps({ "run_id": run_id, "status": "completed", "equity_curve": equity_curve, "metrics": { "total_return": result.total_return, "annualized_return": result.annualized_return, "sharpe_ratio": result.sharpe_ratio, "max_drawdown": result.max_drawdown_pct, "win_rate": result.win_rate / 100.0, "total_trades": result.trade_count, "avg_hold_duration": str(result.avg_hold_duration), }, "completed_at": datetime.now(tz=timezone.utc).isoformat(), }), ) except Exception as exc: logger.exception("Backtest %s failed", run_id) await redis.setex( f"backtest:{run_id}", 86400, json.dumps({ "run_id": run_id, "status": "failed", "error": str(exc), }), ) async def _fetch_alpaca_bars( tickers: list[str], start: datetime, end: datetime, api_key: str, secret_key: str, ) -> list[dict]: """Fetch historical bars from Alpaca's market data API. Runs the synchronous Alpaca SDK call in a thread executor to avoid blocking the event loop. """ if not api_key or not secret_key: raise ValueError("Alpaca API credentials not configured (TRADING_ALPACA_API_KEY / TRADING_ALPACA_SECRET_KEY)") def _fetch() -> list[dict]: from alpaca.data.historical import StockHistoricalDataClient from alpaca.data.requests import StockBarsRequest from alpaca.data.timeframe import TimeFrame client = StockHistoricalDataClient(api_key, secret_key) # Ensure timezone-aware datetimes start_dt = start if start.tzinfo else start.replace(tzinfo=timezone.utc) end_dt = end if end.tzinfo else end.replace(tzinfo=timezone.utc) req = StockBarsRequest( symbol_or_symbols=tickers, timeframe=TimeFrame.Day, start=start_dt, end=end_dt, ) bars_response = client.get_stock_bars(req) all_bars: list[dict] = [] for ticker in tickers: ticker_bars = bars_response.get(ticker, []) if bars_response else [] for bar in ticker_bars: all_bars.append({ "timestamp": bar.timestamp, "ticker": ticker, "open": float(bar.open), "high": float(bar.high), "low": float(bar.low), "close": float(bar.close), "volume": int(bar.volume), }) return all_bars loop = asyncio.get_running_loop() return await loop.run_in_executor(None, _fetch) @router.get("/{run_id}") async def get_backtest( run_id: str, request: Request, _user: dict = Depends(get_current_user), ) -> dict: """Get backtest results by run ID.""" redis = request.app.state.redis raw = await redis.get(f"backtest:{run_id}") if raw is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Backtest run not found", ) return json.loads(raw)