Some checks failed
ci/woodpecker/push/woodpecker Pipeline was canceled
Walks mentions chronologically, T+1 entry, time-based exit per KevinStrategy. Reuses backtester/metrics::compute_metrics for headline numbers. KevinPriceLoader fronts market_data + Alpaca.
367 lines
12 KiB
Python
367 lines
12 KiB
Python
"""Mention-driven backtest mini-engine for the Kevin strategy.
|
|
|
|
Parallel to the bar-driven BacktestEngine. Walks mentions chronologically,
|
|
entry at T+1 open, exit at entry_session + holding_days open. Calls the
|
|
shared KevinStrategy.evaluate_mention so backtest and live agree.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timedelta, timezone
|
|
from decimal import Decimal
|
|
from typing import Any, Protocol
|
|
|
|
import pandas as pd
|
|
|
|
from backtester.metrics import BacktestResult, compute_metrics
|
|
from shared.schemas.kevin import (
|
|
KevinAccountState,
|
|
KevinDecision,
|
|
KevinDecisionType,
|
|
)
|
|
from shared.strategies.kevin import KevinStrategy
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PriceLoader(Protocol):
|
|
async def daily_bars(
|
|
self, symbol: str, start: datetime, end: datetime
|
|
) -> pd.DataFrame: ...
|
|
|
|
async def is_tradable(self, symbol: str) -> bool: ...
|
|
|
|
async def benchmark_bars(
|
|
self, start: datetime, end: datetime
|
|
) -> pd.DataFrame: ...
|
|
|
|
|
|
@dataclass
|
|
class KevinBacktestParams:
|
|
initial_capital: Decimal = Decimal("100000")
|
|
slippage_pct: Decimal = Decimal("0.0005")
|
|
commission_per_trade: Decimal = Decimal("0")
|
|
dedupe_policy: str = "roll" # "roll" | "ignore"
|
|
|
|
|
|
@dataclass
|
|
class _BacktestTrade:
|
|
symbol: str
|
|
source_mention_id: int
|
|
entry_at: datetime
|
|
entry_price: Decimal
|
|
qty: Decimal
|
|
target_exit_at: datetime
|
|
exit_at: datetime | None = None
|
|
exit_price: Decimal | None = None
|
|
pnl_usd: Decimal | None = None
|
|
pnl_pct: Decimal | None = None
|
|
holding_days_actual: int | None = None
|
|
|
|
|
|
@dataclass
|
|
class _Portfolio:
|
|
cash: Decimal
|
|
open_trades: dict[str, _BacktestTrade] = field(default_factory=dict)
|
|
closed_trades: list[_BacktestTrade] = field(default_factory=list)
|
|
blocklist_expiry: dict[str, datetime] = field(default_factory=dict)
|
|
|
|
def equity_at(self, mark_prices: dict[str, Decimal]) -> Decimal:
|
|
total = self.cash
|
|
for symbol, trade in self.open_trades.items():
|
|
price = mark_prices.get(symbol, trade.entry_price)
|
|
total += trade.qty * price
|
|
return total
|
|
|
|
def held_dollars(self) -> dict[str, Decimal]:
|
|
return {s: t.qty * t.entry_price for s, t in self.open_trades.items()}
|
|
|
|
def active_blocklist(self, now: datetime) -> set[str]:
|
|
return {s for s, exp in self.blocklist_expiry.items() if exp > now}
|
|
|
|
|
|
class KevinBacktestRunner:
|
|
def __init__(self, strategy: KevinStrategy, price_loader: PriceLoader) -> None:
|
|
self.strategy = strategy
|
|
self.price_loader = price_loader
|
|
|
|
async def run(
|
|
self, mentions: list[Any], params: KevinBacktestParams
|
|
) -> BacktestResult:
|
|
if not mentions:
|
|
return compute_metrics(
|
|
trade_log=[], equity_curve=[], initial_capital=params.initial_capital
|
|
)
|
|
|
|
sorted_mentions = sorted(mentions, key=lambda m: m.created_at)
|
|
start = sorted_mentions[0].created_at - timedelta(days=1)
|
|
end = max(m.created_at for m in sorted_mentions) + timedelta(days=120)
|
|
|
|
symbols = sorted({m.symbol for m in sorted_mentions})
|
|
bars: dict[str, pd.DataFrame] = {}
|
|
for sym in symbols:
|
|
df = await self.price_loader.daily_bars(sym, start, end)
|
|
if not df.empty:
|
|
bars[sym] = df
|
|
|
|
spy_bars = await self.price_loader.benchmark_bars(start, end)
|
|
|
|
portfolio = _Portfolio(cash=params.initial_capital)
|
|
equity_curve: list[tuple[datetime, Decimal]] = []
|
|
all_dates = _trading_dates(spy_bars)
|
|
|
|
for day in all_dates:
|
|
# 1. Apply mentions whose created_at falls on or before this trading session
|
|
for mention in [
|
|
m for m in sorted_mentions if _entry_day(m.created_at, all_dates) == day
|
|
]:
|
|
await self._apply_mention(mention, day, portfolio, bars, params)
|
|
|
|
# 2. Roll exits whose target_exit_at <= day
|
|
_close_expired(day, portfolio, bars, params)
|
|
|
|
# 3. Mark-to-market equity
|
|
mark = _mark_prices(bars, portfolio.open_trades, day)
|
|
equity_curve.append((day, portfolio.equity_at(mark)))
|
|
|
|
# Close any still-open at the last day
|
|
if all_dates:
|
|
_close_all(all_dates[-1], portfolio, bars, params)
|
|
|
|
trades_dict = [self._trade_to_dict(t) for t in portfolio.closed_trades]
|
|
return compute_metrics(
|
|
trade_log=trades_dict,
|
|
equity_curve=equity_curve,
|
|
initial_capital=params.initial_capital,
|
|
benchmark_bars=spy_bars,
|
|
)
|
|
|
|
async def _apply_mention(
|
|
self,
|
|
mention: Any,
|
|
day: datetime,
|
|
portfolio: _Portfolio,
|
|
bars: dict[str, pd.DataFrame],
|
|
params: KevinBacktestParams,
|
|
) -> None:
|
|
symbol = mention.symbol
|
|
if symbol not in bars:
|
|
return # no price data — skip
|
|
|
|
is_tradable = await self.price_loader.is_tradable(symbol)
|
|
mark = _mark_prices(bars, portfolio.open_trades, day)
|
|
equity = portfolio.equity_at(mark)
|
|
state = KevinAccountState(
|
|
equity_usd=equity,
|
|
cash_usd=portfolio.cash,
|
|
held_positions=portfolio.held_dollars(),
|
|
blocklisted_symbols=portfolio.active_blocklist(day),
|
|
daily_trade_count=0, # backtest doesn't enforce daily caps
|
|
daily_alloc_usd=Decimal("0"),
|
|
paused=False,
|
|
)
|
|
|
|
current_price = _price_at(bars[symbol], day, "open")
|
|
if current_price is None:
|
|
return
|
|
|
|
decision = await self.strategy.evaluate_mention(
|
|
mention,
|
|
state,
|
|
effective_conviction=mention.conviction,
|
|
current_price=current_price,
|
|
is_tradable=is_tradable,
|
|
)
|
|
|
|
if decision.decision == KevinDecisionType.OPEN_LONG:
|
|
self._open_or_roll(decision, mention, day, portfolio, bars, params)
|
|
elif decision.decision == KevinDecisionType.CLOSE_LONG:
|
|
self._close_position(symbol, day, portfolio, bars, params)
|
|
if mention.action.value == "avoid":
|
|
portfolio.blocklist_expiry[symbol] = day + timedelta(
|
|
days=self.strategy.config.avoid_blocks_days
|
|
)
|
|
|
|
def _open_or_roll(
|
|
self,
|
|
decision: KevinDecision,
|
|
mention: Any,
|
|
day: datetime,
|
|
portfolio: _Portfolio,
|
|
bars: dict[str, pd.DataFrame],
|
|
params: KevinBacktestParams,
|
|
) -> None:
|
|
symbol = decision.symbol
|
|
entry_price = _price_at(bars[symbol], day, "open")
|
|
if entry_price is None or decision.target_dollars is None:
|
|
return
|
|
entry_price *= Decimal("1") + params.slippage_pct
|
|
|
|
qty = (decision.target_dollars / entry_price).quantize(Decimal("0.0001"))
|
|
if qty <= 0:
|
|
return
|
|
|
|
cost = qty * entry_price + params.commission_per_trade
|
|
if cost > portfolio.cash:
|
|
return # insufficient cash in backtest
|
|
|
|
# trading days -> calendar days approximation (~7/5 = 1.4)
|
|
hold_days = decision.holding_days or 5
|
|
target_exit = day + timedelta(days=int(hold_days * 1.4))
|
|
target_exit = _next_trading_day(target_exit, bars[symbol].index)
|
|
|
|
if symbol in portfolio.open_trades:
|
|
if params.dedupe_policy == "roll":
|
|
portfolio.open_trades[symbol].target_exit_at = max(
|
|
portfolio.open_trades[symbol].target_exit_at, target_exit
|
|
)
|
|
return # ignore: don't add second position
|
|
|
|
portfolio.cash -= cost
|
|
portfolio.open_trades[symbol] = _BacktestTrade(
|
|
symbol=symbol,
|
|
source_mention_id=mention.id,
|
|
entry_at=day,
|
|
entry_price=entry_price,
|
|
qty=qty,
|
|
target_exit_at=target_exit,
|
|
)
|
|
|
|
def _close_position(
|
|
self,
|
|
symbol: str,
|
|
day: datetime,
|
|
portfolio: _Portfolio,
|
|
bars: dict[str, pd.DataFrame],
|
|
params: KevinBacktestParams,
|
|
) -> None:
|
|
if symbol not in portfolio.open_trades:
|
|
return
|
|
trade = portfolio.open_trades.pop(symbol)
|
|
exit_price = _price_at(bars[symbol], day, "open")
|
|
if exit_price is None:
|
|
exit_price = trade.entry_price # last resort
|
|
exit_price *= Decimal("1") - params.slippage_pct
|
|
|
|
proceeds = trade.qty * exit_price - params.commission_per_trade
|
|
portfolio.cash += proceeds
|
|
trade.exit_at = day
|
|
trade.exit_price = exit_price
|
|
trade.pnl_usd = (exit_price - trade.entry_price) * trade.qty
|
|
trade.pnl_pct = (
|
|
(exit_price - trade.entry_price) / trade.entry_price * Decimal("100")
|
|
)
|
|
trade.holding_days_actual = (day - trade.entry_at).days
|
|
portfolio.closed_trades.append(trade)
|
|
|
|
def _trade_to_dict(self, t: _BacktestTrade) -> dict[str, Any]:
|
|
return {
|
|
"symbol": t.symbol,
|
|
"source_mention_id": t.source_mention_id,
|
|
"entry_at": t.entry_at,
|
|
"entry_price": t.entry_price,
|
|
"exit_at": t.exit_at,
|
|
"exit_price": t.exit_price,
|
|
"qty": t.qty,
|
|
"pnl_usd": t.pnl_usd,
|
|
"pnl_pct": t.pnl_pct,
|
|
"holding_days_actual": t.holding_days_actual,
|
|
}
|
|
|
|
|
|
# --- helpers ---
|
|
|
|
|
|
def _mark_prices(
|
|
bars: dict[str, pd.DataFrame],
|
|
open_trades: dict[str, _BacktestTrade],
|
|
day: datetime,
|
|
) -> dict[str, Decimal]:
|
|
out: dict[str, Decimal] = {}
|
|
for s in open_trades:
|
|
if s in bars:
|
|
p = _price_at(bars[s], day, "close")
|
|
if p is not None:
|
|
out[s] = p
|
|
return out
|
|
|
|
|
|
def _trading_dates(bars: pd.DataFrame) -> list[datetime]:
|
|
if bars is None or bars.empty:
|
|
return []
|
|
return [d.to_pydatetime().replace(tzinfo=timezone.utc) for d in bars.index]
|
|
|
|
|
|
def _entry_day(created_at: datetime, dates: list[datetime]) -> datetime | None:
|
|
"""Find next trading session AFTER mention's created_at (T+1)."""
|
|
target = created_at.date()
|
|
for d in dates:
|
|
if d.date() > target:
|
|
return d
|
|
return None
|
|
|
|
|
|
def _price_at(df: pd.DataFrame, day: datetime, col: str) -> Decimal | None:
|
|
if df is None or df.empty:
|
|
return None
|
|
matches = df[df.index.date <= day.date()]
|
|
if matches.empty:
|
|
return None
|
|
return Decimal(str(matches.iloc[-1][col]))
|
|
|
|
|
|
def _next_trading_day(target: datetime, index: pd.DatetimeIndex) -> datetime:
|
|
for d in index:
|
|
py_d: datetime = d.to_pydatetime().replace(tzinfo=timezone.utc)
|
|
if py_d >= target:
|
|
return py_d
|
|
last: datetime = index[-1].to_pydatetime().replace(tzinfo=timezone.utc)
|
|
return last
|
|
|
|
|
|
def _close_expired(
|
|
day: datetime,
|
|
portfolio: _Portfolio,
|
|
bars: dict[str, pd.DataFrame],
|
|
params: KevinBacktestParams,
|
|
) -> None:
|
|
for symbol in list(portfolio.open_trades.keys()):
|
|
trade = portfolio.open_trades[symbol]
|
|
if trade.target_exit_at <= day:
|
|
_force_close(symbol, day, portfolio, bars, params)
|
|
|
|
|
|
def _close_all(
|
|
day: datetime,
|
|
portfolio: _Portfolio,
|
|
bars: dict[str, pd.DataFrame],
|
|
params: KevinBacktestParams,
|
|
) -> None:
|
|
for symbol in list(portfolio.open_trades.keys()):
|
|
_force_close(symbol, day, portfolio, bars, params)
|
|
|
|
|
|
def _force_close(
|
|
symbol: str,
|
|
day: datetime,
|
|
portfolio: _Portfolio,
|
|
bars: dict[str, pd.DataFrame],
|
|
params: KevinBacktestParams,
|
|
) -> None:
|
|
trade = portfolio.open_trades.pop(symbol)
|
|
exit_price = _price_at(bars[symbol], day, "open")
|
|
if exit_price is None:
|
|
exit_price = trade.entry_price
|
|
exit_price *= Decimal("1") - params.slippage_pct
|
|
proceeds = trade.qty * exit_price - params.commission_per_trade
|
|
portfolio.cash += proceeds
|
|
trade.exit_at = day
|
|
trade.exit_price = exit_price
|
|
trade.pnl_usd = (exit_price - trade.entry_price) * trade.qty
|
|
trade.pnl_pct = (
|
|
(exit_price - trade.entry_price) / trade.entry_price * Decimal("100")
|
|
)
|
|
trade.holding_days_actual = (day - trade.entry_at).days
|
|
portfolio.closed_trades.append(trade)
|