"""Simulated brokerage for backtesting. :class:`SimulatedBroker` implements :class:`~shared.broker.base.BaseBroker` and fills orders instantly at the current bar price adjusted for slippage. All state (cash, positions, trade log) lives in memory. """ from __future__ import annotations import uuid from datetime import datetime, timezone from shared.broker.base import BaseBroker from shared.schemas.trading import ( AccountInfo, OrderRequest, OrderResult, OrderSide, OrderStatus, PositionInfo, TradeExecution, ) class SimulatedBroker(BaseBroker): """In-memory broker that fills orders instantly with simulated slippage. Parameters ---------- initial_capital: Starting cash balance. slippage_pct: Slippage as a fraction of price (e.g. 0.001 = 0.1%). commission_per_trade: Fixed fee deducted per order fill. """ def __init__( self, initial_capital: float = 100_000.0, slippage_pct: float = 0.001, commission_per_trade: float = 0.0, ) -> None: self.cash: float = initial_capital self.slippage_pct = slippage_pct self.commission_per_trade = commission_per_trade # ticker -> {"qty": float, "avg_entry": float} self._positions: dict[str, dict[str, float]] = {} # Current market prices set externally before each order self._current_prices: dict[str, float] = {} # Complete log of every simulated trade self._trade_log: list[TradeExecution] = [] # ------------------------------------------------------------------ # Price management # ------------------------------------------------------------------ def set_current_prices(self, prices: dict[str, float]) -> None: """Update current prices used to simulate fills.""" self._current_prices.update(prices) # ------------------------------------------------------------------ # BaseBroker interface # ------------------------------------------------------------------ async def submit_order(self, order: OrderRequest) -> OrderResult: """Fill an order immediately at current_price +/- slippage. Updates internal cash balance, positions, and appends to the trade log. """ base_price = self._current_prices.get(order.ticker) if base_price is None: return OrderResult( order_id=str(uuid.uuid4()), ticker=order.ticker, side=order.side, qty=order.qty, filled_price=None, status=OrderStatus.REJECTED, timestamp=datetime.now(tz=timezone.utc), ) # Apply slippage if order.side == OrderSide.BUY: fill_price = base_price * (1.0 + self.slippage_pct) else: fill_price = base_price * (1.0 - self.slippage_pct) fill_price = round(fill_price, 4) cost = fill_price * order.qty # Deduct / credit cash if order.side == OrderSide.BUY: total_cost = cost + self.commission_per_trade if total_cost > self.cash: return OrderResult( order_id=str(uuid.uuid4()), ticker=order.ticker, side=order.side, qty=order.qty, filled_price=None, status=OrderStatus.REJECTED, timestamp=datetime.now(tz=timezone.utc), ) self.cash -= total_cost self._update_position_buy(order.ticker, order.qty, fill_price) else: # Validate sufficient shares to sell current_qty = self._positions.get(order.ticker, {}).get("qty", 0.0) if order.qty > current_qty: return OrderResult( order_id=str(uuid.uuid4()), ticker=order.ticker, side=order.side, qty=order.qty, filled_price=None, status=OrderStatus.REJECTED, timestamp=datetime.now(tz=timezone.utc), ) self.cash += cost self.cash -= self.commission_per_trade self._update_position_sell(order.ticker, order.qty) order_id = str(uuid.uuid4()) now = datetime.now(tz=timezone.utc) # Record in trade log execution = TradeExecution( trade_id=uuid.uuid4(), ticker=order.ticker, side=order.side, qty=order.qty, price=fill_price, status=OrderStatus.FILLED, timestamp=now, ) self._trade_log.append(execution) return OrderResult( order_id=order_id, ticker=order.ticker, side=order.side, qty=order.qty, filled_price=fill_price, status=OrderStatus.FILLED, timestamp=now, ) async def cancel_order(self, order_id: str) -> bool: """No-op — all orders fill instantly in simulation.""" return True async def get_positions(self) -> list[PositionInfo]: """Return current positions with unrealized P&L.""" positions: list[PositionInfo] = [] for ticker, pos in self._positions.items(): current_price = self._current_prices.get(ticker, pos["avg_entry"]) qty = pos["qty"] avg_entry = pos["avg_entry"] market_value = current_price * qty unrealized_pnl = (current_price - avg_entry) * qty positions.append( PositionInfo( ticker=ticker, qty=qty, avg_entry=avg_entry, current_price=current_price, unrealized_pnl=round(unrealized_pnl, 4), market_value=round(market_value, 4), ) ) return positions async def get_account(self) -> AccountInfo: """Compute equity = cash + sum(position market values).""" positions = await self.get_positions() portfolio_value = sum(p.market_value for p in positions) equity = self.cash + portfolio_value return AccountInfo( equity=round(equity, 4), cash=round(self.cash, 4), buying_power=round(self.cash, 4), portfolio_value=round(portfolio_value, 4), ) async def get_order_status(self, order_id: str) -> OrderResult: """Always return FILLED (all orders fill instantly).""" return OrderResult( order_id=order_id, ticker="", side=OrderSide.BUY, qty=0, filled_price=0.0, status=OrderStatus.FILLED, timestamp=datetime.now(tz=timezone.utc), ) # ------------------------------------------------------------------ # Extra backtest-only methods # ------------------------------------------------------------------ def get_trade_log(self) -> list[TradeExecution]: """Return all simulated trade executions.""" return list(self._trade_log) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _update_position_buy(self, ticker: str, qty: float, fill_price: float) -> None: """Add to an existing position or create a new one.""" if ticker in self._positions: existing = self._positions[ticker] total_qty = existing["qty"] + qty # Weighted average entry existing["avg_entry"] = ( (existing["avg_entry"] * existing["qty"]) + (fill_price * qty) ) / total_qty existing["qty"] = total_qty else: self._positions[ticker] = {"qty": qty, "avg_entry": fill_price} def _update_position_sell(self, ticker: str, qty: float) -> None: """Reduce or close a position. Removes the entry when qty hits 0.""" if ticker not in self._positions: return existing = self._positions[ticker] existing["qty"] -= qty if existing["qty"] <= 0: del self._positions[ticker]