feat: make backtest work end-to-end with Alpaca bars, ticker selection, all 9 strategies
- Change BacktestRequest from strategy_weights dict to strategies list to match frontend - Add tickers field so users can select which stocks to backtest - Fetch historical bars from Alpaca StockHistoricalDataClient instead of empty data loader - Register all 9 strategies (momentum, mean_reversion, news_driven, value, macd_crossover, bollinger_breakout, vwap, liquidity, ma_stack) filtered by user selection - Fix response format: use frontend field names (max_drawdown, total_trades, win_rate as 0-1 decimal), include equity_curve and run_id in response - Add ticker selector with checkboxes and custom ticker input to dashboard - Add alpaca-py to api dependency group in pyproject.toml
This commit is contained in:
parent
82d30bde80
commit
a2c08743ac
3 changed files with 236 additions and 31 deletions
|
|
@ -4,11 +4,14 @@ import client from '../api/client';
|
|||
import { EquityCurve } from '../components/EquityCurve';
|
||||
import { MetricsRow } from '../components/MetricsRow';
|
||||
|
||||
const DEFAULT_TICKERS = ['AAPL', 'TSLA', 'NVDA', 'MSFT', 'GOOGL'];
|
||||
|
||||
interface BacktestConfig {
|
||||
start_date: string;
|
||||
end_date: string;
|
||||
initial_capital: number;
|
||||
strategies: string[];
|
||||
tickers: string[];
|
||||
}
|
||||
|
||||
interface BacktestResult {
|
||||
|
|
@ -37,6 +40,8 @@ export default function Backtest() {
|
|||
const [endDate, setEndDate] = useState('2026-01-01');
|
||||
const [initialCapital, setInitialCapital] = useState(100000);
|
||||
const [selectedStrategies, setSelectedStrategies] = useState<string[]>([]);
|
||||
const [selectedTickers, setSelectedTickers] = useState<string[]>([...DEFAULT_TICKERS]);
|
||||
const [customTicker, setCustomTicker] = useState('');
|
||||
const [currentRunId, setCurrentRunId] = useState<string | null>(null);
|
||||
|
||||
const { data: strategyOptions } = useQuery<StrategyOption[]>({
|
||||
|
|
@ -74,12 +79,13 @@ export default function Backtest() {
|
|||
});
|
||||
|
||||
const handleSubmit = () => {
|
||||
if (!startDate || !endDate || selectedStrategies.length === 0) return;
|
||||
if (!startDate || !endDate || selectedStrategies.length === 0 || selectedTickers.length === 0) return;
|
||||
runMutation.mutate({
|
||||
start_date: startDate,
|
||||
end_date: endDate,
|
||||
initial_capital: initialCapital,
|
||||
strategies: selectedStrategies,
|
||||
tickers: selectedTickers,
|
||||
});
|
||||
};
|
||||
|
||||
|
|
@ -89,6 +95,20 @@ export default function Backtest() {
|
|||
);
|
||||
};
|
||||
|
||||
const toggleTicker = (ticker: string) => {
|
||||
setSelectedTickers((prev) =>
|
||||
prev.includes(ticker) ? prev.filter((t) => t !== ticker) : [...prev, ticker]
|
||||
);
|
||||
};
|
||||
|
||||
const addCustomTicker = () => {
|
||||
const ticker = customTicker.trim().toUpperCase();
|
||||
if (ticker && !selectedTickers.includes(ticker)) {
|
||||
setSelectedTickers((prev) => [...prev, ticker]);
|
||||
}
|
||||
setCustomTicker('');
|
||||
};
|
||||
|
||||
const metricsDisplay = result?.metrics
|
||||
? [
|
||||
{
|
||||
|
|
@ -132,7 +152,7 @@ export default function Backtest() {
|
|||
<h3 className="text-lg font-semibold text-white mb-4">
|
||||
Configuration
|
||||
</h3>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-6">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4 mb-6">
|
||||
<div>
|
||||
<label className="block text-xs text-slate-400 mb-1">
|
||||
Start Date
|
||||
|
|
@ -168,6 +188,9 @@ export default function Backtest() {
|
|||
className="w-full px-3 py-2 bg-slate-700 border border-slate-600 rounded-lg text-white text-sm focus:outline-none focus:ring-1 focus:ring-blue-500"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4 mb-6">
|
||||
<div>
|
||||
<label className="block text-xs text-slate-400 mb-1">
|
||||
Strategies
|
||||
|
|
@ -191,6 +214,61 @@ export default function Backtest() {
|
|||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-xs text-slate-400 mb-1">
|
||||
Tickers
|
||||
</label>
|
||||
<div className="space-y-2 mt-1">
|
||||
{DEFAULT_TICKERS.map((ticker) => (
|
||||
<label
|
||||
key={ticker}
|
||||
className="flex items-center gap-2 cursor-pointer"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={selectedTickers.includes(ticker)}
|
||||
onChange={() => toggleTicker(ticker)}
|
||||
className="w-4 h-4 rounded border-slate-600 bg-slate-700 text-blue-600 focus:ring-blue-500"
|
||||
/>
|
||||
<span className="text-sm text-slate-300">{ticker}</span>
|
||||
</label>
|
||||
))}
|
||||
{selectedTickers
|
||||
.filter((t) => !DEFAULT_TICKERS.includes(t))
|
||||
.map((ticker) => (
|
||||
<label
|
||||
key={ticker}
|
||||
className="flex items-center gap-2 cursor-pointer"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked
|
||||
onChange={() => toggleTicker(ticker)}
|
||||
className="w-4 h-4 rounded border-slate-600 bg-slate-700 text-blue-600 focus:ring-blue-500"
|
||||
/>
|
||||
<span className="text-sm text-slate-300">{ticker}</span>
|
||||
</label>
|
||||
))}
|
||||
<div className="flex items-center gap-2 mt-2">
|
||||
<input
|
||||
type="text"
|
||||
value={customTicker}
|
||||
onChange={(e) => setCustomTicker(e.target.value)}
|
||||
onKeyDown={(e) => e.key === 'Enter' && addCustomTicker()}
|
||||
placeholder="Add ticker..."
|
||||
className="w-28 px-2 py-1 bg-slate-700 border border-slate-600 rounded text-white text-sm focus:outline-none focus:ring-1 focus:ring-blue-500"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={addCustomTicker}
|
||||
className="px-2 py-1 bg-slate-600 hover:bg-slate-500 text-white text-sm rounded transition-colors"
|
||||
>
|
||||
+
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
|
|
@ -198,6 +276,7 @@ export default function Backtest() {
|
|||
disabled={
|
||||
runMutation.isPending ||
|
||||
selectedStrategies.length === 0 ||
|
||||
selectedTickers.length === 0 ||
|
||||
!startDate ||
|
||||
!endDate
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ dependencies = [
|
|||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
api = ["fastapi>=0.110", "uvicorn[standard]>=0.27", "websockets>=12.0", "webauthn>=2.0", "pyjwt[crypto]>=2.8"]
|
||||
api = ["fastapi>=0.110", "uvicorn[standard]>=0.27", "websockets>=12.0", "webauthn>=2.0", "pyjwt[crypto]>=2.8", "alpaca-py>=0.21"]
|
||||
news = ["feedparser>=6.0", "praw>=7.7", "asyncpraw>=7.7", "httpx>=0.27"]
|
||||
sentiment = ["transformers>=4.38", "torch>=2.2", "ollama>=0.1"]
|
||||
trading = ["alpaca-py>=0.21", "pytz>=2024.1", "yfinance>=0.2", "httpx>=0.27"]
|
||||
|
|
|
|||
|
|
@ -20,6 +20,39 @@ router = APIRouter(prefix="/api/backtest", tags=["backtest"])
|
|||
# Store references to background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
# All available strategy classes keyed by name
|
||||
_STRATEGY_REGISTRY: dict[str, type] | None = None
|
||||
|
||||
|
||||
def _get_strategy_registry() -> dict[str, type]:
|
||||
"""Lazy-load strategy classes to avoid import-time side effects."""
|
||||
global _STRATEGY_REGISTRY
|
||||
if _STRATEGY_REGISTRY is None:
|
||||
from shared.strategies import (
|
||||
MomentumStrategy,
|
||||
MeanReversionStrategy,
|
||||
NewsDrivenStrategy,
|
||||
ValueStrategy,
|
||||
MACDCrossoverStrategy,
|
||||
BollingerBreakoutStrategy,
|
||||
VWAPStrategy,
|
||||
LiquidityStrategy,
|
||||
MAStackStrategy,
|
||||
)
|
||||
|
||||
_STRATEGY_REGISTRY = {
|
||||
"momentum": MomentumStrategy,
|
||||
"mean_reversion": MeanReversionStrategy,
|
||||
"news_driven": NewsDrivenStrategy,
|
||||
"value": ValueStrategy,
|
||||
"macd_crossover": MACDCrossoverStrategy,
|
||||
"bollinger_breakout": BollingerBreakoutStrategy,
|
||||
"vwap": VWAPStrategy,
|
||||
"liquidity": LiquidityStrategy,
|
||||
"ma_stack": MAStackStrategy,
|
||||
}
|
||||
return _STRATEGY_REGISTRY
|
||||
|
||||
|
||||
class BacktestRequest(BaseModel):
|
||||
"""Request body for starting a new backtest."""
|
||||
|
|
@ -29,7 +62,8 @@ class BacktestRequest(BaseModel):
|
|||
initial_capital: float = Field(default=100_000.0, gt=0)
|
||||
commission_per_trade: float = Field(default=0.0, ge=0)
|
||||
slippage_pct: float = Field(default=0.001, ge=0)
|
||||
strategy_weights: dict[str, float] = Field(default_factory=dict)
|
||||
strategies: list[str] = Field(default_factory=list)
|
||||
tickers: list[str] = Field(default_factory=lambda: ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"])
|
||||
max_position_pct: float = Field(default=0.05, gt=0, le=1.0)
|
||||
signal_threshold: float = Field(default=0.3, ge=0, le=1.0)
|
||||
|
||||
|
|
@ -47,12 +81,14 @@ async def run_backtest(
|
|||
"""
|
||||
run_id = str(uuid.uuid4())
|
||||
redis = request.app.state.redis
|
||||
config = request.app.state.config
|
||||
|
||||
# Store initial status
|
||||
await redis.setex(
|
||||
f"backtest:{run_id}",
|
||||
86400, # 24h TTL
|
||||
json.dumps({
|
||||
"run_id": run_id,
|
||||
"status": "running",
|
||||
"config": body.model_dump(mode="json"),
|
||||
"started_at": datetime.now(tz=timezone.utc).isoformat(),
|
||||
|
|
@ -60,7 +96,7 @@ async def run_backtest(
|
|||
)
|
||||
|
||||
# Launch background task (stored in set to prevent GC)
|
||||
task = asyncio.create_task(_run_backtest_task(run_id, body, redis))
|
||||
task = asyncio.create_task(_run_backtest_task(run_id, body, redis, config))
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
|
|
@ -71,14 +107,57 @@ async def _run_backtest_task(
|
|||
run_id: str,
|
||||
config: BacktestRequest,
|
||||
redis,
|
||||
app_config,
|
||||
) -> None:
|
||||
"""Execute the backtest in the background and store results in Redis."""
|
||||
try:
|
||||
from backtester.config import BacktestConfig
|
||||
from backtester.data_loader import BacktestDataLoader
|
||||
from backtester.engine import BacktestEngine
|
||||
from shared.strategies.momentum import MomentumStrategy
|
||||
from shared.strategies.mean_reversion import MeanReversionStrategy
|
||||
from shared.strategies.news_driven import NewsDrivenStrategy
|
||||
|
||||
# ---- Fetch historical bars from Alpaca ----
|
||||
bars = await _fetch_alpaca_bars(
|
||||
tickers=config.tickers,
|
||||
start=config.start_date,
|
||||
end=config.end_date,
|
||||
api_key=app_config.alpaca_api_key,
|
||||
secret_key=app_config.alpaca_secret_key,
|
||||
)
|
||||
|
||||
if not bars:
|
||||
await redis.setex(
|
||||
f"backtest:{run_id}",
|
||||
86400,
|
||||
json.dumps({
|
||||
"run_id": run_id,
|
||||
"status": "failed",
|
||||
"error": "No historical bar data returned from Alpaca. Check tickers and date range.",
|
||||
}),
|
||||
)
|
||||
return
|
||||
|
||||
data_loader = BacktestDataLoader(bars=bars)
|
||||
|
||||
# ---- Build strategy list ----
|
||||
registry = _get_strategy_registry()
|
||||
strategy_names = config.strategies or list(registry.keys())
|
||||
strategies = [
|
||||
registry[name]()
|
||||
for name in strategy_names
|
||||
if name in registry
|
||||
]
|
||||
|
||||
if not strategies:
|
||||
await redis.setex(
|
||||
f"backtest:{run_id}",
|
||||
86400,
|
||||
json.dumps({
|
||||
"run_id": run_id,
|
||||
"status": "failed",
|
||||
"error": f"No valid strategies selected. Available: {list(registry.keys())}",
|
||||
}),
|
||||
)
|
||||
return
|
||||
|
||||
bt_config = BacktestConfig(
|
||||
start_date=config.start_date,
|
||||
|
|
@ -86,44 +165,35 @@ async def _run_backtest_task(
|
|||
initial_capital=config.initial_capital,
|
||||
commission_per_trade=config.commission_per_trade,
|
||||
slippage_pct=config.slippage_pct,
|
||||
strategy_weights=config.strategy_weights,
|
||||
strategy_weights={}, # equal weights
|
||||
max_position_pct=config.max_position_pct,
|
||||
signal_threshold=config.signal_threshold,
|
||||
)
|
||||
|
||||
strategies = [
|
||||
MomentumStrategy(),
|
||||
MeanReversionStrategy(),
|
||||
NewsDrivenStrategy(),
|
||||
]
|
||||
|
||||
engine = BacktestEngine(config=bt_config, strategies=strategies)
|
||||
|
||||
# Use an empty data loader for now; a full implementation
|
||||
# would load historical bars from TimescaleDB.
|
||||
from backtester.data_loader import BacktestDataLoader
|
||||
|
||||
data_loader = BacktestDataLoader(bars=[], sentiments=[])
|
||||
|
||||
result = await engine.run(data_loader)
|
||||
|
||||
# ---- Build response matching frontend expectations ----
|
||||
equity_curve = [
|
||||
{"timestamp": ts.isoformat(), "value": eq}
|
||||
for ts, eq in result.equity_curve
|
||||
]
|
||||
|
||||
await redis.setex(
|
||||
f"backtest:{run_id}",
|
||||
86400,
|
||||
json.dumps({
|
||||
"run_id": run_id,
|
||||
"status": "completed",
|
||||
"config": config.model_dump(mode="json"),
|
||||
"result": {
|
||||
"equity_curve": equity_curve,
|
||||
"metrics": {
|
||||
"total_return": result.total_return,
|
||||
"annualized_return": result.annualized_return,
|
||||
"sharpe_ratio": result.sharpe_ratio,
|
||||
"sortino_ratio": result.sortino_ratio,
|
||||
"max_drawdown_pct": result.max_drawdown_pct,
|
||||
"max_drawdown_duration_days": result.max_drawdown_duration_days,
|
||||
"win_rate": result.win_rate,
|
||||
"avg_win_loss_ratio": result.avg_win_loss_ratio,
|
||||
"trade_count": result.trade_count,
|
||||
"avg_hold_duration_seconds": result.avg_hold_duration.total_seconds(),
|
||||
"max_drawdown": result.max_drawdown_pct,
|
||||
"win_rate": result.win_rate / 100.0,
|
||||
"total_trades": result.trade_count,
|
||||
"avg_hold_duration": str(result.avg_hold_duration),
|
||||
},
|
||||
"completed_at": datetime.now(tz=timezone.utc).isoformat(),
|
||||
}),
|
||||
|
|
@ -134,12 +204,68 @@ async def _run_backtest_task(
|
|||
f"backtest:{run_id}",
|
||||
86400,
|
||||
json.dumps({
|
||||
"run_id": run_id,
|
||||
"status": "failed",
|
||||
"error": str(exc),
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_alpaca_bars(
|
||||
tickers: list[str],
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
api_key: str,
|
||||
secret_key: str,
|
||||
) -> list[dict]:
|
||||
"""Fetch historical bars from Alpaca's market data API.
|
||||
|
||||
Runs the synchronous Alpaca SDK call in a thread executor to avoid
|
||||
blocking the event loop.
|
||||
"""
|
||||
if not api_key or not secret_key:
|
||||
raise ValueError("Alpaca API credentials not configured (TRADING_ALPACA_API_KEY / TRADING_ALPACA_SECRET_KEY)")
|
||||
|
||||
def _fetch() -> list[dict]:
|
||||
from alpaca.data.historical import StockHistoricalDataClient
|
||||
from alpaca.data.requests import StockBarsRequest
|
||||
from alpaca.data.timeframe import TimeFrame
|
||||
|
||||
client = StockHistoricalDataClient(api_key, secret_key)
|
||||
|
||||
# Ensure timezone-aware datetimes
|
||||
start_dt = start if start.tzinfo else start.replace(tzinfo=timezone.utc)
|
||||
end_dt = end if end.tzinfo else end.replace(tzinfo=timezone.utc)
|
||||
|
||||
req = StockBarsRequest(
|
||||
symbol_or_symbols=tickers,
|
||||
timeframe=TimeFrame.Day,
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
|
||||
bars_response = client.get_stock_bars(req)
|
||||
all_bars: list[dict] = []
|
||||
|
||||
for ticker in tickers:
|
||||
ticker_bars = bars_response.get(ticker, []) if bars_response else []
|
||||
for bar in ticker_bars:
|
||||
all_bars.append({
|
||||
"timestamp": bar.timestamp,
|
||||
"ticker": ticker,
|
||||
"open": float(bar.open),
|
||||
"high": float(bar.high),
|
||||
"low": float(bar.low),
|
||||
"close": float(bar.close),
|
||||
"volume": int(bar.volume),
|
||||
})
|
||||
|
||||
return all_bars
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, _fetch)
|
||||
|
||||
|
||||
@router.get("/{run_id}")
|
||||
async def get_backtest(
|
||||
run_id: str,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue