feat: backtesting engine — historical replay with shared strategies
This commit is contained in:
parent
1d9900838d
commit
5e5425a0f7
8 changed files with 1242 additions and 1 deletions
21
backtester/__init__.py
Normal file
21
backtester/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
"""Backtesting engine for historical replay with shared strategies.
|
||||
|
||||
Provides a simulated broker, data loader, metrics calculator, and the
|
||||
main :class:`BacktestEngine` that replays market data through the same
|
||||
strategy ensemble used in live trading.
|
||||
"""
|
||||
|
||||
from backtester.config import BacktestConfig
|
||||
from backtester.data_loader import BacktestDataLoader
|
||||
from backtester.engine import BacktestEngine
|
||||
from backtester.metrics import BacktestResult, compute_metrics
|
||||
from backtester.simulated_broker import SimulatedBroker
|
||||
|
||||
__all__ = [
|
||||
"BacktestConfig",
|
||||
"BacktestDataLoader",
|
||||
"BacktestEngine",
|
||||
"BacktestResult",
|
||||
"SimulatedBroker",
|
||||
"compute_metrics",
|
||||
]
|
||||
42
backtester/config.py
Normal file
42
backtester/config.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Backtest configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class BacktestConfig:
|
||||
"""Configuration for a single backtest run.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
start_date:
|
||||
Inclusive start of the replay window.
|
||||
end_date:
|
||||
Inclusive end of the replay window.
|
||||
initial_capital:
|
||||
Starting cash balance in USD.
|
||||
commission_per_trade:
|
||||
Fixed commission charged per order (Alpaca is commission-free,
|
||||
so the default is 0.0).
|
||||
slippage_pct:
|
||||
Simulated slippage as a fraction of price (0.001 = 0.1%).
|
||||
strategy_weights:
|
||||
Mapping of strategy name to weight. If empty, strategies
|
||||
receive equal weight (0.333...).
|
||||
max_position_pct:
|
||||
Maximum fraction of equity per position (default 5%).
|
||||
signal_threshold:
|
||||
Minimum combined signal strength to trigger a trade (default 0.3).
|
||||
"""
|
||||
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
initial_capital: float = 100_000.0
|
||||
commission_per_trade: float = 0.0
|
||||
slippage_pct: float = 0.001
|
||||
strategy_weights: dict[str, float] = field(default_factory=dict)
|
||||
max_position_pct: float = 0.05
|
||||
signal_threshold: float = 0.3
|
||||
99
backtester/data_loader.py
Normal file
99
backtester/data_loader.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""Historical data loader for backtesting.
|
||||
|
||||
:class:`BacktestDataLoader` takes pre-loaded bar and sentiment data and
|
||||
yields it in chronological order, making the backtester independent of
|
||||
any database.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
from shared.schemas.trading import SentimentContext
|
||||
|
||||
|
||||
class BacktestDataLoader:
|
||||
"""Iterates over historical bars (and optional sentiment) chronologically.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bars:
|
||||
Pre-loaded OHLCV data. Each dict must contain at minimum:
|
||||
``timestamp``, ``ticker``, ``open``, ``high``, ``low``,
|
||||
``close``, ``volume``.
|
||||
sentiments:
|
||||
Optional pre-loaded sentiment data. Each dict must contain:
|
||||
``timestamp``, ``ticker``, ``score``, ``confidence``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bars: list[dict[str, Any]],
|
||||
sentiments: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
self._bars = sorted(bars, key=lambda b: b["timestamp"])
|
||||
self._sentiments = sorted(sentiments or [], key=lambda s: s["timestamp"])
|
||||
|
||||
async def iterate(
|
||||
self,
|
||||
) -> AsyncIterator[tuple[datetime, str, dict[str, Any], SentimentContext | None]]:
|
||||
"""Yield ``(timestamp, ticker, bar_data, sentiment_context)`` in order.
|
||||
|
||||
For each bar the loader aggregates all sentiment records for the
|
||||
same ticker whose timestamps are <= the current bar's timestamp,
|
||||
building a :class:`SentimentContext`. If no sentiment data is
|
||||
available for the ticker, ``None`` is yielded instead.
|
||||
"""
|
||||
# Pre-index sentiments by ticker for efficient lookup
|
||||
sentiment_by_ticker: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
for s in self._sentiments:
|
||||
sentiment_by_ticker[s["ticker"]].append(s)
|
||||
|
||||
for bar in self._bars:
|
||||
ts = bar["timestamp"]
|
||||
ticker = bar["ticker"]
|
||||
|
||||
# Build bar_data dict suitable for MarketDataManager.add_bar
|
||||
bar_data = {
|
||||
"timestamp": ts,
|
||||
"open": bar["open"],
|
||||
"high": bar["high"],
|
||||
"low": bar["low"],
|
||||
"close": bar["close"],
|
||||
"volume": bar["volume"],
|
||||
}
|
||||
|
||||
# Aggregate sentiment up to this timestamp
|
||||
sentiment_ctx = self._build_sentiment(
|
||||
ticker, ts, sentiment_by_ticker.get(ticker, [])
|
||||
)
|
||||
|
||||
yield ts, ticker, bar_data, sentiment_ctx
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _build_sentiment(
|
||||
ticker: str,
|
||||
up_to: datetime,
|
||||
records: list[dict[str, Any]],
|
||||
) -> SentimentContext | None:
|
||||
"""Build a SentimentContext from all records with timestamp <= up_to."""
|
||||
relevant = [r for r in records if r["timestamp"] <= up_to]
|
||||
if not relevant:
|
||||
return None
|
||||
|
||||
scores = [r["score"] for r in relevant]
|
||||
confidences = [r["confidence"] for r in relevant]
|
||||
|
||||
return SentimentContext(
|
||||
ticker=ticker,
|
||||
avg_score=sum(scores) / len(scores),
|
||||
article_count=len(relevant),
|
||||
recent_scores=scores[-10:], # last 10 scores
|
||||
avg_confidence=sum(confidences) / len(confidences),
|
||||
)
|
||||
164
backtester/engine.py
Normal file
164
backtester/engine.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""Main backtest engine that replays historical data through strategies.
|
||||
|
||||
Ties together the :class:`~backtester.data_loader.BacktestDataLoader`,
|
||||
:class:`~backtester.simulated_broker.SimulatedBroker`,
|
||||
:class:`~services.signal_generator.ensemble.WeightedEnsemble`, and
|
||||
:class:`~services.signal_generator.market_data.MarketDataManager` to
|
||||
produce a :class:`~backtester.metrics.BacktestResult`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from backtester.config import BacktestConfig
|
||||
from backtester.data_loader import BacktestDataLoader
|
||||
from backtester.metrics import BacktestResult, compute_metrics
|
||||
from backtester.simulated_broker import SimulatedBroker
|
||||
from services.signal_generator.ensemble import WeightedEnsemble
|
||||
from services.signal_generator.market_data import MarketDataManager
|
||||
from shared.schemas.trading import (
|
||||
OrderRequest,
|
||||
OrderSide,
|
||||
SignalDirection,
|
||||
)
|
||||
from shared.strategies.base import BaseStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BacktestEngine:
|
||||
"""Replays historical data through the trading pipeline.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config:
|
||||
Backtest configuration (dates, capital, slippage, weights, etc.).
|
||||
strategies:
|
||||
List of strategy instances to evaluate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BacktestConfig,
|
||||
strategies: list[BaseStrategy],
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.strategies = strategies
|
||||
|
||||
async def run(self, data_loader: BacktestDataLoader) -> BacktestResult:
|
||||
"""Execute the full backtest and return metrics.
|
||||
|
||||
Steps
|
||||
-----
|
||||
1. Create SimulatedBroker, MarketDataManager, WeightedEnsemble.
|
||||
2. Iterate over data_loader bars in chronological order.
|
||||
3. For each bar: update market data, update broker prices,
|
||||
build snapshot, run ensemble, submit orders.
|
||||
4. Close remaining positions at final prices.
|
||||
5. Compute and return metrics.
|
||||
"""
|
||||
broker = SimulatedBroker(
|
||||
initial_capital=self.config.initial_capital,
|
||||
slippage_pct=self.config.slippage_pct,
|
||||
commission_per_trade=self.config.commission_per_trade,
|
||||
)
|
||||
market_data = MarketDataManager()
|
||||
ensemble = WeightedEnsemble(
|
||||
strategies=self.strategies,
|
||||
threshold=self.config.signal_threshold,
|
||||
)
|
||||
|
||||
# Resolve strategy weights
|
||||
weights = self._resolve_weights()
|
||||
|
||||
equity_curve: list[tuple[datetime, float]] = []
|
||||
|
||||
# ---- Main replay loop ----
|
||||
async for timestamp, ticker, bar_data, sentiment in data_loader.iterate():
|
||||
# a. Update market data manager with the new bar
|
||||
market_data.add_bar(ticker, bar_data)
|
||||
|
||||
# b. Update broker prices
|
||||
broker.set_current_prices({ticker: bar_data["close"]})
|
||||
|
||||
# c. Build market snapshot
|
||||
snapshot = market_data.get_snapshot(ticker)
|
||||
if snapshot is None:
|
||||
continue
|
||||
|
||||
# d. Run ensemble
|
||||
signal = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
|
||||
|
||||
# e. If signal, do simple position sizing and submit order
|
||||
if signal is not None:
|
||||
account = await broker.get_account()
|
||||
positions = await broker.get_positions()
|
||||
position_tickers = {p.ticker for p in positions}
|
||||
|
||||
# Determine order side
|
||||
if signal.direction == SignalDirection.LONG and ticker not in position_tickers:
|
||||
# Buy: size using max_position_pct * equity * strength
|
||||
position_value = account.equity * self.config.max_position_pct * signal.strength
|
||||
current_price = bar_data["close"]
|
||||
if current_price > 0:
|
||||
qty = int(position_value / current_price)
|
||||
if qty > 0:
|
||||
order = OrderRequest(
|
||||
ticker=ticker,
|
||||
side=OrderSide.BUY,
|
||||
qty=float(qty),
|
||||
)
|
||||
await broker.submit_order(order)
|
||||
|
||||
elif signal.direction == SignalDirection.SHORT and ticker in position_tickers:
|
||||
# Sell: close entire position
|
||||
for pos in positions:
|
||||
if pos.ticker == ticker:
|
||||
order = OrderRequest(
|
||||
ticker=ticker,
|
||||
side=OrderSide.SELL,
|
||||
qty=pos.qty,
|
||||
)
|
||||
await broker.submit_order(order)
|
||||
break
|
||||
|
||||
# g. Record equity snapshot
|
||||
account = await broker.get_account()
|
||||
equity_curve.append((timestamp, account.equity))
|
||||
|
||||
# ---- Close all remaining positions at final prices ----
|
||||
remaining_positions = await broker.get_positions()
|
||||
for pos in remaining_positions:
|
||||
order = OrderRequest(
|
||||
ticker=pos.ticker,
|
||||
side=OrderSide.SELL,
|
||||
qty=pos.qty,
|
||||
)
|
||||
await broker.submit_order(order)
|
||||
|
||||
# Record final equity after closing
|
||||
if equity_curve:
|
||||
final_account = await broker.get_account()
|
||||
equity_curve.append((equity_curve[-1][0], final_account.equity))
|
||||
|
||||
# ---- Compute metrics ----
|
||||
trade_log = broker.get_trade_log()
|
||||
result = compute_metrics(trade_log, equity_curve, self.config.initial_capital)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_weights(self) -> dict[str, float]:
|
||||
"""Return strategy weights, defaulting to equal if none configured."""
|
||||
if self.config.strategy_weights:
|
||||
return dict(self.config.strategy_weights)
|
||||
|
||||
# Equal weights
|
||||
if not self.strategies:
|
||||
return {}
|
||||
equal_w = round(1.0 / len(self.strategies), 6)
|
||||
return {s.name: equal_w for s in self.strategies}
|
||||
280
backtester/metrics.py
Normal file
280
backtester/metrics.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""Performance metrics for backtesting results.
|
||||
|
||||
Computes standard risk and return metrics from the trade log and equity
|
||||
curve produced by a backtest run.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from shared.schemas.trading import OrderSide, TradeExecution
|
||||
|
||||
|
||||
@dataclass
|
||||
class BacktestResult:
|
||||
"""Container for all computed backtest metrics.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
total_return:
|
||||
``(final - initial) / initial * 100`` as a percentage.
|
||||
annualized_return:
|
||||
Total return annualized using 252 trading days.
|
||||
sharpe_ratio:
|
||||
``mean(daily_returns) / std(daily_returns) * sqrt(252)``.
|
||||
sortino_ratio:
|
||||
Like Sharpe but using only downside deviation.
|
||||
max_drawdown_pct:
|
||||
Maximum peak-to-trough decline as a percentage.
|
||||
max_drawdown_duration_days:
|
||||
Duration (in calendar days) of the longest drawdown.
|
||||
win_rate:
|
||||
Percentage of winning trades.
|
||||
avg_win_loss_ratio:
|
||||
``avg(winning_pnl) / abs(avg(losing_pnl))``.
|
||||
trade_count:
|
||||
Total number of round-trip trades.
|
||||
avg_hold_duration:
|
||||
Mean hold duration across all round-trip trades.
|
||||
equity_curve:
|
||||
List of ``(timestamp, equity)`` snapshots.
|
||||
trade_log:
|
||||
Raw list of :class:`TradeExecution` objects.
|
||||
"""
|
||||
|
||||
total_return: float = 0.0
|
||||
annualized_return: float = 0.0
|
||||
sharpe_ratio: float = 0.0
|
||||
sortino_ratio: float = 0.0
|
||||
max_drawdown_pct: float = 0.0
|
||||
max_drawdown_duration_days: float = 0.0
|
||||
win_rate: float = 0.0
|
||||
avg_win_loss_ratio: float = 0.0
|
||||
trade_count: int = 0
|
||||
avg_hold_duration: timedelta = field(default_factory=lambda: timedelta(0))
|
||||
equity_curve: list[tuple[datetime, float]] = field(default_factory=list)
|
||||
trade_log: list[TradeExecution] = field(default_factory=list)
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
trade_log: list[TradeExecution],
|
||||
equity_curve: list[tuple[datetime, float]],
|
||||
initial_capital: float = 100_000.0,
|
||||
) -> BacktestResult:
|
||||
"""Compute all performance metrics from a backtest run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_log:
|
||||
Chronological list of every executed trade (buys and sells).
|
||||
equity_curve:
|
||||
List of ``(timestamp, portfolio_equity)`` snapshots.
|
||||
initial_capital:
|
||||
Starting capital used to compute total return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BacktestResult
|
||||
Populated metrics dataclass.
|
||||
"""
|
||||
result = BacktestResult(
|
||||
equity_curve=equity_curve,
|
||||
trade_log=trade_log,
|
||||
)
|
||||
|
||||
if not equity_curve:
|
||||
return result
|
||||
|
||||
# ----- Total return -----
|
||||
final_equity = equity_curve[-1][1]
|
||||
result.total_return = (final_equity - initial_capital) / initial_capital * 100.0
|
||||
|
||||
# ----- Annualized return -----
|
||||
if len(equity_curve) >= 2:
|
||||
total_days = (equity_curve[-1][0] - equity_curve[0][0]).days
|
||||
if total_days > 0:
|
||||
trading_years = total_days / 365.25
|
||||
growth_factor = final_equity / initial_capital
|
||||
if growth_factor > 0:
|
||||
result.annualized_return = (
|
||||
(growth_factor ** (1.0 / trading_years)) - 1.0
|
||||
) * 100.0
|
||||
|
||||
# ----- Daily returns -----
|
||||
daily_returns = _compute_daily_returns(equity_curve)
|
||||
|
||||
# ----- Sharpe ratio -----
|
||||
result.sharpe_ratio = _compute_sharpe(daily_returns)
|
||||
|
||||
# ----- Sortino ratio -----
|
||||
result.sortino_ratio = _compute_sortino(daily_returns)
|
||||
|
||||
# ----- Max drawdown -----
|
||||
dd_pct, dd_duration = _compute_max_drawdown(equity_curve)
|
||||
result.max_drawdown_pct = dd_pct
|
||||
result.max_drawdown_duration_days = dd_duration
|
||||
|
||||
# ----- Round-trip trade analysis -----
|
||||
round_trips = _build_round_trips(trade_log)
|
||||
result.trade_count = len(round_trips)
|
||||
|
||||
if round_trips:
|
||||
pnls = [rt["pnl"] for rt in round_trips]
|
||||
wins = [p for p in pnls if p > 0]
|
||||
losses = [p for p in pnls if p <= 0]
|
||||
|
||||
result.win_rate = (len(wins) / len(pnls)) * 100.0 if pnls else 0.0
|
||||
|
||||
avg_win = sum(wins) / len(wins) if wins else 0.0
|
||||
avg_loss = sum(losses) / len(losses) if losses else 0.0
|
||||
if avg_loss != 0:
|
||||
result.avg_win_loss_ratio = abs(avg_win / avg_loss)
|
||||
elif avg_win > 0:
|
||||
result.avg_win_loss_ratio = float("inf")
|
||||
|
||||
durations = [rt["duration"] for rt in round_trips]
|
||||
result.avg_hold_duration = sum(durations, timedelta()) / len(durations)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_daily_returns(equity_curve: list[tuple[datetime, float]]) -> list[float]:
|
||||
"""Compute simple daily returns from the equity curve."""
|
||||
if len(equity_curve) < 2:
|
||||
return []
|
||||
returns: list[float] = []
|
||||
for i in range(1, len(equity_curve)):
|
||||
prev = equity_curve[i - 1][1]
|
||||
curr = equity_curve[i][1]
|
||||
if prev != 0:
|
||||
returns.append((curr - prev) / prev)
|
||||
else:
|
||||
returns.append(0.0)
|
||||
return returns
|
||||
|
||||
|
||||
def _compute_sharpe(daily_returns: list[float]) -> float:
|
||||
"""Sharpe ratio: mean / std * sqrt(252)."""
|
||||
if len(daily_returns) < 2:
|
||||
return 0.0
|
||||
|
||||
mean_ret = sum(daily_returns) / len(daily_returns)
|
||||
variance = sum((r - mean_ret) ** 2 for r in daily_returns) / (len(daily_returns) - 1)
|
||||
std_ret = math.sqrt(variance)
|
||||
|
||||
if std_ret == 0:
|
||||
return 0.0
|
||||
|
||||
return (mean_ret / std_ret) * math.sqrt(252)
|
||||
|
||||
|
||||
def _compute_sortino(daily_returns: list[float]) -> float:
|
||||
"""Sortino ratio: mean / downside_deviation * sqrt(252)."""
|
||||
if len(daily_returns) < 2:
|
||||
return 0.0
|
||||
|
||||
mean_ret = sum(daily_returns) / len(daily_returns)
|
||||
downside = [r for r in daily_returns if r < 0]
|
||||
|
||||
if not downside:
|
||||
return 0.0 if mean_ret == 0 else float("inf")
|
||||
|
||||
downside_variance = sum(r ** 2 for r in downside) / len(downside)
|
||||
downside_dev = math.sqrt(downside_variance)
|
||||
|
||||
if downside_dev == 0:
|
||||
return 0.0
|
||||
|
||||
return (mean_ret / downside_dev) * math.sqrt(252)
|
||||
|
||||
|
||||
def _compute_max_drawdown(
|
||||
equity_curve: list[tuple[datetime, float]],
|
||||
) -> tuple[float, float]:
|
||||
"""Compute max drawdown percentage and duration in days.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[float, float]
|
||||
``(max_drawdown_pct, max_drawdown_duration_days)``
|
||||
"""
|
||||
if len(equity_curve) < 2:
|
||||
return 0.0, 0.0
|
||||
|
||||
peak = equity_curve[0][1]
|
||||
peak_ts = equity_curve[0][0]
|
||||
max_dd = 0.0
|
||||
max_dd_duration = 0.0
|
||||
|
||||
for ts, equity in equity_curve[1:]:
|
||||
if equity >= peak:
|
||||
peak = equity
|
||||
peak_ts = ts
|
||||
else:
|
||||
dd = (peak - equity) / peak * 100.0 if peak > 0 else 0.0
|
||||
duration = (ts - peak_ts).days
|
||||
if dd > max_dd:
|
||||
max_dd = dd
|
||||
max_dd_duration = duration
|
||||
|
||||
return max_dd, max_dd_duration
|
||||
|
||||
|
||||
def _build_round_trips(
|
||||
trade_log: list[TradeExecution],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Match buys with sells to produce round-trip P&L and duration.
|
||||
|
||||
Uses a simple FIFO approach: each BUY opens (or adds to) a
|
||||
position; each SELL closes (reduces) it.
|
||||
"""
|
||||
# ticker -> list of {"qty": float, "price": float, "timestamp": datetime}
|
||||
open_positions: dict[str, list[dict[str, Any]]] = {}
|
||||
round_trips: list[dict[str, Any]] = []
|
||||
|
||||
for trade in trade_log:
|
||||
ticker = trade.ticker
|
||||
if trade.side == OrderSide.BUY:
|
||||
if ticker not in open_positions:
|
||||
open_positions[ticker] = []
|
||||
open_positions[ticker].append({
|
||||
"qty": trade.qty,
|
||||
"price": trade.price,
|
||||
"timestamp": trade.timestamp,
|
||||
})
|
||||
elif trade.side == OrderSide.SELL:
|
||||
if ticker not in open_positions or not open_positions[ticker]:
|
||||
continue
|
||||
remaining_sell_qty = trade.qty
|
||||
while remaining_sell_qty > 0 and open_positions.get(ticker):
|
||||
entry = open_positions[ticker][0]
|
||||
matched_qty = min(remaining_sell_qty, entry["qty"])
|
||||
|
||||
pnl = (trade.price - entry["price"]) * matched_qty
|
||||
duration = trade.timestamp - entry["timestamp"]
|
||||
|
||||
round_trips.append({
|
||||
"ticker": ticker,
|
||||
"qty": matched_qty,
|
||||
"entry_price": entry["price"],
|
||||
"exit_price": trade.price,
|
||||
"pnl": pnl,
|
||||
"duration": duration,
|
||||
})
|
||||
|
||||
entry["qty"] -= matched_qty
|
||||
remaining_sell_qty -= matched_qty
|
||||
|
||||
if entry["qty"] <= 0:
|
||||
open_positions[ticker].pop(0)
|
||||
|
||||
return round_trips
|
||||
210
backtester/simulated_broker.py
Normal file
210
backtester/simulated_broker.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
"""Simulated brokerage for backtesting.
|
||||
|
||||
:class:`SimulatedBroker` implements :class:`~shared.broker.base.BaseBroker`
|
||||
and fills orders instantly at the current bar price adjusted for slippage.
|
||||
All state (cash, positions, trade log) lives in memory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from shared.broker.base import BaseBroker
|
||||
from shared.schemas.trading import (
|
||||
AccountInfo,
|
||||
OrderRequest,
|
||||
OrderResult,
|
||||
OrderSide,
|
||||
OrderStatus,
|
||||
PositionInfo,
|
||||
TradeExecution,
|
||||
)
|
||||
|
||||
|
||||
class SimulatedBroker(BaseBroker):
|
||||
"""In-memory broker that fills orders instantly with simulated slippage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
initial_capital:
|
||||
Starting cash balance.
|
||||
slippage_pct:
|
||||
Slippage as a fraction of price (e.g. 0.001 = 0.1%).
|
||||
commission_per_trade:
|
||||
Fixed fee deducted per order fill.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_capital: float = 100_000.0,
|
||||
slippage_pct: float = 0.001,
|
||||
commission_per_trade: float = 0.0,
|
||||
) -> None:
|
||||
self.cash: float = initial_capital
|
||||
self.slippage_pct = slippage_pct
|
||||
self.commission_per_trade = commission_per_trade
|
||||
|
||||
# ticker -> {"qty": float, "avg_entry": float}
|
||||
self._positions: dict[str, dict[str, float]] = {}
|
||||
# Current market prices set externally before each order
|
||||
self._current_prices: dict[str, float] = {}
|
||||
# Complete log of every simulated trade
|
||||
self._trade_log: list[TradeExecution] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Price management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_current_prices(self, prices: dict[str, float]) -> None:
|
||||
"""Update current prices used to simulate fills."""
|
||||
self._current_prices.update(prices)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# BaseBroker interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def submit_order(self, order: OrderRequest) -> OrderResult:
|
||||
"""Fill an order immediately at current_price +/- slippage.
|
||||
|
||||
Updates internal cash balance, positions, and appends to the
|
||||
trade log.
|
||||
"""
|
||||
base_price = self._current_prices.get(order.ticker)
|
||||
if base_price is None:
|
||||
return OrderResult(
|
||||
order_id=str(uuid.uuid4()),
|
||||
ticker=order.ticker,
|
||||
side=order.side,
|
||||
qty=order.qty,
|
||||
filled_price=None,
|
||||
status=OrderStatus.REJECTED,
|
||||
timestamp=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
# Apply slippage
|
||||
if order.side == OrderSide.BUY:
|
||||
fill_price = base_price * (1.0 + self.slippage_pct)
|
||||
else:
|
||||
fill_price = base_price * (1.0 - self.slippage_pct)
|
||||
|
||||
fill_price = round(fill_price, 4)
|
||||
cost = fill_price * order.qty
|
||||
|
||||
# Deduct / credit cash
|
||||
if order.side == OrderSide.BUY:
|
||||
self.cash -= cost
|
||||
self.cash -= self.commission_per_trade
|
||||
self._update_position_buy(order.ticker, order.qty, fill_price)
|
||||
else:
|
||||
self.cash += cost
|
||||
self.cash -= self.commission_per_trade
|
||||
self._update_position_sell(order.ticker, order.qty)
|
||||
|
||||
order_id = str(uuid.uuid4())
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
# Record in trade log
|
||||
execution = TradeExecution(
|
||||
trade_id=uuid.uuid4(),
|
||||
ticker=order.ticker,
|
||||
side=order.side,
|
||||
qty=order.qty,
|
||||
price=fill_price,
|
||||
status=OrderStatus.FILLED,
|
||||
timestamp=now,
|
||||
)
|
||||
self._trade_log.append(execution)
|
||||
|
||||
return OrderResult(
|
||||
order_id=order_id,
|
||||
ticker=order.ticker,
|
||||
side=order.side,
|
||||
qty=order.qty,
|
||||
filled_price=fill_price,
|
||||
status=OrderStatus.FILLED,
|
||||
timestamp=now,
|
||||
)
|
||||
|
||||
async def cancel_order(self, order_id: str) -> bool:
|
||||
"""No-op — all orders fill instantly in simulation."""
|
||||
return True
|
||||
|
||||
async def get_positions(self) -> list[PositionInfo]:
|
||||
"""Return current positions with unrealized P&L."""
|
||||
positions: list[PositionInfo] = []
|
||||
for ticker, pos in self._positions.items():
|
||||
current_price = self._current_prices.get(ticker, pos["avg_entry"])
|
||||
qty = pos["qty"]
|
||||
avg_entry = pos["avg_entry"]
|
||||
market_value = current_price * qty
|
||||
unrealized_pnl = (current_price - avg_entry) * qty
|
||||
positions.append(
|
||||
PositionInfo(
|
||||
ticker=ticker,
|
||||
qty=qty,
|
||||
avg_entry=avg_entry,
|
||||
current_price=current_price,
|
||||
unrealized_pnl=round(unrealized_pnl, 4),
|
||||
market_value=round(market_value, 4),
|
||||
)
|
||||
)
|
||||
return positions
|
||||
|
||||
async def get_account(self) -> AccountInfo:
|
||||
"""Compute equity = cash + sum(position market values)."""
|
||||
positions = await self.get_positions()
|
||||
portfolio_value = sum(p.market_value for p in positions)
|
||||
equity = self.cash + portfolio_value
|
||||
return AccountInfo(
|
||||
equity=round(equity, 4),
|
||||
cash=round(self.cash, 4),
|
||||
buying_power=round(self.cash, 4),
|
||||
portfolio_value=round(portfolio_value, 4),
|
||||
)
|
||||
|
||||
async def get_order_status(self, order_id: str) -> OrderResult:
|
||||
"""Always return FILLED (all orders fill instantly)."""
|
||||
return OrderResult(
|
||||
order_id=order_id,
|
||||
ticker="",
|
||||
side=OrderSide.BUY,
|
||||
qty=0,
|
||||
filled_price=0.0,
|
||||
status=OrderStatus.FILLED,
|
||||
timestamp=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Extra backtest-only methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_trade_log(self) -> list[TradeExecution]:
|
||||
"""Return all simulated trade executions."""
|
||||
return list(self._trade_log)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _update_position_buy(self, ticker: str, qty: float, fill_price: float) -> None:
|
||||
"""Add to an existing position or create a new one."""
|
||||
if ticker in self._positions:
|
||||
existing = self._positions[ticker]
|
||||
total_qty = existing["qty"] + qty
|
||||
# Weighted average entry
|
||||
existing["avg_entry"] = (
|
||||
(existing["avg_entry"] * existing["qty"]) + (fill_price * qty)
|
||||
) / total_qty
|
||||
existing["qty"] = total_qty
|
||||
else:
|
||||
self._positions[ticker] = {"qty": qty, "avg_entry": fill_price}
|
||||
|
||||
def _update_position_sell(self, ticker: str, qty: float) -> None:
|
||||
"""Reduce or close a position. Removes the entry when qty hits 0."""
|
||||
if ticker not in self._positions:
|
||||
return
|
||||
existing = self._positions[ticker]
|
||||
existing["qty"] -= qty
|
||||
if existing["qty"] <= 0:
|
||||
del self._positions[ticker]
|
||||
|
|
@ -27,7 +27,7 @@ requires = ["setuptools>=70.0"]
|
|||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["shared*", "services*", "tests*"]
|
||||
include = ["shared*", "services*", "backtester*", "tests*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
|
|
|
|||
425
tests/test_backtester.py
Normal file
425
tests/test_backtester.py
Normal file
|
|
@ -0,0 +1,425 @@
|
|||
"""Tests for the backtesting engine.
|
||||
|
||||
Covers:
|
||||
- SimulatedBroker: slippage, commission, positions, equity
|
||||
- BacktestDataLoader: chronological ordering, sentiment aggregation
|
||||
- Metrics: total return, Sharpe ratio, max drawdown, win rate
|
||||
- BacktestEngine: full run with buy+sell, position closing at end
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from backtester.config import BacktestConfig
|
||||
from backtester.data_loader import BacktestDataLoader
|
||||
from backtester.engine import BacktestEngine
|
||||
from backtester.metrics import BacktestResult, compute_metrics
|
||||
from backtester.simulated_broker import SimulatedBroker
|
||||
from shared.schemas.trading import (
|
||||
OrderRequest,
|
||||
OrderSide,
|
||||
OrderStatus,
|
||||
TradeExecution,
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Helpers
|
||||
# ======================================================================
|
||||
|
||||
def _ts(day: int) -> datetime:
|
||||
"""Return a timezone-aware datetime for 2025-01-{day}."""
|
||||
return datetime(2025, 1, day, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _make_bar(day: int, ticker: str, close: float, *, open_: float | None = None,
|
||||
high: float | None = None, low: float | None = None, volume: float = 1000.0) -> dict:
|
||||
"""Build a bar dict for the data loader."""
|
||||
return {
|
||||
"timestamp": _ts(day),
|
||||
"ticker": ticker,
|
||||
"open": open_ or close,
|
||||
"high": high or close,
|
||||
"low": low or close,
|
||||
"close": close,
|
||||
"volume": volume,
|
||||
}
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# SimulatedBroker tests
|
||||
# ======================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulated_broker_buy_fills_with_slippage():
|
||||
"""BUY orders fill at price * (1 + slippage_pct)."""
|
||||
broker = SimulatedBroker(initial_capital=100_000.0, slippage_pct=0.001)
|
||||
broker.set_current_prices({"AAPL": 100.0})
|
||||
|
||||
result = await broker.submit_order(
|
||||
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10)
|
||||
)
|
||||
|
||||
assert result.status == OrderStatus.FILLED
|
||||
expected_fill = 100.0 * 1.001 # 100.1
|
||||
assert result.filled_price == pytest.approx(expected_fill, abs=0.01)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulated_broker_sell_fills_with_slippage():
|
||||
"""SELL orders fill at price * (1 - slippage_pct)."""
|
||||
broker = SimulatedBroker(initial_capital=100_000.0, slippage_pct=0.001)
|
||||
broker.set_current_prices({"AAPL": 100.0})
|
||||
|
||||
# First buy to have a position
|
||||
await broker.submit_order(
|
||||
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10)
|
||||
)
|
||||
|
||||
result = await broker.submit_order(
|
||||
OrderRequest(ticker="AAPL", side=OrderSide.SELL, qty=10)
|
||||
)
|
||||
|
||||
assert result.status == OrderStatus.FILLED
|
||||
expected_fill = 100.0 * 0.999 # 99.9
|
||||
assert result.filled_price == pytest.approx(expected_fill, abs=0.01)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulated_broker_tracks_positions():
|
||||
"""After buying, the position should be tracked; after selling, removed."""
|
||||
broker = SimulatedBroker(initial_capital=100_000.0, slippage_pct=0.0)
|
||||
broker.set_current_prices({"AAPL": 150.0})
|
||||
|
||||
# Buy 5 shares
|
||||
await broker.submit_order(
|
||||
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=5)
|
||||
)
|
||||
positions = await broker.get_positions()
|
||||
assert len(positions) == 1
|
||||
assert positions[0].ticker == "AAPL"
|
||||
assert positions[0].qty == 5
|
||||
|
||||
# Sell all
|
||||
await broker.submit_order(
|
||||
OrderRequest(ticker="AAPL", side=OrderSide.SELL, qty=5)
|
||||
)
|
||||
positions = await broker.get_positions()
|
||||
assert len(positions) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulated_broker_commission_deducted():
|
||||
"""Commission should be deducted from cash on each trade."""
|
||||
commission = 5.0
|
||||
broker = SimulatedBroker(
|
||||
initial_capital=100_000.0, slippage_pct=0.0, commission_per_trade=commission
|
||||
)
|
||||
broker.set_current_prices({"TSLA": 200.0})
|
||||
|
||||
# Buy 10 shares: cost = 200*10 + 5 commission = 2005
|
||||
await broker.submit_order(
|
||||
OrderRequest(ticker="TSLA", side=OrderSide.BUY, qty=10)
|
||||
)
|
||||
|
||||
expected_cash = 100_000.0 - (200.0 * 10) - commission
|
||||
assert broker.cash == pytest.approx(expected_cash)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulated_broker_account_equity():
|
||||
"""Equity = cash + sum(position market values)."""
|
||||
broker = SimulatedBroker(initial_capital=50_000.0, slippage_pct=0.0)
|
||||
broker.set_current_prices({"GOOG": 100.0})
|
||||
|
||||
await broker.submit_order(
|
||||
OrderRequest(ticker="GOOG", side=OrderSide.BUY, qty=100)
|
||||
)
|
||||
# cash = 50k - 100*100 = 40k, position value = 100*100 = 10k
|
||||
account = await broker.get_account()
|
||||
assert account.equity == pytest.approx(50_000.0)
|
||||
|
||||
# Price moves up to 110 -> position value = 11k
|
||||
broker.set_current_prices({"GOOG": 110.0})
|
||||
account = await broker.get_account()
|
||||
assert account.equity == pytest.approx(40_000.0 + 110.0 * 100)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# BacktestDataLoader tests
|
||||
# ======================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_loader_chronological_order():
|
||||
"""Bars should be yielded in ascending timestamp order even if input is shuffled."""
|
||||
bars = [
|
||||
_make_bar(3, "AAPL", 103.0),
|
||||
_make_bar(1, "AAPL", 101.0),
|
||||
_make_bar(2, "AAPL", 102.0),
|
||||
]
|
||||
loader = BacktestDataLoader(bars=bars)
|
||||
|
||||
timestamps: list[datetime] = []
|
||||
async for ts, _ticker, _bar, _sent in loader.iterate():
|
||||
timestamps.append(ts)
|
||||
|
||||
assert timestamps == sorted(timestamps)
|
||||
assert len(timestamps) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_loader_with_sentiment():
|
||||
"""Sentiment context should aggregate records up to the current bar's timestamp."""
|
||||
bars = [
|
||||
_make_bar(2, "AAPL", 150.0),
|
||||
_make_bar(4, "AAPL", 155.0),
|
||||
]
|
||||
sentiments = [
|
||||
{"timestamp": _ts(1), "ticker": "AAPL", "score": 0.5, "confidence": 0.8},
|
||||
{"timestamp": _ts(3), "ticker": "AAPL", "score": 0.9, "confidence": 0.9},
|
||||
{"timestamp": _ts(5), "ticker": "AAPL", "score": -0.3, "confidence": 0.7}, # future
|
||||
]
|
||||
loader = BacktestDataLoader(bars=bars, sentiments=sentiments)
|
||||
|
||||
results: list = []
|
||||
async for ts, ticker, bar, sentiment in loader.iterate():
|
||||
results.append((ts, sentiment))
|
||||
|
||||
# At day 2: only the day-1 sentiment should be included
|
||||
assert results[0][1] is not None
|
||||
assert results[0][1].article_count == 1
|
||||
assert results[0][1].avg_score == pytest.approx(0.5)
|
||||
|
||||
# At day 4: day-1 and day-3 sentiments should be included
|
||||
assert results[1][1] is not None
|
||||
assert results[1][1].article_count == 2
|
||||
assert results[1][1].avg_score == pytest.approx(0.7) # (0.5 + 0.9) / 2
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Metrics tests
|
||||
# ======================================================================
|
||||
|
||||
|
||||
def test_metrics_total_return():
|
||||
"""Total return should be (final - initial) / initial * 100."""
|
||||
curve = [(_ts(1), 100_000.0), (_ts(10), 110_000.0)]
|
||||
result = compute_metrics([], curve, initial_capital=100_000.0)
|
||||
assert result.total_return == pytest.approx(10.0)
|
||||
|
||||
|
||||
def test_metrics_sharpe_ratio():
|
||||
"""Test Sharpe with known daily returns.
|
||||
|
||||
With constant daily returns, std = 0 and Sharpe is 0.
|
||||
With varied returns, we can compute the expected value.
|
||||
"""
|
||||
# Daily returns: +1%, -0.5%, +1%, -0.5% (2 up, 2 down)
|
||||
equity = 100_000.0
|
||||
daily_rets = [0.01, -0.005, 0.01, -0.005]
|
||||
curve = [(_ts(1), equity)]
|
||||
for i, r in enumerate(daily_rets):
|
||||
equity *= (1 + r)
|
||||
curve.append((_ts(2 + i), equity))
|
||||
|
||||
result = compute_metrics([], curve, initial_capital=100_000.0)
|
||||
|
||||
# Manually compute expected Sharpe
|
||||
mean_ret = sum(daily_rets) / len(daily_rets)
|
||||
variance = sum((r - mean_ret) ** 2 for r in daily_rets) / (len(daily_rets) - 1)
|
||||
std_ret = math.sqrt(variance)
|
||||
expected_sharpe = (mean_ret / std_ret) * math.sqrt(252)
|
||||
|
||||
assert result.sharpe_ratio == pytest.approx(expected_sharpe, rel=0.01)
|
||||
|
||||
|
||||
def test_metrics_max_drawdown():
|
||||
"""Max drawdown should capture the largest peak-to-trough decline."""
|
||||
curve = [
|
||||
(_ts(1), 100_000.0),
|
||||
(_ts(2), 110_000.0), # new peak
|
||||
(_ts(3), 90_000.0), # trough: dd = (110k - 90k)/110k = 18.18%
|
||||
(_ts(4), 105_000.0), # partial recovery
|
||||
]
|
||||
result = compute_metrics([], curve, initial_capital=100_000.0)
|
||||
expected_dd = (110_000 - 90_000) / 110_000 * 100.0
|
||||
assert result.max_drawdown_pct == pytest.approx(expected_dd, rel=0.01)
|
||||
|
||||
|
||||
def test_metrics_win_rate():
|
||||
"""Win rate = winning_trades / total_trades * 100."""
|
||||
now = _ts(1)
|
||||
later = _ts(5)
|
||||
|
||||
trades = [
|
||||
# Round trip 1: buy 100 @ $10, sell 100 @ $12 -> profit
|
||||
TradeExecution(
|
||||
trade_id="aaaa1111-1111-1111-1111-111111111111",
|
||||
ticker="AAPL", side=OrderSide.BUY, qty=100, price=10.0,
|
||||
status=OrderStatus.FILLED, timestamp=now,
|
||||
),
|
||||
TradeExecution(
|
||||
trade_id="aaaa2222-2222-2222-2222-222222222222",
|
||||
ticker="AAPL", side=OrderSide.SELL, qty=100, price=12.0,
|
||||
status=OrderStatus.FILLED, timestamp=later,
|
||||
),
|
||||
# Round trip 2: buy 50 @ $20, sell 50 @ $18 -> loss
|
||||
TradeExecution(
|
||||
trade_id="bbbb1111-1111-1111-1111-111111111111",
|
||||
ticker="TSLA", side=OrderSide.BUY, qty=50, price=20.0,
|
||||
status=OrderStatus.FILLED, timestamp=now,
|
||||
),
|
||||
TradeExecution(
|
||||
trade_id="bbbb2222-2222-2222-2222-222222222222",
|
||||
ticker="TSLA", side=OrderSide.SELL, qty=50, price=18.0,
|
||||
status=OrderStatus.FILLED, timestamp=later,
|
||||
),
|
||||
]
|
||||
curve = [(_ts(1), 100_000.0), (_ts(5), 100_100.0)]
|
||||
result = compute_metrics(trades, curve, initial_capital=100_000.0)
|
||||
assert result.win_rate == pytest.approx(50.0)
|
||||
assert result.trade_count == 2
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# BacktestEngine tests
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class _AlwaysBuyStrategy:
|
||||
"""Trivial strategy that always emits a LONG signal."""
|
||||
|
||||
name = "always_buy"
|
||||
|
||||
async def evaluate(self, ticker, market, sentiment=None):
|
||||
from shared.schemas.trading import SignalDirection, TradeSignal
|
||||
|
||||
return TradeSignal(
|
||||
ticker=ticker,
|
||||
direction=SignalDirection.LONG,
|
||||
strength=0.8,
|
||||
strategy_sources=[self.name],
|
||||
timestamp=market.bars[-1]["timestamp"] if market.bars else datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
class _AlwaysSellStrategy:
|
||||
"""Trivial strategy that always emits a SHORT signal."""
|
||||
|
||||
name = "always_sell"
|
||||
|
||||
async def evaluate(self, ticker, market, sentiment=None):
|
||||
from shared.schemas.trading import SignalDirection, TradeSignal
|
||||
|
||||
return TradeSignal(
|
||||
ticker=ticker,
|
||||
direction=SignalDirection.SHORT,
|
||||
strength=0.8,
|
||||
strategy_sources=[self.name],
|
||||
timestamp=market.bars[-1]["timestamp"] if market.bars else datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
class _BuyThenSellStrategy:
|
||||
"""Strategy that buys on bar 1 and sells on bar 3."""
|
||||
|
||||
name = "buy_then_sell"
|
||||
|
||||
def __init__(self):
|
||||
self._call_count: dict[str, int] = {}
|
||||
|
||||
async def evaluate(self, ticker, market, sentiment=None):
|
||||
from shared.schemas.trading import SignalDirection, TradeSignal
|
||||
|
||||
self._call_count[ticker] = self._call_count.get(ticker, 0) + 1
|
||||
count = self._call_count[ticker]
|
||||
|
||||
if count == 1:
|
||||
return TradeSignal(
|
||||
ticker=ticker,
|
||||
direction=SignalDirection.LONG,
|
||||
strength=0.8,
|
||||
strategy_sources=[self.name],
|
||||
timestamp=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
elif count == 3:
|
||||
return TradeSignal(
|
||||
ticker=ticker,
|
||||
direction=SignalDirection.SHORT,
|
||||
strength=0.8,
|
||||
strategy_sources=[self.name],
|
||||
timestamp=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_runs_full_backtest():
|
||||
"""Run a simple 3-bar scenario: buy on bar 1, sell on bar 3."""
|
||||
config = BacktestConfig(
|
||||
start_date=_ts(1),
|
||||
end_date=_ts(3),
|
||||
initial_capital=100_000.0,
|
||||
slippage_pct=0.0,
|
||||
commission_per_trade=0.0,
|
||||
signal_threshold=0.0,
|
||||
max_position_pct=0.05,
|
||||
)
|
||||
|
||||
strategy = _BuyThenSellStrategy()
|
||||
engine = BacktestEngine(config=config, strategies=[strategy])
|
||||
|
||||
bars = [
|
||||
_make_bar(1, "AAPL", 100.0),
|
||||
_make_bar(2, "AAPL", 110.0),
|
||||
_make_bar(3, "AAPL", 120.0),
|
||||
]
|
||||
loader = BacktestDataLoader(bars=bars)
|
||||
|
||||
result = await engine.run(loader)
|
||||
|
||||
# Should have at least 2 trades (1 buy + 1 sell)
|
||||
assert result.trade_count >= 1
|
||||
assert len(result.equity_curve) > 0
|
||||
# Price went up 20%, so total return should be positive
|
||||
assert result.total_return >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_closes_positions_at_end():
|
||||
"""Any open positions should be closed at the final bar prices."""
|
||||
config = BacktestConfig(
|
||||
start_date=_ts(1),
|
||||
end_date=_ts(3),
|
||||
initial_capital=100_000.0,
|
||||
slippage_pct=0.0,
|
||||
commission_per_trade=0.0,
|
||||
signal_threshold=0.0,
|
||||
max_position_pct=0.10,
|
||||
)
|
||||
|
||||
# This strategy only buys, never sells
|
||||
strategy = _AlwaysBuyStrategy()
|
||||
engine = BacktestEngine(config=config, strategies=[strategy])
|
||||
|
||||
bars = [
|
||||
_make_bar(1, "AAPL", 100.0),
|
||||
_make_bar(2, "AAPL", 105.0),
|
||||
_make_bar(3, "AAPL", 110.0),
|
||||
]
|
||||
loader = BacktestDataLoader(bars=bars)
|
||||
|
||||
result = await engine.run(loader)
|
||||
|
||||
# The engine should have closed the position at the end.
|
||||
# The trade log should contain at least a buy and a sell.
|
||||
buys = [t for t in result.trade_log if t.side == OrderSide.BUY]
|
||||
sells = [t for t in result.trade_log if t.side == OrderSide.SELL]
|
||||
assert len(buys) >= 1
|
||||
assert len(sells) >= 1 # auto-closed at end
|
||||
Loading…
Add table
Add a link
Reference in a new issue