feat: API gateway trading endpoints, controls, backtest, WebSocket

This commit is contained in:
Viktor Barzin 2026-02-22 15:54:20 +00:00
parent e0d138c457
commit 6fe586f722
No known key found for this signature in database
GPG key ID: 0EB088298288D958
11 changed files with 1304 additions and 0 deletions

View file

@ -69,6 +69,28 @@ def create_app(config: ApiGatewayConfig | None = None) -> FastAPI:
# Auth routes (unauthenticated)
app.include_router(auth_router)
# Trading routes (authenticated) — imported lazily to avoid circular deps
from services.api_gateway.routes.portfolio import router as portfolio_router
from services.api_gateway.routes.trades import router as trades_router
from services.api_gateway.routes.signals import router as signals_router
from services.api_gateway.routes.strategies import router as strategies_router
from services.api_gateway.routes.news import router as news_router
from services.api_gateway.routes.controls import router as controls_router
from services.api_gateway.routes.backtest import router as backtest_router
app.include_router(portfolio_router)
app.include_router(trades_router)
app.include_router(signals_router)
app.include_router(strategies_router)
app.include_router(news_router)
app.include_router(controls_router)
app.include_router(backtest_router)
# WebSocket
from services.api_gateway.ws import router as ws_router
app.include_router(ws_router)
# Health check
@app.get("/health", tags=["health"])
async def health() -> dict:

View file

@ -0,0 +1 @@
"""Trading route sub-package."""

View file

@ -0,0 +1,156 @@
"""Backtest endpoints — run backtests and retrieve results."""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel, Field
from services.api_gateway.auth.middleware import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
class BacktestRequest(BaseModel):
"""Request body for starting a new backtest."""
start_date: datetime
end_date: datetime
initial_capital: float = Field(default=100_000.0, gt=0)
commission_per_trade: float = Field(default=0.0, ge=0)
slippage_pct: float = Field(default=0.001, ge=0)
strategy_weights: dict[str, float] = Field(default_factory=dict)
max_position_pct: float = Field(default=0.05, gt=0, le=1.0)
signal_threshold: float = Field(default=0.3, ge=0, le=1.0)
@router.post("/run")
async def run_backtest(
body: BacktestRequest,
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Start a backtest with the given configuration.
Returns a ``run_id`` immediately. The backtest runs in a background
task and its results can be retrieved via ``GET /api/backtest/{run_id}``.
"""
run_id = str(uuid.uuid4())
redis = request.app.state.redis
# Store initial status
await redis.setex(
f"backtest:{run_id}",
86400, # 24h TTL
json.dumps({
"status": "running",
"config": body.model_dump(mode="json"),
"started_at": datetime.now(tz=timezone.utc).isoformat(),
}),
)
# Launch background task
asyncio.create_task(_run_backtest_task(run_id, body, redis))
return {"run_id": run_id, "status": "running"}
async def _run_backtest_task(
run_id: str,
config: BacktestRequest,
redis,
) -> None:
"""Execute the backtest in the background and store results in Redis."""
try:
from backtester.config import BacktestConfig
from backtester.engine import BacktestEngine
from shared.strategies.momentum import MomentumStrategy
from shared.strategies.mean_reversion import MeanReversionStrategy
from shared.strategies.news_driven import NewsDrivenStrategy
bt_config = BacktestConfig(
start_date=config.start_date,
end_date=config.end_date,
initial_capital=config.initial_capital,
commission_per_trade=config.commission_per_trade,
slippage_pct=config.slippage_pct,
strategy_weights=config.strategy_weights,
max_position_pct=config.max_position_pct,
signal_threshold=config.signal_threshold,
)
strategies = [
MomentumStrategy(),
MeanReversionStrategy(),
NewsDrivenStrategy(),
]
engine = BacktestEngine(config=bt_config, strategies=strategies)
# Use an in-memory stub data loader for now; a full implementation
# would read from TimescaleDB.
from backtester.data_loader import BacktestDataLoader
data_loader = BacktestDataLoader(
session_factory=None,
start_date=config.start_date,
end_date=config.end_date,
)
result = await engine.run(data_loader)
await redis.setex(
f"backtest:{run_id}",
86400,
json.dumps({
"status": "completed",
"config": config.model_dump(mode="json"),
"result": {
"total_return": result.total_return,
"annualized_return": result.annualized_return,
"sharpe_ratio": result.sharpe_ratio,
"sortino_ratio": result.sortino_ratio,
"max_drawdown": result.max_drawdown,
"max_drawdown_duration_days": result.max_drawdown_duration_days,
"win_rate": result.win_rate,
"trade_count": result.trade_count,
"avg_hold_duration_hours": result.avg_hold_duration_hours,
"profit_factor": result.profit_factor,
},
"completed_at": datetime.now(tz=timezone.utc).isoformat(),
}),
)
except Exception as exc:
logger.exception("Backtest %s failed", run_id)
await redis.setex(
f"backtest:{run_id}",
86400,
json.dumps({
"status": "failed",
"error": str(exc),
}),
)
@router.get("/{run_id}")
async def get_backtest(
run_id: str,
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Get backtest results by run ID."""
redis = request.app.state.redis
raw = await redis.get(f"backtest:{run_id}")
if raw is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Backtest run not found",
)
return json.loads(raw)

