240 lines
9 KiB
Python
240 lines
9 KiB
Python
"""Alpaca brokerage adapter implementing :class:`BaseBroker`.
|
|
|
|
Uses the ``alpaca-py`` SDK (``TradingClient``) for order management,
|
|
position retrieval, and account information. Paper vs. live trading is
|
|
controlled via the ``paper`` flag in the constructor.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
|
|
from alpaca.common.exceptions import APIError
|
|
from alpaca.trading.client import TradingClient
|
|
from alpaca.trading.enums import OrderSide as AlpacaOrderSide
|
|
from alpaca.trading.enums import OrderStatus as AlpacaOrderStatus
|
|
from alpaca.trading.enums import TimeInForce
|
|
from alpaca.trading.models import Order as AlpacaOrder
|
|
from alpaca.trading.models import Position as AlpacaPosition
|
|
from alpaca.trading.models import TradeAccount
|
|
from alpaca.trading.requests import (
|
|
LimitOrderRequest,
|
|
MarketOrderRequest,
|
|
StopOrderRequest,
|
|
)
|
|
|
|
from shared.broker.base import BaseBroker
|
|
from shared.schemas.trading import (
|
|
AccountInfo,
|
|
OrderRequest,
|
|
OrderResult,
|
|
OrderSide,
|
|
OrderStatus,
|
|
OrderType,
|
|
PositionInfo,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Status mapping helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_STATUS_MAP: dict[AlpacaOrderStatus, OrderStatus] = {
|
|
AlpacaOrderStatus.NEW: OrderStatus.PENDING,
|
|
AlpacaOrderStatus.ACCEPTED: OrderStatus.PENDING,
|
|
AlpacaOrderStatus.PENDING_NEW: OrderStatus.PENDING,
|
|
AlpacaOrderStatus.ACCEPTED_FOR_BIDDING: OrderStatus.PENDING,
|
|
AlpacaOrderStatus.PENDING_CANCEL: OrderStatus.PENDING,
|
|
AlpacaOrderStatus.PENDING_REPLACE: OrderStatus.PENDING,
|
|
AlpacaOrderStatus.PENDING_REVIEW: OrderStatus.PENDING,
|
|
AlpacaOrderStatus.HELD: OrderStatus.PENDING,
|
|
AlpacaOrderStatus.PARTIALLY_FILLED: OrderStatus.PENDING,
|
|
AlpacaOrderStatus.FILLED: OrderStatus.FILLED,
|
|
AlpacaOrderStatus.DONE_FOR_DAY: OrderStatus.FILLED,
|
|
AlpacaOrderStatus.CANCELED: OrderStatus.CANCELLED,
|
|
AlpacaOrderStatus.EXPIRED: OrderStatus.CANCELLED,
|
|
AlpacaOrderStatus.REPLACED: OrderStatus.CANCELLED,
|
|
AlpacaOrderStatus.STOPPED: OrderStatus.CANCELLED,
|
|
AlpacaOrderStatus.SUSPENDED: OrderStatus.CANCELLED,
|
|
AlpacaOrderStatus.REJECTED: OrderStatus.REJECTED,
|
|
}
|
|
|
|
|
|
def _map_status(alpaca_status: AlpacaOrderStatus) -> OrderStatus:
|
|
"""Convert an Alpaca order status to our canonical ``OrderStatus``."""
|
|
return _STATUS_MAP.get(alpaca_status, OrderStatus.PENDING)
|
|
|
|
|
|
def _map_side(alpaca_side: AlpacaOrderSide | None) -> OrderSide:
|
|
"""Convert an Alpaca order side to our canonical ``OrderSide``."""
|
|
if alpaca_side == AlpacaOrderSide.SELL:
|
|
return OrderSide.SELL
|
|
return OrderSide.BUY
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Alpaca broker implementation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class AlpacaBroker(BaseBroker):
|
|
"""Brokerage adapter backed by the Alpaca ``TradingClient``.
|
|
|
|
Parameters
|
|
----------
|
|
api_key:
|
|
Alpaca API key ID.
|
|
secret_key:
|
|
Alpaca API secret key.
|
|
paper:
|
|
If ``True`` (the default), connect to the Alpaca paper-trading
|
|
sandbox. Set to ``False`` for live trading.
|
|
"""
|
|
|
|
def __init__(self, api_key: str, secret_key: str, *, paper: bool = True) -> None:
|
|
self._client = TradingClient(
|
|
api_key=api_key,
|
|
secret_key=secret_key,
|
|
paper=paper,
|
|
)
|
|
|
|
# -- internal helpers ----------------------------------------------------
|
|
|
|
def _build_order_request(
|
|
self, order: OrderRequest
|
|
) -> MarketOrderRequest | LimitOrderRequest | StopOrderRequest:
|
|
"""Convert our ``OrderRequest`` into the appropriate Alpaca request."""
|
|
side = AlpacaOrderSide.BUY if order.side == OrderSide.BUY else AlpacaOrderSide.SELL
|
|
|
|
if order.order_type == OrderType.LIMIT:
|
|
if order.limit_price is None:
|
|
raise ValueError("limit_price is required for limit orders")
|
|
return LimitOrderRequest(
|
|
symbol=order.ticker,
|
|
qty=order.qty,
|
|
side=side,
|
|
time_in_force=TimeInForce.DAY,
|
|
limit_price=order.limit_price,
|
|
)
|
|
elif order.order_type == OrderType.STOP:
|
|
if order.stop_price is None:
|
|
raise ValueError("stop_price is required for stop orders")
|
|
return StopOrderRequest(
|
|
symbol=order.ticker,
|
|
qty=order.qty,
|
|
side=side,
|
|
time_in_force=TimeInForce.DAY,
|
|
stop_price=order.stop_price,
|
|
)
|
|
else:
|
|
# Default: MARKET order
|
|
return MarketOrderRequest(
|
|
symbol=order.ticker,
|
|
qty=order.qty,
|
|
side=side,
|
|
time_in_force=TimeInForce.DAY,
|
|
)
|
|
|
|
@staticmethod
|
|
def _order_to_result(alpaca_order: AlpacaOrder) -> OrderResult:
|
|
"""Convert an Alpaca ``Order`` model to our ``OrderResult``."""
|
|
filled_price: float | None = None
|
|
if alpaca_order.filled_avg_price is not None:
|
|
filled_price = float(alpaca_order.filled_avg_price)
|
|
|
|
qty = float(alpaca_order.qty) if alpaca_order.qty is not None else 0.0
|
|
|
|
timestamp = alpaca_order.submitted_at or alpaca_order.created_at or datetime.now(timezone.utc)
|
|
|
|
return OrderResult(
|
|
order_id=str(alpaca_order.id),
|
|
ticker=alpaca_order.symbol or "",
|
|
side=_map_side(alpaca_order.side),
|
|
qty=qty,
|
|
filled_price=filled_price,
|
|
status=_map_status(alpaca_order.status),
|
|
timestamp=timestamp,
|
|
)
|
|
|
|
@staticmethod
|
|
def _position_to_info(pos: AlpacaPosition) -> PositionInfo:
|
|
"""Convert an Alpaca ``Position`` to our ``PositionInfo``."""
|
|
return PositionInfo(
|
|
ticker=pos.symbol,
|
|
qty=float(pos.qty),
|
|
avg_entry=float(pos.avg_entry_price),
|
|
current_price=float(pos.current_price) if pos.current_price else 0.0,
|
|
unrealized_pnl=float(pos.unrealized_pl) if pos.unrealized_pl else 0.0,
|
|
market_value=float(pos.market_value) if pos.market_value else 0.0,
|
|
)
|
|
|
|
# -- BaseBroker interface ------------------------------------------------
|
|
|
|
async def submit_order(self, order: OrderRequest) -> OrderResult:
|
|
"""Submit an order to Alpaca.
|
|
|
|
Converts the ``OrderRequest`` into the appropriate Alpaca request
|
|
object, submits it via the ``TradingClient``, and returns an
|
|
``OrderResult`` reflecting the initial state of the order.
|
|
|
|
If the API rejects the order an ``OrderResult`` with status
|
|
``REJECTED`` is returned rather than raising an exception.
|
|
"""
|
|
try:
|
|
alpaca_request = self._build_order_request(order)
|
|
alpaca_order: AlpacaOrder = await asyncio.to_thread(
|
|
self._client.submit_order, alpaca_request
|
|
)
|
|
return self._order_to_result(alpaca_order)
|
|
except APIError as exc:
|
|
logger.warning("Order rejected by Alpaca: %s", exc)
|
|
return OrderResult(
|
|
order_id="",
|
|
ticker=order.ticker,
|
|
side=order.side,
|
|
qty=order.qty,
|
|
filled_price=None,
|
|
status=OrderStatus.REJECTED,
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
|
|
async def cancel_order(self, order_id: str) -> bool:
|
|
"""Cancel an order on Alpaca by its ID.
|
|
|
|
Returns ``True`` if the cancellation request was accepted, or
|
|
``False`` if the API raised an error (e.g. the order was already
|
|
filled or does not exist).
|
|
"""
|
|
try:
|
|
await asyncio.to_thread(self._client.cancel_order_by_id, order_id)
|
|
return True
|
|
except APIError as exc:
|
|
logger.warning("Failed to cancel order %s: %s", order_id, exc)
|
|
return False
|
|
|
|
async def get_positions(self) -> list[PositionInfo]:
|
|
"""Return all open positions from Alpaca."""
|
|
positions: list[AlpacaPosition] = await asyncio.to_thread(
|
|
self._client.get_all_positions
|
|
)
|
|
return [self._position_to_info(p) for p in positions]
|
|
|
|
async def get_account(self) -> AccountInfo:
|
|
"""Return account summary from Alpaca."""
|
|
account: TradeAccount = await asyncio.to_thread(self._client.get_account)
|
|
return AccountInfo(
|
|
equity=float(account.equity) if account.equity else 0.0,
|
|
cash=float(account.cash) if account.cash else 0.0,
|
|
buying_power=float(account.buying_power) if account.buying_power else 0.0,
|
|
portfolio_value=float(account.portfolio_value) if account.portfolio_value else 0.0,
|
|
)
|
|
|
|
async def get_order_status(self, order_id: str) -> OrderResult:
|
|
"""Fetch the current state of an order from Alpaca."""
|
|
alpaca_order: AlpacaOrder = await asyncio.to_thread(
|
|
self._client.get_order_by_id, order_id
|
|
)
|
|
return self._order_to_result(alpaca_order)
|