diff --git a/shared/broker/__init__.py b/shared/broker/__init__.py new file mode 100644 index 0000000..21772d6 --- /dev/null +++ b/shared/broker/__init__.py @@ -0,0 +1,11 @@ +"""Brokerage abstraction layer. + +Provides :class:`BaseBroker` (the interface) and :class:`AlpacaBroker` +(the default Alpaca implementation). Additional brokerage adapters can be +added by subclassing ``BaseBroker`` and implementing its abstract methods. +""" + +from shared.broker.alpaca_broker import AlpacaBroker +from shared.broker.base import BaseBroker + +__all__ = ["AlpacaBroker", "BaseBroker"] diff --git a/shared/broker/alpaca_broker.py b/shared/broker/alpaca_broker.py new file mode 100644 index 0000000..e5cda70 --- /dev/null +++ b/shared/broker/alpaca_broker.py @@ -0,0 +1,240 @@ +"""Alpaca brokerage adapter implementing :class:`BaseBroker`. + +Uses the ``alpaca-py`` SDK (``TradingClient``) for order management, +position retrieval, and account information. Paper vs. live trading is +controlled via the ``paper`` flag in the constructor. +""" + +from __future__ import annotations + +import asyncio +import logging +from datetime import datetime, timezone + +from alpaca.common.exceptions import APIError +from alpaca.trading.client import TradingClient +from alpaca.trading.enums import OrderSide as AlpacaOrderSide +from alpaca.trading.enums import OrderStatus as AlpacaOrderStatus +from alpaca.trading.enums import TimeInForce +from alpaca.trading.models import Order as AlpacaOrder +from alpaca.trading.models import Position as AlpacaPosition +from alpaca.trading.models import TradeAccount +from alpaca.trading.requests import ( + LimitOrderRequest, + MarketOrderRequest, + StopOrderRequest, +) + +from shared.broker.base import BaseBroker +from shared.schemas.trading import ( + AccountInfo, + OrderRequest, + OrderResult, + OrderSide, + OrderStatus, + OrderType, + PositionInfo, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Status mapping helpers +# --------------------------------------------------------------------------- + +_STATUS_MAP: dict[AlpacaOrderStatus, OrderStatus] = { + AlpacaOrderStatus.NEW: OrderStatus.PENDING, + AlpacaOrderStatus.ACCEPTED: OrderStatus.PENDING, + AlpacaOrderStatus.PENDING_NEW: OrderStatus.PENDING, + AlpacaOrderStatus.ACCEPTED_FOR_BIDDING: OrderStatus.PENDING, + AlpacaOrderStatus.PENDING_CANCEL: OrderStatus.PENDING, + AlpacaOrderStatus.PENDING_REPLACE: OrderStatus.PENDING, + AlpacaOrderStatus.PENDING_REVIEW: OrderStatus.PENDING, + AlpacaOrderStatus.HELD: OrderStatus.PENDING, + AlpacaOrderStatus.PARTIALLY_FILLED: OrderStatus.PENDING, + AlpacaOrderStatus.FILLED: OrderStatus.FILLED, + AlpacaOrderStatus.DONE_FOR_DAY: OrderStatus.FILLED, + AlpacaOrderStatus.CANCELED: OrderStatus.CANCELLED, + AlpacaOrderStatus.EXPIRED: OrderStatus.CANCELLED, + AlpacaOrderStatus.REPLACED: OrderStatus.CANCELLED, + AlpacaOrderStatus.STOPPED: OrderStatus.CANCELLED, + AlpacaOrderStatus.SUSPENDED: OrderStatus.CANCELLED, + AlpacaOrderStatus.REJECTED: OrderStatus.REJECTED, +} + + +def _map_status(alpaca_status: AlpacaOrderStatus) -> OrderStatus: + """Convert an Alpaca order status to our canonical ``OrderStatus``.""" + return _STATUS_MAP.get(alpaca_status, OrderStatus.PENDING) + + +def _map_side(alpaca_side: AlpacaOrderSide | None) -> OrderSide: + """Convert an Alpaca order side to our canonical ``OrderSide``.""" + if alpaca_side == AlpacaOrderSide.SELL: + return OrderSide.SELL + return OrderSide.BUY + + +# --------------------------------------------------------------------------- +# Alpaca broker implementation +# --------------------------------------------------------------------------- + + +class AlpacaBroker(BaseBroker): + """Brokerage adapter backed by the Alpaca ``TradingClient``. + + Parameters + ---------- + api_key: + Alpaca API key ID. + secret_key: + Alpaca API secret key. + paper: + If ``True`` (the default), connect to the Alpaca paper-trading + sandbox. Set to ``False`` for live trading. + """ + + def __init__(self, api_key: str, secret_key: str, *, paper: bool = True) -> None: + self._client = TradingClient( + api_key=api_key, + secret_key=secret_key, + paper=paper, + ) + + # -- internal helpers ---------------------------------------------------- + + def _build_order_request( + self, order: OrderRequest + ) -> MarketOrderRequest | LimitOrderRequest | StopOrderRequest: + """Convert our ``OrderRequest`` into the appropriate Alpaca request.""" + side = AlpacaOrderSide.BUY if order.side == OrderSide.BUY else AlpacaOrderSide.SELL + + if order.order_type == OrderType.LIMIT: + if order.limit_price is None: + raise ValueError("limit_price is required for limit orders") + return LimitOrderRequest( + symbol=order.ticker, + qty=order.qty, + side=side, + time_in_force=TimeInForce.DAY, + limit_price=order.limit_price, + ) + elif order.order_type == OrderType.STOP: + if order.stop_price is None: + raise ValueError("stop_price is required for stop orders") + return StopOrderRequest( + symbol=order.ticker, + qty=order.qty, + side=side, + time_in_force=TimeInForce.DAY, + stop_price=order.stop_price, + ) + else: + # Default: MARKET order + return MarketOrderRequest( + symbol=order.ticker, + qty=order.qty, + side=side, + time_in_force=TimeInForce.DAY, + ) + + @staticmethod + def _order_to_result(alpaca_order: AlpacaOrder) -> OrderResult: + """Convert an Alpaca ``Order`` model to our ``OrderResult``.""" + filled_price: float | None = None + if alpaca_order.filled_avg_price is not None: + filled_price = float(alpaca_order.filled_avg_price) + + qty = float(alpaca_order.qty) if alpaca_order.qty is not None else 0.0 + + timestamp = alpaca_order.submitted_at or alpaca_order.created_at or datetime.now(timezone.utc) + + return OrderResult( + order_id=str(alpaca_order.id), + ticker=alpaca_order.symbol or "", + side=_map_side(alpaca_order.side), + qty=qty, + filled_price=filled_price, + status=_map_status(alpaca_order.status), + timestamp=timestamp, + ) + + @staticmethod + def _position_to_info(pos: AlpacaPosition) -> PositionInfo: + """Convert an Alpaca ``Position`` to our ``PositionInfo``.""" + return PositionInfo( + ticker=pos.symbol, + qty=float(pos.qty), + avg_entry=float(pos.avg_entry_price), + current_price=float(pos.current_price) if pos.current_price else 0.0, + unrealized_pnl=float(pos.unrealized_pl) if pos.unrealized_pl else 0.0, + market_value=float(pos.market_value) if pos.market_value else 0.0, + ) + + # -- BaseBroker interface ------------------------------------------------ + + async def submit_order(self, order: OrderRequest) -> OrderResult: + """Submit an order to Alpaca. + + Converts the ``OrderRequest`` into the appropriate Alpaca request + object, submits it via the ``TradingClient``, and returns an + ``OrderResult`` reflecting the initial state of the order. + + If the API rejects the order an ``OrderResult`` with status + ``REJECTED`` is returned rather than raising an exception. + """ + try: + alpaca_request = self._build_order_request(order) + alpaca_order: AlpacaOrder = await asyncio.to_thread( + self._client.submit_order, alpaca_request + ) + return self._order_to_result(alpaca_order) + except APIError as exc: + logger.warning("Order rejected by Alpaca: %s", exc) + return OrderResult( + order_id="", + ticker=order.ticker, + side=order.side, + qty=order.qty, + filled_price=None, + status=OrderStatus.REJECTED, + timestamp=datetime.now(timezone.utc), + ) + + async def cancel_order(self, order_id: str) -> bool: + """Cancel an order on Alpaca by its ID. + + Returns ``True`` if the cancellation request was accepted, or + ``False`` if the API raised an error (e.g. the order was already + filled or does not exist). + """ + try: + await asyncio.to_thread(self._client.cancel_order_by_id, order_id) + return True + except APIError as exc: + logger.warning("Failed to cancel order %s: %s", order_id, exc) + return False + + async def get_positions(self) -> list[PositionInfo]: + """Return all open positions from Alpaca.""" + positions: list[AlpacaPosition] = await asyncio.to_thread( + self._client.get_all_positions + ) + return [self._position_to_info(p) for p in positions] + + async def get_account(self) -> AccountInfo: + """Return account summary from Alpaca.""" + account: TradeAccount = await asyncio.to_thread(self._client.get_account) + return AccountInfo( + equity=float(account.equity) if account.equity else 0.0, + cash=float(account.cash) if account.cash else 0.0, + buying_power=float(account.buying_power) if account.buying_power else 0.0, + portfolio_value=float(account.portfolio_value) if account.portfolio_value else 0.0, + ) + + async def get_order_status(self, order_id: str) -> OrderResult: + """Fetch the current state of an order from Alpaca.""" + alpaca_order: AlpacaOrder = await asyncio.to_thread( + self._client.get_order_by_id, order_id + ) + return self._order_to_result(alpaca_order) diff --git a/shared/broker/base.py b/shared/broker/base.py new file mode 100644 index 0000000..7ed6ce8 --- /dev/null +++ b/shared/broker/base.py @@ -0,0 +1,87 @@ +"""Abstract base class for brokerage integrations. + +All broker implementations must inherit from ``BaseBroker`` and provide +concrete implementations for order management, position tracking, and +account information retrieval. This abstraction layer allows the trading +bot to swap brokerages (Alpaca, Interactive Brokers, Tradier, ...) without +changing strategy or execution logic. +""" + +from abc import ABC, abstractmethod + +from shared.schemas.trading import AccountInfo, OrderRequest, OrderResult, PositionInfo + + +class BaseBroker(ABC): + """Interface that every brokerage adapter must implement.""" + + @abstractmethod + async def submit_order(self, order: OrderRequest) -> OrderResult: + """Submit a new order to the brokerage. + + Parameters + ---------- + order: + The order details including ticker, side, quantity, and order type. + + Returns + ------- + OrderResult + Result containing the order ID, status, and fill information. + """ + ... + + @abstractmethod + async def cancel_order(self, order_id: str) -> bool: + """Cancel an open order. + + Parameters + ---------- + order_id: + The brokerage-assigned order identifier. + + Returns + ------- + bool + ``True`` if the cancellation was accepted, ``False`` otherwise. + """ + ... + + @abstractmethod + async def get_positions(self) -> list[PositionInfo]: + """Return all currently open positions. + + Returns + ------- + list[PositionInfo] + One entry per open position with quantity, average entry, current + price, and unrealized P&L. + """ + ... + + @abstractmethod + async def get_account(self) -> AccountInfo: + """Return account-level summary information. + + Returns + ------- + AccountInfo + Equity, cash, buying power, and total portfolio value. + """ + ... + + @abstractmethod + async def get_order_status(self, order_id: str) -> OrderResult: + """Fetch the current status of an existing order. + + Parameters + ---------- + order_id: + The brokerage-assigned order identifier. + + Returns + ------- + OrderResult + Current state of the order including fill price if applicable. + """ + ... diff --git a/tests/test_broker.py b/tests/test_broker.py new file mode 100644 index 0000000..81b515a --- /dev/null +++ b/tests/test_broker.py @@ -0,0 +1,539 @@ +"""Tests for the brokerage abstraction layer with a mocked Alpaca TradingClient.""" + +import uuid +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from alpaca.common.exceptions import APIError +from alpaca.trading.enums import OrderSide as AlpacaOrderSide +from alpaca.trading.enums import OrderStatus as AlpacaOrderStatus +from alpaca.trading.enums import OrderType as AlpacaOrderType +from alpaca.trading.enums import AssetClass, AssetExchange, OrderClass, PositionSide, TimeInForce +from alpaca.trading.models import Order as AlpacaOrder +from alpaca.trading.models import Position as AlpacaPosition +from alpaca.trading.models import TradeAccount +from alpaca.trading.requests import LimitOrderRequest, MarketOrderRequest, StopOrderRequest + +from shared.broker.alpaca_broker import AlpacaBroker +from shared.schemas.trading import ( + AccountInfo, + OrderRequest, + OrderResult, + OrderSide, + OrderStatus, + OrderType, + PositionInfo, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_alpaca_order( + *, + order_id: str | None = None, + symbol: str = "AAPL", + side: AlpacaOrderSide = AlpacaOrderSide.BUY, + qty: str = "10", + filled_avg_price: str | None = None, + status: AlpacaOrderStatus = AlpacaOrderStatus.NEW, + submitted_at: datetime | None = None, +) -> AlpacaOrder: + """Build a minimal Alpaca ``Order`` model for testing.""" + oid = uuid.UUID(order_id) if order_id else uuid.uuid4() + now = submitted_at or datetime.now(timezone.utc) + return AlpacaOrder( + id=oid, + client_order_id=str(uuid.uuid4()), + created_at=now, + updated_at=now, + submitted_at=now, + filled_at=None, + expired_at=None, + canceled_at=None, + failed_at=None, + replaced_at=None, + replaced_by=None, + replaces=None, + asset_id=uuid.uuid4(), + symbol=symbol, + asset_class=AssetClass.US_EQUITY, + notional=None, + qty=qty, + filled_qty="0", + filled_avg_price=filled_avg_price, + order_class=OrderClass.SIMPLE, + order_type=AlpacaOrderType.MARKET, + type=AlpacaOrderType.MARKET, + side=side, + time_in_force=TimeInForce.DAY, + limit_price=None, + stop_price=None, + status=status, + extended_hours=False, + legs=None, + trail_percent=None, + trail_price=None, + hwm=None, + ) + + +def _make_alpaca_position( + *, + symbol: str = "AAPL", + qty: str = "10", + avg_entry_price: str = "150.00", + current_price: str = "155.00", + unrealized_pl: str = "50.00", + market_value: str = "1550.00", +) -> AlpacaPosition: + """Build a minimal Alpaca ``Position`` model for testing.""" + return AlpacaPosition( + asset_id=uuid.uuid4(), + symbol=symbol, + exchange=AssetExchange.NASDAQ, + asset_class=AssetClass.US_EQUITY, + asset_marginable=True, + avg_entry_price=avg_entry_price, + qty=qty, + side=PositionSide.LONG, + market_value=market_value, + cost_basis="1500.00", + unrealized_pl=unrealized_pl, + unrealized_plpc="0.0333", + unrealized_intraday_pl="50.00", + unrealized_intraday_plpc="0.0333", + current_price=current_price, + lastday_price="150.00", + change_today="0.0333", + ) + + +def _make_alpaca_account( + *, + equity: str = "100000.00", + cash: str = "50000.00", + buying_power: str = "200000.00", + portfolio_value: str = "100000.00", +) -> TradeAccount: + """Build a minimal Alpaca ``TradeAccount`` model for testing.""" + return TradeAccount( + id=uuid.uuid4(), + account_number="test-123", + status="ACTIVE", + crypto_status=None, + currency="USD", + buying_power=buying_power, + regt_buying_power=buying_power, + daytrading_buying_power=buying_power, + non_marginable_buying_power=cash, + cash=cash, + accrued_fees="0", + pending_transfer_in=None, + portfolio_value=portfolio_value, + pattern_day_trader=False, + trading_blocked=False, + transfers_blocked=False, + account_blocked=False, + created_at=datetime.now(timezone.utc), + trade_suspended_by_user=False, + multiplier="2", + shorting_enabled=True, + equity=equity, + last_equity=equity, + long_market_value="50000.00", + short_market_value="0", + initial_margin="25000.00", + maintenance_margin="15000.00", + last_maintenance_margin="15000.00", + sma="100000.00", + daytrade_count=0, + ) + + +@pytest.fixture +def mock_client() -> MagicMock: + """Return a mocked ``TradingClient``.""" + return MagicMock() + + +@pytest.fixture +def broker(mock_client: MagicMock) -> AlpacaBroker: + """Return an ``AlpacaBroker`` whose internal client is mocked.""" + with patch("shared.broker.alpaca_broker.TradingClient", return_value=mock_client): + b = AlpacaBroker(api_key="test-key", secret_key="test-secret", paper=True) + return b + + +# --------------------------------------------------------------------------- +# Order submission +# --------------------------------------------------------------------------- + + +class TestSubmitMarketOrder: + @pytest.mark.asyncio + async def test_submit_market_order(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """A market buy order should be converted and submitted to Alpaca.""" + alpaca_order = _make_alpaca_order( + symbol="AAPL", + side=AlpacaOrderSide.BUY, + qty="10", + status=AlpacaOrderStatus.NEW, + ) + mock_client.submit_order.return_value = alpaca_order + + order = OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10.0, order_type=OrderType.MARKET) + result = await broker.submit_order(order) + + assert isinstance(result, OrderResult) + assert result.ticker == "AAPL" + assert result.side == OrderSide.BUY + assert result.qty == 10.0 + assert result.status == OrderStatus.PENDING + assert result.order_id == str(alpaca_order.id) + assert result.filled_price is None + + # Verify the Alpaca client received a MarketOrderRequest + submitted = mock_client.submit_order.call_args[0][0] + assert isinstance(submitted, MarketOrderRequest) + assert submitted.symbol == "AAPL" + assert submitted.qty == 10.0 + assert submitted.side == AlpacaOrderSide.BUY + + @pytest.mark.asyncio + async def test_submit_market_sell(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """A market sell order maps the side correctly.""" + alpaca_order = _make_alpaca_order( + symbol="TSLA", + side=AlpacaOrderSide.SELL, + qty="5", + status=AlpacaOrderStatus.FILLED, + filled_avg_price="200.50", + ) + mock_client.submit_order.return_value = alpaca_order + + order = OrderRequest(ticker="TSLA", side=OrderSide.SELL, qty=5.0) + result = await broker.submit_order(order) + + assert result.side == OrderSide.SELL + assert result.status == OrderStatus.FILLED + assert result.filled_price == 200.50 + + +class TestSubmitLimitOrder: + @pytest.mark.asyncio + async def test_submit_limit_order(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """A limit order should include the limit price in the Alpaca request.""" + alpaca_order = _make_alpaca_order( + symbol="MSFT", + side=AlpacaOrderSide.BUY, + qty="20", + status=AlpacaOrderStatus.ACCEPTED, + ) + mock_client.submit_order.return_value = alpaca_order + + order = OrderRequest( + ticker="MSFT", + side=OrderSide.BUY, + qty=20.0, + order_type=OrderType.LIMIT, + limit_price=350.00, + ) + result = await broker.submit_order(order) + + assert isinstance(result, OrderResult) + assert result.ticker == "MSFT" + assert result.status == OrderStatus.PENDING + + submitted = mock_client.submit_order.call_args[0][0] + assert isinstance(submitted, LimitOrderRequest) + assert submitted.limit_price == 350.00 + + @pytest.mark.asyncio + async def test_limit_order_missing_price_raises(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """A limit order without limit_price should raise ValueError.""" + order = OrderRequest( + ticker="MSFT", + side=OrderSide.BUY, + qty=20.0, + order_type=OrderType.LIMIT, + limit_price=None, + ) + with pytest.raises(ValueError, match="limit_price is required"): + await broker.submit_order(order) + + +class TestSubmitStopOrder: + @pytest.mark.asyncio + async def test_submit_stop_order(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """A stop order should include the stop price in the Alpaca request.""" + alpaca_order = _make_alpaca_order( + symbol="GOOG", + side=AlpacaOrderSide.SELL, + qty="15", + status=AlpacaOrderStatus.NEW, + ) + mock_client.submit_order.return_value = alpaca_order + + order = OrderRequest( + ticker="GOOG", + side=OrderSide.SELL, + qty=15.0, + order_type=OrderType.STOP, + stop_price=140.00, + ) + result = await broker.submit_order(order) + + assert isinstance(result, OrderResult) + assert result.ticker == "GOOG" + + submitted = mock_client.submit_order.call_args[0][0] + assert isinstance(submitted, StopOrderRequest) + assert submitted.stop_price == 140.00 + + @pytest.mark.asyncio + async def test_stop_order_missing_price_raises(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """A stop order without stop_price should raise ValueError.""" + order = OrderRequest( + ticker="GOOG", + side=OrderSide.SELL, + qty=15.0, + order_type=OrderType.STOP, + stop_price=None, + ) + with pytest.raises(ValueError, match="stop_price is required"): + await broker.submit_order(order) + + +# --------------------------------------------------------------------------- +# Order rejection +# --------------------------------------------------------------------------- + + +class TestSubmitOrderRejected: + @pytest.mark.asyncio + async def test_api_error_returns_rejected_result(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """If Alpaca's API raises an error, submit_order returns REJECTED.""" + mock_client.submit_order.side_effect = APIError("insufficient buying power") + + order = OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10000.0) + result = await broker.submit_order(order) + + assert isinstance(result, OrderResult) + assert result.status == OrderStatus.REJECTED + assert result.ticker == "AAPL" + assert result.side == OrderSide.BUY + assert result.qty == 10000.0 + assert result.filled_price is None + assert result.order_id == "" + + +# --------------------------------------------------------------------------- +# Cancel order +# --------------------------------------------------------------------------- + + +class TestCancelOrder: + @pytest.mark.asyncio + async def test_cancel_order_success(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """Successful cancellation returns True.""" + mock_client.cancel_order_by_id.return_value = None # void method + + result = await broker.cancel_order("some-order-id") + + assert result is True + mock_client.cancel_order_by_id.assert_called_once_with("some-order-id") + + @pytest.mark.asyncio + async def test_cancel_order_failure(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """If the API raises an error (e.g. order already filled), return False.""" + mock_client.cancel_order_by_id.side_effect = APIError("order is not cancelable") + + result = await broker.cancel_order("some-order-id") + + assert result is False + + +# --------------------------------------------------------------------------- +# Positions +# --------------------------------------------------------------------------- + + +class TestGetPositions: + @pytest.mark.asyncio + async def test_get_positions_empty(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """When there are no open positions, return an empty list.""" + mock_client.get_all_positions.return_value = [] + + positions = await broker.get_positions() + + assert positions == [] + + @pytest.mark.asyncio + async def test_get_positions_with_data(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """Positions are correctly converted to PositionInfo.""" + mock_client.get_all_positions.return_value = [ + _make_alpaca_position( + symbol="AAPL", + qty="10", + avg_entry_price="150.00", + current_price="155.00", + unrealized_pl="50.00", + market_value="1550.00", + ), + _make_alpaca_position( + symbol="TSLA", + qty="5", + avg_entry_price="200.00", + current_price="190.00", + unrealized_pl="-50.00", + market_value="950.00", + ), + ] + + positions = await broker.get_positions() + + assert len(positions) == 2 + + aapl = positions[0] + assert isinstance(aapl, PositionInfo) + assert aapl.ticker == "AAPL" + assert aapl.qty == 10.0 + assert aapl.avg_entry == 150.0 + assert aapl.current_price == 155.0 + assert aapl.unrealized_pnl == 50.0 + assert aapl.market_value == 1550.0 + + tsla = positions[1] + assert tsla.ticker == "TSLA" + assert tsla.qty == 5.0 + assert tsla.unrealized_pnl == -50.0 + + +# --------------------------------------------------------------------------- +# Account +# --------------------------------------------------------------------------- + + +class TestGetAccount: + @pytest.mark.asyncio + async def test_get_account(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """Account info is correctly converted from TradeAccount.""" + mock_client.get_account.return_value = _make_alpaca_account( + equity="100000.00", + cash="50000.00", + buying_power="200000.00", + portfolio_value="100000.00", + ) + + account = await broker.get_account() + + assert isinstance(account, AccountInfo) + assert account.equity == 100000.0 + assert account.cash == 50000.0 + assert account.buying_power == 200000.0 + assert account.portfolio_value == 100000.0 + + +# --------------------------------------------------------------------------- +# Order status +# --------------------------------------------------------------------------- + + +class TestGetOrderStatus: + @pytest.mark.asyncio + async def test_get_order_status_filled(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """A filled order returns FILLED status and the fill price.""" + order_id = str(uuid.uuid4()) + alpaca_order = _make_alpaca_order( + order_id=order_id, + symbol="AAPL", + side=AlpacaOrderSide.BUY, + qty="10", + status=AlpacaOrderStatus.FILLED, + filled_avg_price="152.30", + ) + mock_client.get_order_by_id.return_value = alpaca_order + + result = await broker.get_order_status(order_id) + + assert isinstance(result, OrderResult) + assert result.order_id == order_id + assert result.status == OrderStatus.FILLED + assert result.filled_price == 152.30 + assert result.ticker == "AAPL" + assert result.side == OrderSide.BUY + mock_client.get_order_by_id.assert_called_once_with(order_id) + + @pytest.mark.asyncio + async def test_get_order_status_pending(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """A pending order returns PENDING status with no fill price.""" + order_id = str(uuid.uuid4()) + alpaca_order = _make_alpaca_order( + order_id=order_id, + symbol="MSFT", + status=AlpacaOrderStatus.PENDING_NEW, + ) + mock_client.get_order_by_id.return_value = alpaca_order + + result = await broker.get_order_status(order_id) + + assert result.status == OrderStatus.PENDING + assert result.filled_price is None + + @pytest.mark.asyncio + async def test_get_order_status_cancelled(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """A cancelled order returns CANCELLED status.""" + order_id = str(uuid.uuid4()) + alpaca_order = _make_alpaca_order( + order_id=order_id, + symbol="TSLA", + status=AlpacaOrderStatus.CANCELED, + ) + mock_client.get_order_by_id.return_value = alpaca_order + + result = await broker.get_order_status(order_id) + + assert result.status == OrderStatus.CANCELLED + + @pytest.mark.asyncio + async def test_get_order_status_rejected(self, broker: AlpacaBroker, mock_client: MagicMock) -> None: + """A rejected order returns REJECTED status.""" + order_id = str(uuid.uuid4()) + alpaca_order = _make_alpaca_order( + order_id=order_id, + symbol="GOOG", + status=AlpacaOrderStatus.REJECTED, + ) + mock_client.get_order_by_id.return_value = alpaca_order + + result = await broker.get_order_status(order_id) + + assert result.status == OrderStatus.REJECTED + + +# --------------------------------------------------------------------------- +# BaseBroker interface +# --------------------------------------------------------------------------- + + +class TestBaseBrokerInterface: + def test_alpaca_broker_is_subclass(self) -> None: + """AlpacaBroker should be a proper subclass of BaseBroker.""" + from shared.broker.base import BaseBroker + + assert issubclass(AlpacaBroker, BaseBroker) + + def test_package_exports(self) -> None: + """The broker package should export BaseBroker and AlpacaBroker.""" + from shared.broker import AlpacaBroker as AB + from shared.broker import BaseBroker as BB + + assert AB is AlpacaBroker + from shared.broker.base import BaseBroker + + assert BB is BaseBroker