View file

@ -0,0 +1,95 @@
"""Control endpoints — pause/resume trading, force close positions."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from services.api_gateway.auth.middleware import get_current_user
router = APIRouter(prefix="/api/controls", tags=["controls"])
TRADING_PAUSED_KEY = "trading:paused"
class ClosePositionRequest(BaseModel):
"""Body for the force-close-position endpoint."""
ticker: str
@router.post("/pause")
async def pause_trading(
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Set Redis flag to pause trading."""
redis = request.app.state.redis
await redis.set(TRADING_PAUSED_KEY, "1")
return {"status": "paused"}
@router.post("/resume")
async def resume_trading(
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Clear pause flag to resume trading."""
redis = request.app.state.redis
await redis.delete(TRADING_PAUSED_KEY)
return {"status": "active"}
@router.post("/close-position")
async def close_position(
body: ClosePositionRequest,
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Force close an open position by ticker.
Publishes a close-position command to Redis so the trade executor
picks it up asynchronously.
"""
import json
from sqlalchemy import select
from shared.models.trading import Position
db = request.app.state.db_session_factory
async with db() as session:
position = (
await session.execute(
select(Position).where(Position.ticker == body.ticker.upper())
)
).scalar_one_or_none()
if position is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No open position for {body.ticker}",
)
# Publish close command to Redis for the trade executor
redis = request.app.state.redis
await redis.publish(
"controls:close_position",
json.dumps({"ticker": body.ticker.upper(), "qty": position.qty}),
)
return {
"status": "close_requested",
"ticker": body.ticker.upper(),
"qty": position.qty,
}
@router.get("/status")
async def get_trading_status(
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Current trading status — active or paused."""
redis = request.app.state.redis
paused = await redis.get(TRADING_PAUSED_KEY)
return {"status": "paused" if paused else "active"}

View file

@ -0,0 +1,87 @@
"""News endpoints — recent scored articles with filtering."""
from __future__ import annotations
from fastapi import APIRouter, Depends, Query, Request
from services.api_gateway.auth.middleware import get_current_user
from sqlalchemy import select, desc, func
router = APIRouter(prefix="/api/news", tags=["news"])
@router.get("")
async def list_news(
request: Request,
_user: dict = Depends(get_current_user),
ticker: str | None = Query(default=None),
source: str | None = Query(default=None),
min_score: float | None = Query(default=None, ge=-1.0, le=1.0),
max_score: float | None = Query(default=None, ge=-1.0, le=1.0),
page: int = Query(default=1, ge=1),
per_page: int = Query(default=20, ge=1, le=100),
) -> dict:
"""Recent scored articles with optional filters."""
from shared.models.news import Article, ArticleSentiment
db = request.app.state.db_session_factory
async with db() as session:
# Base query joining articles with sentiments
query = (
select(Article, ArticleSentiment)
.join(ArticleSentiment, Article.id == ArticleSentiment.article_id)
.order_by(desc(Article.fetched_at))
)
count_query = (
select(func.count())
.select_from(Article)
.join(ArticleSentiment, Article.id == ArticleSentiment.article_id)
)
if ticker:
query = query.where(ArticleSentiment.ticker == ticker.upper())
count_query = count_query.where(
ArticleSentiment.ticker == ticker.upper()
)
if source:
query = query.where(Article.source == source)
count_query = count_query.where(Article.source == source)
if min_score is not None:
query = query.where(ArticleSentiment.score >= min_score)
count_query = count_query.where(ArticleSentiment.score >= min_score)
if max_score is not None:
query = query.where(ArticleSentiment.score <= max_score)
count_query = count_query.where(ArticleSentiment.score <= max_score)
total = (await session.execute(count_query)).scalar() or 0
offset = (page - 1) * per_page
query = query.offset(offset).limit(per_page)
result = await session.execute(query)
rows = result.all()
return {
"articles": [
{
"id": str(article.id),
"source": article.source,
"url": article.url,
"title": article.title,
"published_at": (
article.published_at.isoformat()
if article.published_at
else None
),
"fetched_at": article.fetched_at.isoformat(),
"ticker": sentiment.ticker,
"sentiment_score": sentiment.score,
"confidence": sentiment.confidence,
"model_used": sentiment.model_used,
}
for article, sentiment in rows
],
"total": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page if per_page else 0,
}

View file

@ -0,0 +1,125 @@
"""Portfolio endpoints — current value, positions, equity curve."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from enum import Enum
from fastapi import APIRouter, Depends, Query, Request
from services.api_gateway.auth.middleware import get_current_user
from sqlalchemy import select, desc
router = APIRouter(prefix="/api/portfolio", tags=["portfolio"])
class HistoryPeriod(str, Enum):
ONE_DAY = "1d"
ONE_WEEK = "1w"
ONE_MONTH = "1m"
THREE_MONTHS = "3m"
ONE_YEAR = "1y"
def _period_to_timedelta(period: HistoryPeriod) -> timedelta:
"""Convert a period enum value to a timedelta."""
mapping = {
HistoryPeriod.ONE_DAY: timedelta(days=1),
HistoryPeriod.ONE_WEEK: timedelta(weeks=1),
HistoryPeriod.ONE_MONTH: timedelta(days=30),
HistoryPeriod.THREE_MONTHS: timedelta(days=90),
HistoryPeriod.ONE_YEAR: timedelta(days=365),
}
return mapping[period]
@router.get("")
async def get_portfolio(
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Current portfolio summary — value, cash, buying power, daily P&L."""
from shared.models.timeseries import PortfolioSnapshot
db = request.app.state.db_session_factory
async with db() as session:
latest = (
await session.execute(
select(PortfolioSnapshot)
.order_by(desc(PortfolioSnapshot.timestamp))
.limit(1)
)
).scalar_one_or_none()
if latest is None:
return {
"total_value": 0.0,
"cash": 0.0,
"buying_power": 0.0,
"daily_pnl": 0.0,
}
return {
"total_value": latest.total_value,
"cash": latest.cash,
"buying_power": latest.cash,
"daily_pnl": latest.daily_pnl,
}
@router.get("/positions")
async def get_positions(
request: Request,
_user: dict = Depends(get_current_user),
) -> list[dict]:
"""All open positions with unrealized P&L."""
from shared.models.trading import Position
db = request.app.state.db_session_factory
async with db() as session:
result = await session.execute(select(Position))
positions = result.scalars().all()
return [
{
"id": str(p.id),
"ticker": p.ticker,
"qty": p.qty,
"avg_entry": p.avg_entry,
"unrealized_pnl": p.unrealized_pnl or 0.0,
"stop_loss": p.stop_loss,
"take_profit": p.take_profit,
}
for p in positions
]
@router.get("/history")
async def get_portfolio_history(
request: Request,
_user: dict = Depends(get_current_user),
period: HistoryPeriod = Query(default=HistoryPeriod.ONE_MONTH),
) -> list[dict]:
"""Equity curve from portfolio_snapshots over a given period."""
from shared.models.timeseries import PortfolioSnapshot
since = datetime.now(timezone.utc) - _period_to_timedelta(period)
db = request.app.state.db_session_factory
async with db() as session:
result = await session.execute(
select(PortfolioSnapshot)
.where(PortfolioSnapshot.timestamp >= since)
.order_by(PortfolioSnapshot.timestamp)
)
snapshots = result.scalars().all()
return [
{
"timestamp": s.timestamp.isoformat(),
"total_value": s.total_value,
"cash": s.cash,
"positions_value": s.positions_value,
"daily_pnl": s.daily_pnl,
}
for s in snapshots
]

View file

@ -0,0 +1,58 @@
"""Signal endpoints — recent signals with filtering."""
from __future__ import annotations
from fastapi import APIRouter, Depends, Query, Request
from services.api_gateway.auth.middleware import get_current_user
from sqlalchemy import select, desc, func
router = APIRouter(prefix="/api/signals", tags=["signals"])
@router.get("")
async def list_signals(
request: Request,
_user: dict = Depends(get_current_user),
ticker: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
per_page: int = Query(default=20, ge=1, le=100),
) -> dict:
"""Recent signals with optional ticker filter and pagination."""
from shared.models.trading import Signal
db = request.app.state.db_session_factory
async with db() as session:
query = select(Signal).order_by(desc(Signal.created_at))
count_query = select(func.count()).select_from(Signal)
if ticker:
query = query.where(Signal.ticker == ticker.upper())
count_query = count_query.where(Signal.ticker == ticker.upper())
total = (await session.execute(count_query)).scalar() or 0
offset = (page - 1) * per_page
query = query.offset(offset).limit(per_page)
result = await session.execute(query)
signals = result.scalars().all()
return {
"signals": [
{
"id": str(s.id),
"ticker": s.ticker,
"direction": s.direction.value,
"strength": s.strength,
"strategy_sources": s.strategy_sources,
"sentiment_score": s.sentiment_score,
"acted_on": s.acted_on,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in signals
],
"total": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page if per_page else 0,
}

View file

@ -0,0 +1,111 @@
"""Strategy endpoints — list strategies, weight history, metrics."""
from __future__ import annotations
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Request, status
from services.api_gateway.auth.middleware import get_current_user
from sqlalchemy import select, desc
router = APIRouter(prefix="/api/strategies", tags=["strategies"])
@router.get("")
async def list_strategies(
request: Request,
_user: dict = Depends(get_current_user),
) -> list[dict]:
"""All strategies with current weights."""
from shared.models.trading import Strategy
db = request.app.state.db_session_factory
async with db() as session:
result = await session.execute(select(Strategy))
strategies = result.scalars().all()
return [
{
"id": str(s.id),
"name": s.name,
"description": s.description,
"current_weight": s.current_weight,
"active": s.active,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in strategies
]
@router.get("/{strategy_id}/history")
async def get_strategy_weight_history(
strategy_id: UUID,
request: Request,
_user: dict = Depends(get_current_user),
) -> list[dict]:
"""Weight history for a specific strategy."""
from shared.models.trading import StrategyWeightHistory, Strategy
db = request.app.state.db_session_factory
async with db() as session:
# Verify strategy exists
strategy = (
await session.execute(
select(Strategy).where(Strategy.id == strategy_id)
)
).scalar_one_or_none()
if strategy is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Strategy not found",
)
result = await session.execute(
select(StrategyWeightHistory)
.where(StrategyWeightHistory.strategy_id == strategy_id)
.order_by(desc(StrategyWeightHistory.created_at))
)
history = result.scalars().all()
return [
{
"id": str(h.id),
"old_weight": h.old_weight,
"new_weight": h.new_weight,
"reason": h.reason,
"created_at": h.created_at.isoformat() if h.created_at else None,
}
for h in history
]
@router.get("/{strategy_id}/metrics")
async def get_strategy_metrics(
strategy_id: UUID,
request: Request,
_user: dict = Depends(get_current_user),
) -> list[dict]:
"""Performance metrics over time for a specific strategy."""
from shared.models.timeseries import StrategyMetric
db = request.app.state.db_session_factory
async with db() as session:
result = await session.execute(
select(StrategyMetric)
.where(StrategyMetric.strategy_id == strategy_id)
.order_by(desc(StrategyMetric.timestamp))
.limit(100)
)
metrics = result.scalars().all()
return [
{
"timestamp": m.timestamp.isoformat(),
"win_rate": m.win_rate,
"total_pnl": m.total_pnl,
"trade_count": m.trade_count,
"sharpe_ratio": m.sharpe_ratio,
}
for m in metrics
]

View file

@ -0,0 +1,125 @@
"""Trade endpoints — paginated trade history and detail."""
from __future__ import annotations
from datetime import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from services.api_gateway.auth.middleware import get_current_user
from sqlalchemy import select, desc, func
router = APIRouter(prefix="/api/trades", tags=["trades"])
@router.get("")
async def list_trades(
request: Request,
_user: dict = Depends(get_current_user),
ticker: str | None = Query(default=None),
start_date: datetime | None = Query(default=None),
end_date: datetime | None = Query(default=None),
strategy: str | None = Query(default=None),
profitable: bool | None = Query(default=None),
page: int = Query(default=1, ge=1),
per_page: int = Query(default=20, ge=1, le=100),
) -> dict:
"""Paginated trade history with optional filters."""
from shared.models.trading import Trade, Strategy
db = request.app.state.db_session_factory
async with db() as session:
query = select(Trade).order_by(desc(Trade.created_at))
count_query = select(func.count()).select_from(Trade)
# Apply filters
if ticker:
query = query.where(Trade.ticker == ticker.upper())
count_query = count_query.where(Trade.ticker == ticker.upper())
if start_date:
query = query.where(Trade.created_at >= start_date)
count_query = count_query.where(Trade.created_at >= start_date)
if end_date:
query = query.where(Trade.created_at <= end_date)
count_query = count_query.where(Trade.created_at <= end_date)
if strategy:
# Join with Strategy to filter by name
query = query.join(Strategy, Trade.strategy_id == Strategy.id).where(
Strategy.name == strategy
)
count_query = count_query.join(
Strategy, Trade.strategy_id == Strategy.id
).where(Strategy.name == strategy)
if profitable is not None:
if profitable:
query = query.where(Trade.pnl > 0)
count_query = count_query.where(Trade.pnl > 0)
else:
query = query.where(Trade.pnl <= 0)
count_query = count_query.where(Trade.pnl <= 0)
# Pagination
total = (await session.execute(count_query)).scalar() or 0
offset = (page - 1) * per_page
query = query.offset(offset).limit(per_page)
result = await session.execute(query)
trades = result.scalars().all()
return {
"trades": [
{
"id": str(t.id),
"ticker": t.ticker,
"side": t.side.value,
"qty": t.qty,
"price": t.price,
"status": t.status.value,
"pnl": t.pnl,
"strategy_id": str(t.strategy_id) if t.strategy_id else None,
"signal_id": str(t.signal_id) if t.signal_id else None,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in trades
],
"total": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page if per_page else 0,
}
@router.get("/{trade_id}")
async def get_trade(
trade_id: UUID,
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Single trade detail with linked signal and outcome."""
from shared.models.trading import Trade
db = request.app.state.db_session_factory
async with db() as session:
trade = (
await session.execute(select(Trade).where(Trade.id == trade_id))
).scalar_one_or_none()
if trade is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Trade not found",
)
return {
"id": str(trade.id),
"ticker": trade.ticker,
"side": trade.side.value,
"qty": trade.qty,
"price": trade.price,
"status": trade.status.value,
"pnl": trade.pnl,
"strategy_id": str(trade.strategy_id) if trade.strategy_id else None,
"signal_id": str(trade.signal_id) if trade.signal_id else None,
"created_at": trade.created_at.isoformat() if trade.created_at else None,
}

138
services/api_gateway/ws.py Normal file
View file

@ -0,0 +1,138 @@
"""WebSocket endpoint — authenticated real-time event stream."""
from __future__ import annotations
import asyncio
import json
import logging
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
import jwt as pyjwt
from services.api_gateway.auth.jwt import decode_token
from services.api_gateway.config import ApiGatewayConfig
logger = logging.getLogger(__name__)
router = APIRouter(tags=["websocket"])
class ConnectionManager:
"""Manages active WebSocket connections."""
def __init__(self) -> None:
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket) -> None:
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket) -> None:
if websocket in self.active_connections:
self.active_connections.remove(websocket)
async def broadcast(self, message: dict) -> None:
"""Send a message to all connected clients. Remove any that have
disconnected."""
disconnected: list[WebSocket] = []
for connection in self.active_connections:
try:
await connection.send_json(message)
except Exception:
disconnected.append(connection)
for ws in disconnected:
self.disconnect(ws)
manager = ConnectionManager()
@router.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str = Query(default=""),
) -> None:
"""Authenticated WebSocket that pushes real-time trading events.
Connect via ``ws://host/ws?token=<JWT>``. The server subscribes
to Redis pub/sub channels and forwards events to all connected
clients.
Event types pushed to clients:
- ``trade_executed``
- ``signal_generated``
- ``portfolio_update``
"""
# Authenticate via JWT query parameter
config: ApiGatewayConfig = websocket.app.state.config
try:
payload = decode_token(token, config)
if payload.get("type") != "access":
await websocket.close(code=4001, reason="Invalid token type")
return
except pyjwt.ExpiredSignatureError:
await websocket.close(code=4001, reason="Token expired")
return
except pyjwt.InvalidTokenError:
await websocket.close(code=4001, reason="Invalid token")
return
await manager.connect(websocket)
# Subscribe to Redis pub/sub channels for real-time events
redis = websocket.app.state.redis
pubsub = redis.pubsub()
await pubsub.subscribe(
"events:trade_executed",
"events:signal_generated",
"events:portfolio_update",
)
try:
# Run two concurrent tasks: one reads from Redis pub/sub and pushes
# to the client, the other keeps the WebSocket alive by reading
# (and discarding) incoming messages.
async def _redis_listener() -> None:
"""Forward Redis pub/sub messages to this WebSocket client."""
while True:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if message and message["type"] == "message":
try:
data = json.loads(message["data"])
except (json.JSONDecodeError, TypeError):
data = {"raw": str(message["data"])}
channel = message["channel"]
if isinstance(channel, bytes):
channel = channel.decode()
event_type = channel.replace("events:", "")
await websocket.send_json(
{"event": event_type, "data": data}
)
await asyncio.sleep(0.1)
async def _ws_receiver() -> None:
"""Keep the connection alive by reading messages."""
while True:
await websocket.receive_text()
# Run both tasks; whichever finishes first (e.g., client disconnect)
# will cancel the other.
listener_task = asyncio.create_task(_redis_listener())
receiver_task = asyncio.create_task(_ws_receiver())
done, pending = await asyncio.wait(
{listener_task, receiver_task},
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
except WebSocketDisconnect:
pass
finally:
manager.disconnect(websocket)
await pubsub.unsubscribe()
await pubsub.aclose()

View file

@ -0,0 +1,386 @@
"""Tests for API Gateway trading endpoints (Task 14).
Uses FastAPI TestClient with mocked DB sessions and Redis.
"""
from __future__ import annotations
import json
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from services.api_gateway.auth.jwt import create_access_token
from services.api_gateway.auth.middleware import get_config, get_current_user
from services.api_gateway.config import ApiGatewayConfig
from services.api_gateway.main import create_app
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def config() -> ApiGatewayConfig:
return ApiGatewayConfig(
jwt_secret_key="test-secret-for-routes",
database_url="sqlite+aiosqlite:///:memory:",
redis_url="redis://localhost:6379/0",
)
@pytest.fixture()
def mock_user() -> dict:
return {"sub": "user-test", "username": "tester", "type": "access"}
@pytest.fixture()
def auth_headers(config: ApiGatewayConfig) -> dict[str, str]:
token = create_access_token("user-test", "tester", config)
return {"Authorization": f"Bearer {token}"}
@pytest.fixture()
def mock_redis() -> AsyncMock:
"""A fully mocked Redis client."""
redis = AsyncMock()
redis.get = AsyncMock(return_value=None)
redis.set = AsyncMock()
redis.setex = AsyncMock()
redis.delete = AsyncMock()
redis.publish = AsyncMock()
return redis
@pytest.fixture()
def mock_session_factory():
"""Creates a mock async session factory that returns a mock session."""
session = AsyncMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=False)
factory = MagicMock()
factory.return_value = session
return factory, session
@pytest.fixture()
def client(
config: ApiGatewayConfig,
mock_user: dict,
mock_redis: AsyncMock,
mock_session_factory,
) -> TestClient:
"""Create a test client with all dependencies mocked."""
factory, session = mock_session_factory
app = create_app(config)
# Override auth dependency to bypass JWT validation
app.dependency_overrides[get_current_user] = lambda: mock_user
app.dependency_overrides[get_config] = lambda: config
# Inject mock state
app.state.redis = mock_redis
app.state.db_session_factory = factory
app.state.db_engine = MagicMock()
app.state.config = config
return TestClient(app, raise_server_exceptions=False)
# ---------------------------------------------------------------------------
# Helper: build mock execute results
# ---------------------------------------------------------------------------
def _make_execute_result(rows, scalar=None):
"""Build a mock result for session.execute()."""
result = MagicMock()
result.scalars.return_value.all.return_value = rows
result.scalars.return_value.__iter__ = lambda self: iter(rows)
result.scalar_one_or_none.return_value = scalar
result.scalar.return_value = len(rows) if scalar is None else scalar
result.all.return_value = rows
return result
# ---------------------------------------------------------------------------
# Portfolio Tests
# ---------------------------------------------------------------------------
class TestPortfolioEndpoint:
"""test_portfolio_endpoint."""
def test_portfolio_returns_defaults_when_empty(
self, client: TestClient, mock_session_factory
) -> None:
_, session = mock_session_factory
session.execute = AsyncMock(
return_value=_make_execute_result([], scalar=None)
)
resp = client.get("/api/portfolio")
assert resp.status_code == 200
data = resp.json()
assert data["total_value"] == 0.0
assert data["cash"] == 0.0
assert data["daily_pnl"] == 0.0
class TestPositionsEndpoint:
"""test_positions_endpoint."""
def test_positions_returns_list(
self, client: TestClient, mock_session_factory
) -> None:
_, session = mock_session_factory
# Create mock positions
pos = MagicMock()
pos.id = uuid.uuid4()
pos.ticker = "AAPL"
pos.qty = 10.0
pos.avg_entry = 150.0
pos.unrealized_pnl = 50.0
pos.stop_loss = 145.0
pos.take_profit = 160.0
session.execute = AsyncMock(
return_value=_make_execute_result([pos])
)
resp = client.get("/api/portfolio/positions")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["ticker"] == "AAPL"
assert data[0]["qty"] == 10.0
# ---------------------------------------------------------------------------
# Trades Tests
# ---------------------------------------------------------------------------
class TestTradesListEndpoint:
"""test_trades_list_endpoint."""
def test_trades_returns_paginated_list(
self, client: TestClient, mock_session_factory
) -> None:
_, session = mock_session_factory
trade = MagicMock()
trade.id = uuid.uuid4()
trade.ticker = "TSLA"
trade.side.value = "BUY"
trade.qty = 5.0
trade.price = 200.0
trade.status.value = "FILLED"
trade.pnl = 25.0
trade.strategy_id = None
trade.signal_id = None
trade.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
# session.execute will be called twice: count + data
count_result = _make_execute_result([], scalar=1)
data_result = _make_execute_result([trade])
session.execute = AsyncMock(side_effect=[count_result, data_result])
resp = client.get("/api/trades")
assert resp.status_code == 200
data = resp.json()
assert "trades" in data
assert "total" in data
assert "page" in data
class TestTradesPagination:
"""test_trades_pagination."""
def test_trades_page_and_per_page(
self, client: TestClient, mock_session_factory
) -> None:
_, session = mock_session_factory
count_result = _make_execute_result([], scalar=50)
data_result = _make_execute_result([])
session.execute = AsyncMock(side_effect=[count_result, data_result])
resp = client.get("/api/trades?page=3&per_page=10")
assert resp.status_code == 200
data = resp.json()
assert data["page"] == 3
assert data["per_page"] == 10
assert data["pages"] == 5 # 50 / 10
# ---------------------------------------------------------------------------
# Strategies Tests
# ---------------------------------------------------------------------------
class TestStrategiesEndpoint:
"""test_strategies_endpoint."""
def test_strategies_returns_list(
self, client: TestClient, mock_session_factory
) -> None:
_, session = mock_session_factory
strategy = MagicMock()
strategy.id = uuid.uuid4()
strategy.name = "momentum"
strategy.description = "Momentum strategy"
strategy.current_weight = 0.333
strategy.active = True
strategy.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
session.execute = AsyncMock(
return_value=_make_execute_result([strategy])
)
resp = client.get("/api/strategies")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["name"] == "momentum"
assert data[0]["current_weight"] == 0.333
# ---------------------------------------------------------------------------
# News Tests
# ---------------------------------------------------------------------------
class TestNewsEndpoint:
"""test_news_endpoint."""
def test_news_returns_paginated_articles(
self, client: TestClient, mock_session_factory
) -> None:
_, session = mock_session_factory
article = MagicMock()
article.id = uuid.uuid4()
article.source = "reuters"
article.url = "https://reuters.com/article/1"
article.title = "Stock rises"
article.published_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
article.fetched_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
sentiment = MagicMock()
sentiment.ticker = "AAPL"
sentiment.score = 0.8
sentiment.confidence = 0.9
sentiment.model_used = "finbert"
count_result = _make_execute_result([], scalar=1)
data_result = MagicMock()
data_result.all.return_value = [(article, sentiment)]
session.execute = AsyncMock(side_effect=[count_result, data_result])
resp = client.get("/api/news")
assert resp.status_code == 200
data = resp.json()
assert "articles" in data
assert data["total"] == 1
# ---------------------------------------------------------------------------
# Controls Tests
# ---------------------------------------------------------------------------
class TestControlsPauseResume:
"""test_controls_pause_resume."""
def test_pause_sets_redis_key(
self, client: TestClient, mock_redis: AsyncMock
) -> None:
resp = client.post("/api/controls/pause")
assert resp.status_code == 200
assert resp.json()["status"] == "paused"
mock_redis.set.assert_called_once_with("trading:paused", "1")
def test_resume_clears_redis_key(
self, client: TestClient, mock_redis: AsyncMock
) -> None:
resp = client.post("/api/controls/resume")
assert resp.status_code == 200
assert resp.json()["status"] == "active"
mock_redis.delete.assert_called_once_with("trading:paused")
class TestControlsStatus:
"""test_controls_status."""
def test_status_active_when_not_paused(
self, client: TestClient, mock_redis: AsyncMock
) -> None:
mock_redis.get = AsyncMock(return_value=None)
resp = client.get("/api/controls/status")
assert resp.status_code == 200
assert resp.json()["status"] == "active"
def test_status_paused_when_flag_set(
self, client: TestClient, mock_redis: AsyncMock
) -> None:
mock_redis.get = AsyncMock(return_value="1")
resp = client.get("/api/controls/status")
assert resp.status_code == 200
assert resp.json()["status"] == "paused"
# ---------------------------------------------------------------------------
# Backtest Tests
# ---------------------------------------------------------------------------
class TestBacktestRunEndpoint:
"""test_backtest_run_endpoint."""
def test_backtest_run_returns_run_id(
self, client: TestClient, mock_redis: AsyncMock
) -> None:
resp = client.post(
"/api/backtest/run",
json={
"start_date": "2024-01-01T00:00:00Z",
"end_date": "2024-06-01T00:00:00Z",
"initial_capital": 100000,
},
)
assert resp.status_code == 200
data = resp.json()
assert "run_id" in data
assert data["status"] == "running"
mock_redis.setex.assert_called()
def test_backtest_get_not_found(
self, client: TestClient, mock_redis: AsyncMock
) -> None:
mock_redis.get = AsyncMock(return_value=None)
resp = client.get("/api/backtest/nonexistent-id")
assert resp.status_code == 404
def test_backtest_get_returns_result(
self, client: TestClient, mock_redis: AsyncMock
) -> None:
stored = json.dumps({
"status": "completed",
"result": {"total_return": 0.15, "sharpe_ratio": 1.2},
})
mock_redis.get = AsyncMock(return_value=stored)
resp = client.get("/api/backtest/some-run-id")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "completed"
assert data["result"]["total_return"] == 0.15