feat: brokerage abstraction layer with Alpaca implementation

This commit is contained in:
Viktor Barzin 2026-02-22 15:26:41 +00:00
parent 9f46071502
commit 5696da6472
No known key found for this signature in database
GPG key ID: 0EB088298288D958
4 changed files with 877 additions and 0 deletions

11
shared/broker/__init__.py Normal file
View file

@ -0,0 +1,11 @@
"""Brokerage abstraction layer.
Provides :class:`BaseBroker` (the interface) and :class:`AlpacaBroker`
(the default Alpaca implementation). Additional brokerage adapters can be
added by subclassing ``BaseBroker`` and implementing its abstract methods.
"""
from shared.broker.alpaca_broker import AlpacaBroker
from shared.broker.base import BaseBroker
__all__ = ["AlpacaBroker", "BaseBroker"]

View file

@ -0,0 +1,240 @@
"""Alpaca brokerage adapter implementing :class:`BaseBroker`.
Uses the ``alpaca-py`` SDK (``TradingClient``) for order management,
position retrieval, and account information. Paper vs. live trading is
controlled via the ``paper`` flag in the constructor.
"""
from __future__ import annotations
import asyncio
import logging
from datetime import datetime, timezone
from alpaca.common.exceptions import APIError
from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide as AlpacaOrderSide
from alpaca.trading.enums import OrderStatus as AlpacaOrderStatus
from alpaca.trading.enums import TimeInForce
from alpaca.trading.models import Order as AlpacaOrder
from alpaca.trading.models import Position as AlpacaPosition
from alpaca.trading.models import TradeAccount
from alpaca.trading.requests import (
LimitOrderRequest,
MarketOrderRequest,
StopOrderRequest,
)
from shared.broker.base import BaseBroker
from shared.schemas.trading import (
AccountInfo,
OrderRequest,
OrderResult,
OrderSide,
OrderStatus,
OrderType,
PositionInfo,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Status mapping helpers
# ---------------------------------------------------------------------------
_STATUS_MAP: dict[AlpacaOrderStatus, OrderStatus] = {
AlpacaOrderStatus.NEW: OrderStatus.PENDING,
AlpacaOrderStatus.ACCEPTED: OrderStatus.PENDING,
AlpacaOrderStatus.PENDING_NEW: OrderStatus.PENDING,
AlpacaOrderStatus.ACCEPTED_FOR_BIDDING: OrderStatus.PENDING,
AlpacaOrderStatus.PENDING_CANCEL: OrderStatus.PENDING,
AlpacaOrderStatus.PENDING_REPLACE: OrderStatus.PENDING,
AlpacaOrderStatus.PENDING_REVIEW: OrderStatus.PENDING,
AlpacaOrderStatus.HELD: OrderStatus.PENDING,
AlpacaOrderStatus.PARTIALLY_FILLED: OrderStatus.PENDING,
AlpacaOrderStatus.FILLED: OrderStatus.FILLED,
AlpacaOrderStatus.DONE_FOR_DAY: OrderStatus.FILLED,
AlpacaOrderStatus.CANCELED: OrderStatus.CANCELLED,
AlpacaOrderStatus.EXPIRED: OrderStatus.CANCELLED,
AlpacaOrderStatus.REPLACED: OrderStatus.CANCELLED,
AlpacaOrderStatus.STOPPED: OrderStatus.CANCELLED,
AlpacaOrderStatus.SUSPENDED: OrderStatus.CANCELLED,
AlpacaOrderStatus.REJECTED: OrderStatus.REJECTED,
}
def _map_status(alpaca_status: AlpacaOrderStatus) -> OrderStatus:
"""Convert an Alpaca order status to our canonical ``OrderStatus``."""
return _STATUS_MAP.get(alpaca_status, OrderStatus.PENDING)
def _map_side(alpaca_side: AlpacaOrderSide | None) -> OrderSide:
"""Convert an Alpaca order side to our canonical ``OrderSide``."""
if alpaca_side == AlpacaOrderSide.SELL:
return OrderSide.SELL
return OrderSide.BUY
# ---------------------------------------------------------------------------
# Alpaca broker implementation
# ---------------------------------------------------------------------------
class AlpacaBroker(BaseBroker):
"""Brokerage adapter backed by the Alpaca ``TradingClient``.
Parameters
----------
api_key:
Alpaca API key ID.
secret_key:
Alpaca API secret key.
paper:
If ``True`` (the default), connect to the Alpaca paper-trading
sandbox. Set to ``False`` for live trading.
"""
def __init__(self, api_key: str, secret_key: str, *, paper: bool = True) -> None:
self._client = TradingClient(
api_key=api_key,
secret_key=secret_key,
paper=paper,
)
# -- internal helpers ----------------------------------------------------
def _build_order_request(
self, order: OrderRequest
) -> MarketOrderRequest | LimitOrderRequest | StopOrderRequest:
"""Convert our ``OrderRequest`` into the appropriate Alpaca request."""
side = AlpacaOrderSide.BUY if order.side == OrderSide.BUY else AlpacaOrderSide.SELL
if order.order_type == OrderType.LIMIT:
if order.limit_price is None:
raise ValueError("limit_price is required for limit orders")
return LimitOrderRequest(
symbol=order.ticker,
qty=order.qty,
side=side,
time_in_force=TimeInForce.DAY,
limit_price=order.limit_price,
)
elif order.order_type == OrderType.STOP:
if order.stop_price is None:
raise ValueError("stop_price is required for stop orders")
return StopOrderRequest(
symbol=order.ticker,
qty=order.qty,
side=side,
time_in_force=TimeInForce.DAY,
stop_price=order.stop_price,
)
else:
# Default: MARKET order
return MarketOrderRequest(
symbol=order.ticker,
qty=order.qty,
side=side,
time_in_force=TimeInForce.DAY,
)
@staticmethod
def _order_to_result(alpaca_order: AlpacaOrder) -> OrderResult:
"""Convert an Alpaca ``Order`` model to our ``OrderResult``."""
filled_price: float | None = None
if alpaca_order.filled_avg_price is not None:
filled_price = float(alpaca_order.filled_avg_price)
qty = float(alpaca_order.qty) if alpaca_order.qty is not None else 0.0
timestamp = alpaca_order.submitted_at or alpaca_order.created_at or datetime.now(timezone.utc)
return OrderResult(
order_id=str(alpaca_order.id),
ticker=alpaca_order.symbol or "",
side=_map_side(alpaca_order.side),
qty=qty,
filled_price=filled_price,
status=_map_status(alpaca_order.status),
timestamp=timestamp,
)
@staticmethod
def _position_to_info(pos: AlpacaPosition) -> PositionInfo:
"""Convert an Alpaca ``Position`` to our ``PositionInfo``."""
return PositionInfo(
ticker=pos.symbol,
qty=float(pos.qty),
avg_entry=float(pos.avg_entry_price),
current_price=float(pos.current_price) if pos.current_price else 0.0,
unrealized_pnl=float(pos.unrealized_pl) if pos.unrealized_pl else 0.0,
market_value=float(pos.market_value) if pos.market_value else 0.0,
)
# -- BaseBroker interface ------------------------------------------------
async def submit_order(self, order: OrderRequest) -> OrderResult:
"""Submit an order to Alpaca.
Converts the ``OrderRequest`` into the appropriate Alpaca request
object, submits it via the ``TradingClient``, and returns an
``OrderResult`` reflecting the initial state of the order.
If the API rejects the order an ``OrderResult`` with status
``REJECTED`` is returned rather than raising an exception.
"""
try:
alpaca_request = self._build_order_request(order)
alpaca_order: AlpacaOrder = await asyncio.to_thread(
self._client.submit_order, alpaca_request
)
return self._order_to_result(alpaca_order)
except APIError as exc:
logger.warning("Order rejected by Alpaca: %s", exc)
return OrderResult(
order_id="",
ticker=order.ticker,
side=order.side,
qty=order.qty,
filled_price=None,
status=OrderStatus.REJECTED,
timestamp=datetime.now(timezone.utc),
)
async def cancel_order(self, order_id: str) -> bool:
"""Cancel an order on Alpaca by its ID.
Returns ``True`` if the cancellation request was accepted, or
``False`` if the API raised an error (e.g. the order was already
filled or does not exist).
"""
try:
await asyncio.to_thread(self._client.cancel_order_by_id, order_id)
return True
except APIError as exc:
logger.warning("Failed to cancel order %s: %s", order_id, exc)
return False
async def get_positions(self) -> list[PositionInfo]:
"""Return all open positions from Alpaca."""
positions: list[AlpacaPosition] = await asyncio.to_thread(
self._client.get_all_positions
)
return [self._position_to_info(p) for p in positions]
async def get_account(self) -> AccountInfo:
"""Return account summary from Alpaca."""
account: TradeAccount = await asyncio.to_thread(self._client.get_account)
return AccountInfo(
equity=float(account.equity) if account.equity else 0.0,
cash=float(account.cash) if account.cash else 0.0,
buying_power=float(account.buying_power) if account.buying_power else 0.0,
portfolio_value=float(account.portfolio_value) if account.portfolio_value else 0.0,
)
async def get_order_status(self, order_id: str) -> OrderResult:
"""Fetch the current state of an order from Alpaca."""
alpaca_order: AlpacaOrder = await asyncio.to_thread(
self._client.get_order_by_id, order_id
)
return self._order_to_result(alpaca_order)

87
shared/broker/base.py Normal file
View file

@ -0,0 +1,87 @@
"""Abstract base class for brokerage integrations.
All broker implementations must inherit from ``BaseBroker`` and provide
concrete implementations for order management, position tracking, and
account information retrieval. This abstraction layer allows the trading
bot to swap brokerages (Alpaca, Interactive Brokers, Tradier, ...) without
changing strategy or execution logic.
"""
from abc import ABC, abstractmethod
from shared.schemas.trading import AccountInfo, OrderRequest, OrderResult, PositionInfo
class BaseBroker(ABC):
"""Interface that every brokerage adapter must implement."""
@abstractmethod
async def submit_order(self, order: OrderRequest) -> OrderResult:
"""Submit a new order to the brokerage.
Parameters
----------
order:
The order details including ticker, side, quantity, and order type.
Returns
-------
OrderResult
Result containing the order ID, status, and fill information.
"""
...
@abstractmethod
async def cancel_order(self, order_id: str) -> bool:
"""Cancel an open order.
Parameters
----------
order_id:
The brokerage-assigned order identifier.
Returns
-------
bool
``True`` if the cancellation was accepted, ``False`` otherwise.
"""
...
@abstractmethod
async def get_positions(self) -> list[PositionInfo]:
"""Return all currently open positions.
Returns
-------
list[PositionInfo]
One entry per open position with quantity, average entry, current
price, and unrealized P&L.
"""
...
@abstractmethod
async def get_account(self) -> AccountInfo:
"""Return account-level summary information.
Returns
-------
AccountInfo
Equity, cash, buying power, and total portfolio value.
"""
...
@abstractmethod
async def get_order_status(self, order_id: str) -> OrderResult:
"""Fetch the current status of an existing order.
Parameters
----------
order_id:
The brokerage-assigned order identifier.
Returns
-------
OrderResult
Current state of the order including fill price if applicable.
"""
...

539
tests/test_broker.py Normal file
View file

@ -0,0 +1,539 @@
"""Tests for the brokerage abstraction layer with a mocked Alpaca TradingClient."""
import uuid
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from alpaca.common.exceptions import APIError
from alpaca.trading.enums import OrderSide as AlpacaOrderSide
from alpaca.trading.enums import OrderStatus as AlpacaOrderStatus
from alpaca.trading.enums import OrderType as AlpacaOrderType
from alpaca.trading.enums import AssetClass, AssetExchange, OrderClass, PositionSide, TimeInForce
from alpaca.trading.models import Order as AlpacaOrder
from alpaca.trading.models import Position as AlpacaPosition
from alpaca.trading.models import TradeAccount
from alpaca.trading.requests import LimitOrderRequest, MarketOrderRequest, StopOrderRequest
from shared.broker.alpaca_broker import AlpacaBroker
from shared.schemas.trading import (
AccountInfo,
OrderRequest,
OrderResult,
OrderSide,
OrderStatus,
OrderType,
PositionInfo,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_alpaca_order(
*,
order_id: str | None = None,
symbol: str = "AAPL",
side: AlpacaOrderSide = AlpacaOrderSide.BUY,
qty: str = "10",
filled_avg_price: str | None = None,
status: AlpacaOrderStatus = AlpacaOrderStatus.NEW,
submitted_at: datetime | None = None,
) -> AlpacaOrder:
"""Build a minimal Alpaca ``Order`` model for testing."""
oid = uuid.UUID(order_id) if order_id else uuid.uuid4()
now = submitted_at or datetime.now(timezone.utc)
return AlpacaOrder(
id=oid,
client_order_id=str(uuid.uuid4()),
created_at=now,
updated_at=now,
submitted_at=now,
filled_at=None,
expired_at=None,
canceled_at=None,
failed_at=None,
replaced_at=None,
replaced_by=None,
replaces=None,
asset_id=uuid.uuid4(),
symbol=symbol,
asset_class=AssetClass.US_EQUITY,
notional=None,
qty=qty,
filled_qty="0",
filled_avg_price=filled_avg_price,
order_class=OrderClass.SIMPLE,
order_type=AlpacaOrderType.MARKET,
type=AlpacaOrderType.MARKET,
side=side,
time_in_force=TimeInForce.DAY,
limit_price=None,
stop_price=None,
status=status,
extended_hours=False,
legs=None,
trail_percent=None,
trail_price=None,
hwm=None,
)
def _make_alpaca_position(
*,
symbol: str = "AAPL",
qty: str = "10",
avg_entry_price: str = "150.00",
current_price: str = "155.00",
unrealized_pl: str = "50.00",
market_value: str = "1550.00",
) -> AlpacaPosition:
"""Build a minimal Alpaca ``Position`` model for testing."""
return AlpacaPosition(
asset_id=uuid.uuid4(),
symbol=symbol,
exchange=AssetExchange.NASDAQ,
asset_class=AssetClass.US_EQUITY,
asset_marginable=True,
avg_entry_price=avg_entry_price,
qty=qty,
side=PositionSide.LONG,
market_value=market_value,
cost_basis="1500.00",
unrealized_pl=unrealized_pl,
unrealized_plpc="0.0333",
unrealized_intraday_pl="50.00",
unrealized_intraday_plpc="0.0333",
current_price=current_price,
lastday_price="150.00",
change_today="0.0333",
)
def _make_alpaca_account(
*,
equity: str = "100000.00",
cash: str = "50000.00",
buying_power: str = "200000.00",
portfolio_value: str = "100000.00",
) -> TradeAccount:
"""Build a minimal Alpaca ``TradeAccount`` model for testing."""
return TradeAccount(
id=uuid.uuid4(),
account_number="test-123",
status="ACTIVE",
crypto_status=None,
currency="USD",
buying_power=buying_power,
regt_buying_power=buying_power,
daytrading_buying_power=buying_power,
non_marginable_buying_power=cash,
cash=cash,
accrued_fees="0",
pending_transfer_in=None,
portfolio_value=portfolio_value,
pattern_day_trader=False,
trading_blocked=False,
transfers_blocked=False,
account_blocked=False,
created_at=datetime.now(timezone.utc),
trade_suspended_by_user=False,
multiplier="2",
shorting_enabled=True,
equity=equity,
last_equity=equity,
long_market_value="50000.00",
short_market_value="0",
initial_margin="25000.00",
maintenance_margin="15000.00",
last_maintenance_margin="15000.00",
sma="100000.00",
daytrade_count=0,
)
@pytest.fixture
def mock_client() -> MagicMock:
"""Return a mocked ``TradingClient``."""
return MagicMock()
@pytest.fixture
def broker(mock_client: MagicMock) -> AlpacaBroker:
"""Return an ``AlpacaBroker`` whose internal client is mocked."""
with patch("shared.broker.alpaca_broker.TradingClient", return_value=mock_client):
b = AlpacaBroker(api_key="test-key", secret_key="test-secret", paper=True)
return b
# ---------------------------------------------------------------------------
# Order submission
# ---------------------------------------------------------------------------
class TestSubmitMarketOrder:
@pytest.mark.asyncio
async def test_submit_market_order(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""A market buy order should be converted and submitted to Alpaca."""
alpaca_order = _make_alpaca_order(
symbol="AAPL",
side=AlpacaOrderSide.BUY,
qty="10",
status=AlpacaOrderStatus.NEW,
)
mock_client.submit_order.return_value = alpaca_order
order = OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10.0, order_type=OrderType.MARKET)
result = await broker.submit_order(order)
assert isinstance(result, OrderResult)
assert result.ticker == "AAPL"
assert result.side == OrderSide.BUY
assert result.qty == 10.0
assert result.status == OrderStatus.PENDING
assert result.order_id == str(alpaca_order.id)
assert result.filled_price is None
# Verify the Alpaca client received a MarketOrderRequest
submitted = mock_client.submit_order.call_args[0][0]
assert isinstance(submitted, MarketOrderRequest)
assert submitted.symbol == "AAPL"
assert submitted.qty == 10.0
assert submitted.side == AlpacaOrderSide.BUY
@pytest.mark.asyncio
async def test_submit_market_sell(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""A market sell order maps the side correctly."""
alpaca_order = _make_alpaca_order(
symbol="TSLA",
side=AlpacaOrderSide.SELL,
qty="5",
status=AlpacaOrderStatus.FILLED,
filled_avg_price="200.50",
)
mock_client.submit_order.return_value = alpaca_order
order = OrderRequest(ticker="TSLA", side=OrderSide.SELL, qty=5.0)
result = await broker.submit_order(order)
assert result.side == OrderSide.SELL
assert result.status == OrderStatus.FILLED
assert result.filled_price == 200.50
class TestSubmitLimitOrder:
@pytest.mark.asyncio
async def test_submit_limit_order(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""A limit order should include the limit price in the Alpaca request."""
alpaca_order = _make_alpaca_order(
symbol="MSFT",
side=AlpacaOrderSide.BUY,
qty="20",
status=AlpacaOrderStatus.ACCEPTED,
)
mock_client.submit_order.return_value = alpaca_order
order = OrderRequest(
ticker="MSFT",
side=OrderSide.BUY,
qty=20.0,
order_type=OrderType.LIMIT,
limit_price=350.00,
)
result = await broker.submit_order(order)
assert isinstance(result, OrderResult)
assert result.ticker == "MSFT"
assert result.status == OrderStatus.PENDING
submitted = mock_client.submit_order.call_args[0][0]
assert isinstance(submitted, LimitOrderRequest)
assert submitted.limit_price == 350.00
@pytest.mark.asyncio
async def test_limit_order_missing_price_raises(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""A limit order without limit_price should raise ValueError."""
order = OrderRequest(
ticker="MSFT",
side=OrderSide.BUY,
qty=20.0,
order_type=OrderType.LIMIT,
limit_price=None,
)
with pytest.raises(ValueError, match="limit_price is required"):
await broker.submit_order(order)
class TestSubmitStopOrder:
@pytest.mark.asyncio
async def test_submit_stop_order(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""A stop order should include the stop price in the Alpaca request."""
alpaca_order = _make_alpaca_order(
symbol="GOOG",
side=AlpacaOrderSide.SELL,
qty="15",
status=AlpacaOrderStatus.NEW,
)
mock_client.submit_order.return_value = alpaca_order
order = OrderRequest(
ticker="GOOG",
side=OrderSide.SELL,
qty=15.0,
order_type=OrderType.STOP,
stop_price=140.00,
)
result = await broker.submit_order(order)
assert isinstance(result, OrderResult)
assert result.ticker == "GOOG"
submitted = mock_client.submit_order.call_args[0][0]
assert isinstance(submitted, StopOrderRequest)
assert submitted.stop_price == 140.00
@pytest.mark.asyncio
async def test_stop_order_missing_price_raises(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""A stop order without stop_price should raise ValueError."""
order = OrderRequest(
ticker="GOOG",
side=OrderSide.SELL,
qty=15.0,
order_type=OrderType.STOP,
stop_price=None,
)
with pytest.raises(ValueError, match="stop_price is required"):
await broker.submit_order(order)
# ---------------------------------------------------------------------------
# Order rejection
# ---------------------------------------------------------------------------
class TestSubmitOrderRejected:
@pytest.mark.asyncio
async def test_api_error_returns_rejected_result(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""If Alpaca's API raises an error, submit_order returns REJECTED."""
mock_client.submit_order.side_effect = APIError("insufficient buying power")
order = OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10000.0)
result = await broker.submit_order(order)
assert isinstance(result, OrderResult)
assert result.status == OrderStatus.REJECTED
assert result.ticker == "AAPL"
assert result.side == OrderSide.BUY
assert result.qty == 10000.0
assert result.filled_price is None
assert result.order_id == ""
# ---------------------------------------------------------------------------
# Cancel order
# ---------------------------------------------------------------------------
class TestCancelOrder:
@pytest.mark.asyncio
async def test_cancel_order_success(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""Successful cancellation returns True."""
mock_client.cancel_order_by_id.return_value = None # void method
result = await broker.cancel_order("some-order-id")
assert result is True
mock_client.cancel_order_by_id.assert_called_once_with("some-order-id")
@pytest.mark.asyncio
async def test_cancel_order_failure(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""If the API raises an error (e.g. order already filled), return False."""
mock_client.cancel_order_by_id.side_effect = APIError("order is not cancelable")
result = await broker.cancel_order("some-order-id")
assert result is False
# ---------------------------------------------------------------------------
# Positions
# ---------------------------------------------------------------------------
class TestGetPositions:
@pytest.mark.asyncio
async def test_get_positions_empty(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""When there are no open positions, return an empty list."""
mock_client.get_all_positions.return_value = []
positions = await broker.get_positions()
assert positions == []
@pytest.mark.asyncio
async def test_get_positions_with_data(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""Positions are correctly converted to PositionInfo."""
mock_client.get_all_positions.return_value = [
_make_alpaca_position(
symbol="AAPL",
qty="10",
avg_entry_price="150.00",
current_price="155.00",
unrealized_pl="50.00",
market_value="1550.00",
),
_make_alpaca_position(
symbol="TSLA",
qty="5",
avg_entry_price="200.00",
current_price="190.00",
unrealized_pl="-50.00",
market_value="950.00",
),
]
positions = await broker.get_positions()
assert len(positions) == 2
aapl = positions[0]
assert isinstance(aapl, PositionInfo)
assert aapl.ticker == "AAPL"
assert aapl.qty == 10.0
assert aapl.avg_entry == 150.0
assert aapl.current_price == 155.0
assert aapl.unrealized_pnl == 50.0
assert aapl.market_value == 1550.0
tsla = positions[1]
assert tsla.ticker == "TSLA"
assert tsla.qty == 5.0
assert tsla.unrealized_pnl == -50.0
# ---------------------------------------------------------------------------
# Account
# ---------------------------------------------------------------------------
class TestGetAccount:
@pytest.mark.asyncio
async def test_get_account(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""Account info is correctly converted from TradeAccount."""
mock_client.get_account.return_value = _make_alpaca_account(
equity="100000.00",
cash="50000.00",
buying_power="200000.00",
portfolio_value="100000.00",
)
account = await broker.get_account()
assert isinstance(account, AccountInfo)
assert account.equity == 100000.0
assert account.cash == 50000.0
assert account.buying_power == 200000.0
assert account.portfolio_value == 100000.0
# ---------------------------------------------------------------------------
# Order status
# ---------------------------------------------------------------------------
class TestGetOrderStatus:
@pytest.mark.asyncio
async def test_get_order_status_filled(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""A filled order returns FILLED status and the fill price."""
order_id = str(uuid.uuid4())
alpaca_order = _make_alpaca_order(
order_id=order_id,
symbol="AAPL",
side=AlpacaOrderSide.BUY,
qty="10",
status=AlpacaOrderStatus.FILLED,
filled_avg_price="152.30",
)
mock_client.get_order_by_id.return_value = alpaca_order
result = await broker.get_order_status(order_id)
assert isinstance(result, OrderResult)
assert result.order_id == order_id
assert result.status == OrderStatus.FILLED
assert result.filled_price == 152.30
assert result.ticker == "AAPL"
assert result.side == OrderSide.BUY
mock_client.get_order_by_id.assert_called_once_with(order_id)
@pytest.mark.asyncio
async def test_get_order_status_pending(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""A pending order returns PENDING status with no fill price."""
order_id = str(uuid.uuid4())
alpaca_order = _make_alpaca_order(
order_id=order_id,
symbol="MSFT",
status=AlpacaOrderStatus.PENDING_NEW,
)
mock_client.get_order_by_id.return_value = alpaca_order
result = await broker.get_order_status(order_id)
assert result.status == OrderStatus.PENDING
assert result.filled_price is None
@pytest.mark.asyncio
async def test_get_order_status_cancelled(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""A cancelled order returns CANCELLED status."""
order_id = str(uuid.uuid4())
alpaca_order = _make_alpaca_order(
order_id=order_id,
symbol="TSLA",
status=AlpacaOrderStatus.CANCELED,
)
mock_client.get_order_by_id.return_value = alpaca_order
result = await broker.get_order_status(order_id)
assert result.status == OrderStatus.CANCELLED
@pytest.mark.asyncio
async def test_get_order_status_rejected(self, broker: AlpacaBroker, mock_client: MagicMock) -> None:
"""A rejected order returns REJECTED status."""
order_id = str(uuid.uuid4())
alpaca_order = _make_alpaca_order(
order_id=order_id,
symbol="GOOG",
status=AlpacaOrderStatus.REJECTED,
)
mock_client.get_order_by_id.return_value = alpaca_order
result = await broker.get_order_status(order_id)
assert result.status == OrderStatus.REJECTED
# ---------------------------------------------------------------------------
# BaseBroker interface
# ---------------------------------------------------------------------------
class TestBaseBrokerInterface:
def test_alpaca_broker_is_subclass(self) -> None:
"""AlpacaBroker should be a proper subclass of BaseBroker."""
from shared.broker.base import BaseBroker
assert issubclass(AlpacaBroker, BaseBroker)
def test_package_exports(self) -> None:
"""The broker package should export BaseBroker and AlpacaBroker."""
from shared.broker import AlpacaBroker as AB
from shared.broker import BaseBroker as BB
assert AB is AlpacaBroker
from shared.broker.base import BaseBroker
assert BB is BaseBroker