trading/services/api_gateway/routes/trades.py

125 lines
4.3 KiB
Python

"""Trade endpoints — paginated trade history and detail."""
from __future__ import annotations
from datetime import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from services.api_gateway.auth.middleware import get_current_user
from sqlalchemy import select, desc, func
router = APIRouter(prefix="/api/trades", tags=["trades"])
@router.get("")
async def list_trades(
request: Request,
_user: dict = Depends(get_current_user),
ticker: str | None = Query(default=None),
start_date: datetime | None = Query(default=None),
end_date: datetime | None = Query(default=None),
strategy: str | None = Query(default=None),
profitable: bool | None = Query(default=None),
page: int = Query(default=1, ge=1),
per_page: int = Query(default=20, ge=1, le=100),
) -> dict:
"""Paginated trade history with optional filters."""
from shared.models.trading import Trade, Strategy
db = request.app.state.db_session_factory
async with db() as session:
query = select(Trade).order_by(desc(Trade.created_at))
count_query = select(func.count()).select_from(Trade)
# Apply filters
if ticker:
query = query.where(Trade.ticker == ticker.upper())
count_query = count_query.where(Trade.ticker == ticker.upper())
if start_date:
query = query.where(Trade.created_at >= start_date)
count_query = count_query.where(Trade.created_at >= start_date)
if end_date:
query = query.where(Trade.created_at <= end_date)
count_query = count_query.where(Trade.created_at <= end_date)
if strategy:
# Join with Strategy to filter by name
query = query.join(Strategy, Trade.strategy_id == Strategy.id).where(
Strategy.name == strategy
)
count_query = count_query.join(
Strategy, Trade.strategy_id == Strategy.id
).where(Strategy.name == strategy)
if profitable is not None:
if profitable:
query = query.where(Trade.pnl > 0)
count_query = count_query.where(Trade.pnl > 0)
else:
query = query.where(Trade.pnl <= 0)
count_query = count_query.where(Trade.pnl <= 0)
# Pagination
total = (await session.execute(count_query)).scalar() or 0
offset = (page - 1) * per_page
query = query.offset(offset).limit(per_page)
result = await session.execute(query)
trades = result.scalars().all()
return {
"trades": [
{
"id": str(t.id),
"ticker": t.ticker,
"side": t.side.value,
"qty": t.qty,
"price": t.price,
"status": t.status.value,
"pnl": t.pnl,
"strategy_id": str(t.strategy_id) if t.strategy_id else None,
"signal_id": str(t.signal_id) if t.signal_id else None,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in trades
],
"total": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page if per_page else 0,
}
@router.get("/{trade_id}")
async def get_trade(
trade_id: UUID,
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Single trade detail with linked signal and outcome."""
from shared.models.trading import Trade
db = request.app.state.db_session_factory
async with db() as session:
trade = (
await session.execute(select(Trade).where(Trade.id == trade_id))
).scalar_one_or_none()
if trade is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Trade not found",
)
return {
"id": str(trade.id),
"ticker": trade.ticker,
"side": trade.side.value,
"qty": trade.qty,
"price": trade.price,
"status": trade.status.value,
"pnl": trade.pnl,
"strategy_id": str(trade.strategy_id) if trade.strategy_id else None,
"signal_id": str(trade.signal_id) if trade.signal_id else None,
"created_at": trade.created_at.isoformat() if trade.created_at else None,
}