diff --git a/services/api_gateway/main.py b/services/api_gateway/main.py index 6edf979..f286f45 100644 --- a/services/api_gateway/main.py +++ b/services/api_gateway/main.py @@ -69,6 +69,28 @@ def create_app(config: ApiGatewayConfig | None = None) -> FastAPI: # Auth routes (unauthenticated) app.include_router(auth_router) + # Trading routes (authenticated) — imported lazily to avoid circular deps + from services.api_gateway.routes.portfolio import router as portfolio_router + from services.api_gateway.routes.trades import router as trades_router + from services.api_gateway.routes.signals import router as signals_router + from services.api_gateway.routes.strategies import router as strategies_router + from services.api_gateway.routes.news import router as news_router + from services.api_gateway.routes.controls import router as controls_router + from services.api_gateway.routes.backtest import router as backtest_router + + app.include_router(portfolio_router) + app.include_router(trades_router) + app.include_router(signals_router) + app.include_router(strategies_router) + app.include_router(news_router) + app.include_router(controls_router) + app.include_router(backtest_router) + + # WebSocket + from services.api_gateway.ws import router as ws_router + + app.include_router(ws_router) + # Health check @app.get("/health", tags=["health"]) async def health() -> dict: diff --git a/services/api_gateway/routes/__init__.py b/services/api_gateway/routes/__init__.py new file mode 100644 index 0000000..f2e55ae --- /dev/null +++ b/services/api_gateway/routes/__init__.py @@ -0,0 +1 @@ +"""Trading route sub-package.""" diff --git a/services/api_gateway/routes/backtest.py b/services/api_gateway/routes/backtest.py new file mode 100644 index 0000000..728b11d --- /dev/null +++ b/services/api_gateway/routes/backtest.py @@ -0,0 +1,156 @@ +"""Backtest endpoints — run backtests and retrieve results.""" + +from __future__ import annotations + +import asyncio +import json +import logging +import uuid +from datetime import datetime, timezone + +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status +from pydantic import BaseModel, Field + +from services.api_gateway.auth.middleware import get_current_user + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/backtest", tags=["backtest"]) + + +class BacktestRequest(BaseModel): + """Request body for starting a new backtest.""" + + start_date: datetime + end_date: datetime + initial_capital: float = Field(default=100_000.0, gt=0) + commission_per_trade: float = Field(default=0.0, ge=0) + slippage_pct: float = Field(default=0.001, ge=0) + strategy_weights: dict[str, float] = Field(default_factory=dict) + max_position_pct: float = Field(default=0.05, gt=0, le=1.0) + signal_threshold: float = Field(default=0.3, ge=0, le=1.0) + + +@router.post("/run") +async def run_backtest( + body: BacktestRequest, + request: Request, + _user: dict = Depends(get_current_user), +) -> dict: + """Start a backtest with the given configuration. + + Returns a ``run_id`` immediately. The backtest runs in a background + task and its results can be retrieved via ``GET /api/backtest/{run_id}``. + """ + run_id = str(uuid.uuid4()) + redis = request.app.state.redis + + # Store initial status + await redis.setex( + f"backtest:{run_id}", + 86400, # 24h TTL + json.dumps({ + "status": "running", + "config": body.model_dump(mode="json"), + "started_at": datetime.now(tz=timezone.utc).isoformat(), + }), + ) + + # Launch background task + asyncio.create_task(_run_backtest_task(run_id, body, redis)) + + return {"run_id": run_id, "status": "running"} + + +async def _run_backtest_task( + run_id: str, + config: BacktestRequest, + redis, +) -> None: + """Execute the backtest in the background and store results in Redis.""" + try: + from backtester.config import BacktestConfig + from backtester.engine import BacktestEngine + from shared.strategies.momentum import MomentumStrategy + from shared.strategies.mean_reversion import MeanReversionStrategy + from shared.strategies.news_driven import NewsDrivenStrategy + + bt_config = BacktestConfig( + start_date=config.start_date, + end_date=config.end_date, + initial_capital=config.initial_capital, + commission_per_trade=config.commission_per_trade, + slippage_pct=config.slippage_pct, + strategy_weights=config.strategy_weights, + max_position_pct=config.max_position_pct, + signal_threshold=config.signal_threshold, + ) + + strategies = [ + MomentumStrategy(), + MeanReversionStrategy(), + NewsDrivenStrategy(), + ] + + engine = BacktestEngine(config=bt_config, strategies=strategies) + + # Use an in-memory stub data loader for now; a full implementation + # would read from TimescaleDB. + from backtester.data_loader import BacktestDataLoader + + data_loader = BacktestDataLoader( + session_factory=None, + start_date=config.start_date, + end_date=config.end_date, + ) + + result = await engine.run(data_loader) + + await redis.setex( + f"backtest:{run_id}", + 86400, + json.dumps({ + "status": "completed", + "config": config.model_dump(mode="json"), + "result": { + "total_return": result.total_return, + "annualized_return": result.annualized_return, + "sharpe_ratio": result.sharpe_ratio, + "sortino_ratio": result.sortino_ratio, + "max_drawdown": result.max_drawdown, + "max_drawdown_duration_days": result.max_drawdown_duration_days, + "win_rate": result.win_rate, + "trade_count": result.trade_count, + "avg_hold_duration_hours": result.avg_hold_duration_hours, + "profit_factor": result.profit_factor, + }, + "completed_at": datetime.now(tz=timezone.utc).isoformat(), + }), + ) + except Exception as exc: + logger.exception("Backtest %s failed", run_id) + await redis.setex( + f"backtest:{run_id}", + 86400, + json.dumps({ + "status": "failed", + "error": str(exc), + }), + ) + + +@router.get("/{run_id}") +async def get_backtest( + run_id: str, + request: Request, + _user: dict = Depends(get_current_user), +) -> dict: + """Get backtest results by run ID.""" + redis = request.app.state.redis + raw = await redis.get(f"backtest:{run_id}") + if raw is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Backtest run not found", + ) + return json.loads(raw) diff --git a/services/api_gateway/routes/controls.py b/services/api_gateway/routes/controls.py new file mode 100644 index 0000000..456430f --- /dev/null +++ b/services/api_gateway/routes/controls.py @@ -0,0 +1,95 @@ +"""Control endpoints — pause/resume trading, force close positions.""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import BaseModel + +from services.api_gateway.auth.middleware import get_current_user + +router = APIRouter(prefix="/api/controls", tags=["controls"]) + +TRADING_PAUSED_KEY = "trading:paused" + + +class ClosePositionRequest(BaseModel): + """Body for the force-close-position endpoint.""" + + ticker: str + + +@router.post("/pause") +async def pause_trading( + request: Request, + _user: dict = Depends(get_current_user), +) -> dict: + """Set Redis flag to pause trading.""" + redis = request.app.state.redis + await redis.set(TRADING_PAUSED_KEY, "1") + return {"status": "paused"} + + +@router.post("/resume") +async def resume_trading( + request: Request, + _user: dict = Depends(get_current_user), +) -> dict: + """Clear pause flag to resume trading.""" + redis = request.app.state.redis + await redis.delete(TRADING_PAUSED_KEY) + return {"status": "active"} + + +@router.post("/close-position") +async def close_position( + body: ClosePositionRequest, + request: Request, + _user: dict = Depends(get_current_user), +) -> dict: + """Force close an open position by ticker. + + Publishes a close-position command to Redis so the trade executor + picks it up asynchronously. + """ + import json + + from sqlalchemy import select + from shared.models.trading import Position + + db = request.app.state.db_session_factory + async with db() as session: + position = ( + await session.execute( + select(Position).where(Position.ticker == body.ticker.upper()) + ) + ).scalar_one_or_none() + + if position is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"No open position for {body.ticker}", + ) + + # Publish close command to Redis for the trade executor + redis = request.app.state.redis + await redis.publish( + "controls:close_position", + json.dumps({"ticker": body.ticker.upper(), "qty": position.qty}), + ) + + return { + "status": "close_requested", + "ticker": body.ticker.upper(), + "qty": position.qty, + } + + +@router.get("/status") +async def get_trading_status( + request: Request, + _user: dict = Depends(get_current_user), +) -> dict: + """Current trading status — active or paused.""" + redis = request.app.state.redis + paused = await redis.get(TRADING_PAUSED_KEY) + return {"status": "paused" if paused else "active"} diff --git a/services/api_gateway/routes/news.py b/services/api_gateway/routes/news.py new file mode 100644 index 0000000..d71580b --- /dev/null +++ b/services/api_gateway/routes/news.py @@ -0,0 +1,87 @@ +"""News endpoints — recent scored articles with filtering.""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query, Request + +from services.api_gateway.auth.middleware import get_current_user +from sqlalchemy import select, desc, func + +router = APIRouter(prefix="/api/news", tags=["news"]) + + +@router.get("") +async def list_news( + request: Request, + _user: dict = Depends(get_current_user), + ticker: str | None = Query(default=None), + source: str | None = Query(default=None), + min_score: float | None = Query(default=None, ge=-1.0, le=1.0), + max_score: float | None = Query(default=None, ge=-1.0, le=1.0), + page: int = Query(default=1, ge=1), + per_page: int = Query(default=20, ge=1, le=100), +) -> dict: + """Recent scored articles with optional filters.""" + from shared.models.news import Article, ArticleSentiment + + db = request.app.state.db_session_factory + async with db() as session: + # Base query joining articles with sentiments + query = ( + select(Article, ArticleSentiment) + .join(ArticleSentiment, Article.id == ArticleSentiment.article_id) + .order_by(desc(Article.fetched_at)) + ) + count_query = ( + select(func.count()) + .select_from(Article) + .join(ArticleSentiment, Article.id == ArticleSentiment.article_id) + ) + + if ticker: + query = query.where(ArticleSentiment.ticker == ticker.upper()) + count_query = count_query.where( + ArticleSentiment.ticker == ticker.upper() + ) + if source: + query = query.where(Article.source == source) + count_query = count_query.where(Article.source == source) + if min_score is not None: + query = query.where(ArticleSentiment.score >= min_score) + count_query = count_query.where(ArticleSentiment.score >= min_score) + if max_score is not None: + query = query.where(ArticleSentiment.score <= max_score) + count_query = count_query.where(ArticleSentiment.score <= max_score) + + total = (await session.execute(count_query)).scalar() or 0 + offset = (page - 1) * per_page + query = query.offset(offset).limit(per_page) + + result = await session.execute(query) + rows = result.all() + + return { + "articles": [ + { + "id": str(article.id), + "source": article.source, + "url": article.url, + "title": article.title, + "published_at": ( + article.published_at.isoformat() + if article.published_at + else None + ), + "fetched_at": article.fetched_at.isoformat(), + "ticker": sentiment.ticker, + "sentiment_score": sentiment.score, + "confidence": sentiment.confidence, + "model_used": sentiment.model_used, + } + for article, sentiment in rows + ], + "total": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page if per_page else 0, + } diff --git a/services/api_gateway/routes/portfolio.py b/services/api_gateway/routes/portfolio.py new file mode 100644 index 0000000..f512071 --- /dev/null +++ b/services/api_gateway/routes/portfolio.py @@ -0,0 +1,125 @@ +"""Portfolio endpoints — current value, positions, equity curve.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from enum import Enum + +from fastapi import APIRouter, Depends, Query, Request + +from services.api_gateway.auth.middleware import get_current_user +from sqlalchemy import select, desc + +router = APIRouter(prefix="/api/portfolio", tags=["portfolio"]) + + +class HistoryPeriod(str, Enum): + ONE_DAY = "1d" + ONE_WEEK = "1w" + ONE_MONTH = "1m" + THREE_MONTHS = "3m" + ONE_YEAR = "1y" + + +def _period_to_timedelta(period: HistoryPeriod) -> timedelta: + """Convert a period enum value to a timedelta.""" + mapping = { + HistoryPeriod.ONE_DAY: timedelta(days=1), + HistoryPeriod.ONE_WEEK: timedelta(weeks=1), + HistoryPeriod.ONE_MONTH: timedelta(days=30), + HistoryPeriod.THREE_MONTHS: timedelta(days=90), + HistoryPeriod.ONE_YEAR: timedelta(days=365), + } + return mapping[period] + + +@router.get("") +async def get_portfolio( + request: Request, + _user: dict = Depends(get_current_user), +) -> dict: + """Current portfolio summary — value, cash, buying power, daily P&L.""" + from shared.models.timeseries import PortfolioSnapshot + + db = request.app.state.db_session_factory + async with db() as session: + latest = ( + await session.execute( + select(PortfolioSnapshot) + .order_by(desc(PortfolioSnapshot.timestamp)) + .limit(1) + ) + ).scalar_one_or_none() + + if latest is None: + return { + "total_value": 0.0, + "cash": 0.0, + "buying_power": 0.0, + "daily_pnl": 0.0, + } + + return { + "total_value": latest.total_value, + "cash": latest.cash, + "buying_power": latest.cash, + "daily_pnl": latest.daily_pnl, + } + + +@router.get("/positions") +async def get_positions( + request: Request, + _user: dict = Depends(get_current_user), +) -> list[dict]: + """All open positions with unrealized P&L.""" + from shared.models.trading import Position + + db = request.app.state.db_session_factory + async with db() as session: + result = await session.execute(select(Position)) + positions = result.scalars().all() + + return [ + { + "id": str(p.id), + "ticker": p.ticker, + "qty": p.qty, + "avg_entry": p.avg_entry, + "unrealized_pnl": p.unrealized_pnl or 0.0, + "stop_loss": p.stop_loss, + "take_profit": p.take_profit, + } + for p in positions + ] + + +@router.get("/history") +async def get_portfolio_history( + request: Request, + _user: dict = Depends(get_current_user), + period: HistoryPeriod = Query(default=HistoryPeriod.ONE_MONTH), +) -> list[dict]: + """Equity curve from portfolio_snapshots over a given period.""" + from shared.models.timeseries import PortfolioSnapshot + + since = datetime.now(timezone.utc) - _period_to_timedelta(period) + db = request.app.state.db_session_factory + async with db() as session: + result = await session.execute( + select(PortfolioSnapshot) + .where(PortfolioSnapshot.timestamp >= since) + .order_by(PortfolioSnapshot.timestamp) + ) + snapshots = result.scalars().all() + + return [ + { + "timestamp": s.timestamp.isoformat(), + "total_value": s.total_value, + "cash": s.cash, + "positions_value": s.positions_value, + "daily_pnl": s.daily_pnl, + } + for s in snapshots + ] diff --git a/services/api_gateway/routes/signals.py b/services/api_gateway/routes/signals.py new file mode 100644 index 0000000..8e63a16 --- /dev/null +++ b/services/api_gateway/routes/signals.py @@ -0,0 +1,58 @@ +"""Signal endpoints — recent signals with filtering.""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query, Request + +from services.api_gateway.auth.middleware import get_current_user +from sqlalchemy import select, desc, func + +router = APIRouter(prefix="/api/signals", tags=["signals"]) + + +@router.get("") +async def list_signals( + request: Request, + _user: dict = Depends(get_current_user), + ticker: str | None = Query(default=None), + page: int = Query(default=1, ge=1), + per_page: int = Query(default=20, ge=1, le=100), +) -> dict: + """Recent signals with optional ticker filter and pagination.""" + from shared.models.trading import Signal + + db = request.app.state.db_session_factory + async with db() as session: + query = select(Signal).order_by(desc(Signal.created_at)) + count_query = select(func.count()).select_from(Signal) + + if ticker: + query = query.where(Signal.ticker == ticker.upper()) + count_query = count_query.where(Signal.ticker == ticker.upper()) + + total = (await session.execute(count_query)).scalar() or 0 + offset = (page - 1) * per_page + query = query.offset(offset).limit(per_page) + + result = await session.execute(query) + signals = result.scalars().all() + + return { + "signals": [ + { + "id": str(s.id), + "ticker": s.ticker, + "direction": s.direction.value, + "strength": s.strength, + "strategy_sources": s.strategy_sources, + "sentiment_score": s.sentiment_score, + "acted_on": s.acted_on, + "created_at": s.created_at.isoformat() if s.created_at else None, + } + for s in signals + ], + "total": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page if per_page else 0, + } diff --git a/services/api_gateway/routes/strategies.py b/services/api_gateway/routes/strategies.py new file mode 100644 index 0000000..412ef1a --- /dev/null +++ b/services/api_gateway/routes/strategies.py @@ -0,0 +1,111 @@ +"""Strategy endpoints — list strategies, weight history, metrics.""" + +from __future__ import annotations + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Request, status + +from services.api_gateway.auth.middleware import get_current_user +from sqlalchemy import select, desc + +router = APIRouter(prefix="/api/strategies", tags=["strategies"]) + + +@router.get("") +async def list_strategies( + request: Request, + _user: dict = Depends(get_current_user), +) -> list[dict]: + """All strategies with current weights.""" + from shared.models.trading import Strategy + + db = request.app.state.db_session_factory + async with db() as session: + result = await session.execute(select(Strategy)) + strategies = result.scalars().all() + + return [ + { + "id": str(s.id), + "name": s.name, + "description": s.description, + "current_weight": s.current_weight, + "active": s.active, + "created_at": s.created_at.isoformat() if s.created_at else None, + } + for s in strategies + ] + + +@router.get("/{strategy_id}/history") +async def get_strategy_weight_history( + strategy_id: UUID, + request: Request, + _user: dict = Depends(get_current_user), +) -> list[dict]: + """Weight history for a specific strategy.""" + from shared.models.trading import StrategyWeightHistory, Strategy + + db = request.app.state.db_session_factory + async with db() as session: + # Verify strategy exists + strategy = ( + await session.execute( + select(Strategy).where(Strategy.id == strategy_id) + ) + ).scalar_one_or_none() + if strategy is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Strategy not found", + ) + + result = await session.execute( + select(StrategyWeightHistory) + .where(StrategyWeightHistory.strategy_id == strategy_id) + .order_by(desc(StrategyWeightHistory.created_at)) + ) + history = result.scalars().all() + + return [ + { + "id": str(h.id), + "old_weight": h.old_weight, + "new_weight": h.new_weight, + "reason": h.reason, + "created_at": h.created_at.isoformat() if h.created_at else None, + } + for h in history + ] + + +@router.get("/{strategy_id}/metrics") +async def get_strategy_metrics( + strategy_id: UUID, + request: Request, + _user: dict = Depends(get_current_user), +) -> list[dict]: + """Performance metrics over time for a specific strategy.""" + from shared.models.timeseries import StrategyMetric + + db = request.app.state.db_session_factory + async with db() as session: + result = await session.execute( + select(StrategyMetric) + .where(StrategyMetric.strategy_id == strategy_id) + .order_by(desc(StrategyMetric.timestamp)) + .limit(100) + ) + metrics = result.scalars().all() + + return [ + { + "timestamp": m.timestamp.isoformat(), + "win_rate": m.win_rate, + "total_pnl": m.total_pnl, + "trade_count": m.trade_count, + "sharpe_ratio": m.sharpe_ratio, + } + for m in metrics + ] diff --git a/services/api_gateway/routes/trades.py b/services/api_gateway/routes/trades.py new file mode 100644 index 0000000..00f1016 --- /dev/null +++ b/services/api_gateway/routes/trades.py @@ -0,0 +1,125 @@ +"""Trade endpoints — paginated trade history and detail.""" + +from __future__ import annotations + +from datetime import datetime +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status + +from services.api_gateway.auth.middleware import get_current_user +from sqlalchemy import select, desc, func + +router = APIRouter(prefix="/api/trades", tags=["trades"]) + + +@router.get("") +async def list_trades( + request: Request, + _user: dict = Depends(get_current_user), + ticker: str | None = Query(default=None), + start_date: datetime | None = Query(default=None), + end_date: datetime | None = Query(default=None), + strategy: str | None = Query(default=None), + profitable: bool | None = Query(default=None), + page: int = Query(default=1, ge=1), + per_page: int = Query(default=20, ge=1, le=100), +) -> dict: + """Paginated trade history with optional filters.""" + from shared.models.trading import Trade, Strategy + + db = request.app.state.db_session_factory + async with db() as session: + query = select(Trade).order_by(desc(Trade.created_at)) + count_query = select(func.count()).select_from(Trade) + + # Apply filters + if ticker: + query = query.where(Trade.ticker == ticker.upper()) + count_query = count_query.where(Trade.ticker == ticker.upper()) + if start_date: + query = query.where(Trade.created_at >= start_date) + count_query = count_query.where(Trade.created_at >= start_date) + if end_date: + query = query.where(Trade.created_at <= end_date) + count_query = count_query.where(Trade.created_at <= end_date) + if strategy: + # Join with Strategy to filter by name + query = query.join(Strategy, Trade.strategy_id == Strategy.id).where( + Strategy.name == strategy + ) + count_query = count_query.join( + Strategy, Trade.strategy_id == Strategy.id + ).where(Strategy.name == strategy) + if profitable is not None: + if profitable: + query = query.where(Trade.pnl > 0) + count_query = count_query.where(Trade.pnl > 0) + else: + query = query.where(Trade.pnl <= 0) + count_query = count_query.where(Trade.pnl <= 0) + + # Pagination + total = (await session.execute(count_query)).scalar() or 0 + offset = (page - 1) * per_page + query = query.offset(offset).limit(per_page) + + result = await session.execute(query) + trades = result.scalars().all() + + return { + "trades": [ + { + "id": str(t.id), + "ticker": t.ticker, + "side": t.side.value, + "qty": t.qty, + "price": t.price, + "status": t.status.value, + "pnl": t.pnl, + "strategy_id": str(t.strategy_id) if t.strategy_id else None, + "signal_id": str(t.signal_id) if t.signal_id else None, + "created_at": t.created_at.isoformat() if t.created_at else None, + } + for t in trades + ], + "total": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page if per_page else 0, + } + + +@router.get("/{trade_id}") +async def get_trade( + trade_id: UUID, + request: Request, + _user: dict = Depends(get_current_user), +) -> dict: + """Single trade detail with linked signal and outcome.""" + from shared.models.trading import Trade + + db = request.app.state.db_session_factory + async with db() as session: + trade = ( + await session.execute(select(Trade).where(Trade.id == trade_id)) + ).scalar_one_or_none() + + if trade is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Trade not found", + ) + + return { + "id": str(trade.id), + "ticker": trade.ticker, + "side": trade.side.value, + "qty": trade.qty, + "price": trade.price, + "status": trade.status.value, + "pnl": trade.pnl, + "strategy_id": str(trade.strategy_id) if trade.strategy_id else None, + "signal_id": str(trade.signal_id) if trade.signal_id else None, + "created_at": trade.created_at.isoformat() if trade.created_at else None, + } diff --git a/services/api_gateway/ws.py b/services/api_gateway/ws.py new file mode 100644 index 0000000..3c26aaf --- /dev/null +++ b/services/api_gateway/ws.py @@ -0,0 +1,138 @@ +"""WebSocket endpoint — authenticated real-time event stream.""" + +from __future__ import annotations + +import asyncio +import json +import logging + +from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect + +import jwt as pyjwt + +from services.api_gateway.auth.jwt import decode_token +from services.api_gateway.config import ApiGatewayConfig + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["websocket"]) + + +class ConnectionManager: + """Manages active WebSocket connections.""" + + def __init__(self) -> None: + self.active_connections: list[WebSocket] = [] + + async def connect(self, websocket: WebSocket) -> None: + await websocket.accept() + self.active_connections.append(websocket) + + def disconnect(self, websocket: WebSocket) -> None: + if websocket in self.active_connections: + self.active_connections.remove(websocket) + + async def broadcast(self, message: dict) -> None: + """Send a message to all connected clients. Remove any that have + disconnected.""" + disconnected: list[WebSocket] = [] + for connection in self.active_connections: + try: + await connection.send_json(message) + except Exception: + disconnected.append(connection) + for ws in disconnected: + self.disconnect(ws) + + +manager = ConnectionManager() + + +@router.websocket("/ws") +async def websocket_endpoint( + websocket: WebSocket, + token: str = Query(default=""), +) -> None: + """Authenticated WebSocket that pushes real-time trading events. + + Connect via ``ws://host/ws?token=``. The server subscribes + to Redis pub/sub channels and forwards events to all connected + clients. + + Event types pushed to clients: + - ``trade_executed`` + - ``signal_generated`` + - ``portfolio_update`` + """ + # Authenticate via JWT query parameter + config: ApiGatewayConfig = websocket.app.state.config + try: + payload = decode_token(token, config) + if payload.get("type") != "access": + await websocket.close(code=4001, reason="Invalid token type") + return + except pyjwt.ExpiredSignatureError: + await websocket.close(code=4001, reason="Token expired") + return + except pyjwt.InvalidTokenError: + await websocket.close(code=4001, reason="Invalid token") + return + + await manager.connect(websocket) + + # Subscribe to Redis pub/sub channels for real-time events + redis = websocket.app.state.redis + pubsub = redis.pubsub() + await pubsub.subscribe( + "events:trade_executed", + "events:signal_generated", + "events:portfolio_update", + ) + + try: + # Run two concurrent tasks: one reads from Redis pub/sub and pushes + # to the client, the other keeps the WebSocket alive by reading + # (and discarding) incoming messages. + + async def _redis_listener() -> None: + """Forward Redis pub/sub messages to this WebSocket client.""" + while True: + message = await pubsub.get_message( + ignore_subscribe_messages=True, timeout=1.0 + ) + if message and message["type"] == "message": + try: + data = json.loads(message["data"]) + except (json.JSONDecodeError, TypeError): + data = {"raw": str(message["data"])} + channel = message["channel"] + if isinstance(channel, bytes): + channel = channel.decode() + event_type = channel.replace("events:", "") + await websocket.send_json( + {"event": event_type, "data": data} + ) + await asyncio.sleep(0.1) + + async def _ws_receiver() -> None: + """Keep the connection alive by reading messages.""" + while True: + await websocket.receive_text() + + # Run both tasks; whichever finishes first (e.g., client disconnect) + # will cancel the other. + listener_task = asyncio.create_task(_redis_listener()) + receiver_task = asyncio.create_task(_ws_receiver()) + done, pending = await asyncio.wait( + {listener_task, receiver_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + + except WebSocketDisconnect: + pass + finally: + manager.disconnect(websocket) + await pubsub.unsubscribe() + await pubsub.aclose() diff --git a/tests/services/test_api_routes.py b/tests/services/test_api_routes.py new file mode 100644 index 0000000..575a4f6 --- /dev/null +++ b/tests/services/test_api_routes.py @@ -0,0 +1,386 @@ +"""Tests for API Gateway trading endpoints (Task 14). + +Uses FastAPI TestClient with mocked DB sessions and Redis. +""" + +from __future__ import annotations + +import json +import uuid +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from services.api_gateway.auth.jwt import create_access_token +from services.api_gateway.auth.middleware import get_config, get_current_user +from services.api_gateway.config import ApiGatewayConfig +from services.api_gateway.main import create_app + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def config() -> ApiGatewayConfig: + return ApiGatewayConfig( + jwt_secret_key="test-secret-for-routes", + database_url="sqlite+aiosqlite:///:memory:", + redis_url="redis://localhost:6379/0", + ) + + +@pytest.fixture() +def mock_user() -> dict: + return {"sub": "user-test", "username": "tester", "type": "access"} + + +@pytest.fixture() +def auth_headers(config: ApiGatewayConfig) -> dict[str, str]: + token = create_access_token("user-test", "tester", config) + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture() +def mock_redis() -> AsyncMock: + """A fully mocked Redis client.""" + redis = AsyncMock() + redis.get = AsyncMock(return_value=None) + redis.set = AsyncMock() + redis.setex = AsyncMock() + redis.delete = AsyncMock() + redis.publish = AsyncMock() + return redis + + +@pytest.fixture() +def mock_session_factory(): + """Creates a mock async session factory that returns a mock session.""" + session = AsyncMock() + session.__aenter__ = AsyncMock(return_value=session) + session.__aexit__ = AsyncMock(return_value=False) + + factory = MagicMock() + factory.return_value = session + return factory, session + + +@pytest.fixture() +def client( + config: ApiGatewayConfig, + mock_user: dict, + mock_redis: AsyncMock, + mock_session_factory, +) -> TestClient: + """Create a test client with all dependencies mocked.""" + factory, session = mock_session_factory + + app = create_app(config) + + # Override auth dependency to bypass JWT validation + app.dependency_overrides[get_current_user] = lambda: mock_user + app.dependency_overrides[get_config] = lambda: config + + # Inject mock state + app.state.redis = mock_redis + app.state.db_session_factory = factory + app.state.db_engine = MagicMock() + app.state.config = config + + return TestClient(app, raise_server_exceptions=False) + + +# --------------------------------------------------------------------------- +# Helper: build mock execute results +# --------------------------------------------------------------------------- + + +def _make_execute_result(rows, scalar=None): + """Build a mock result for session.execute().""" + result = MagicMock() + result.scalars.return_value.all.return_value = rows + result.scalars.return_value.__iter__ = lambda self: iter(rows) + result.scalar_one_or_none.return_value = scalar + result.scalar.return_value = len(rows) if scalar is None else scalar + result.all.return_value = rows + return result + + +# --------------------------------------------------------------------------- +# Portfolio Tests +# --------------------------------------------------------------------------- + + +class TestPortfolioEndpoint: + """test_portfolio_endpoint.""" + + def test_portfolio_returns_defaults_when_empty( + self, client: TestClient, mock_session_factory + ) -> None: + _, session = mock_session_factory + session.execute = AsyncMock( + return_value=_make_execute_result([], scalar=None) + ) + + resp = client.get("/api/portfolio") + assert resp.status_code == 200 + data = resp.json() + assert data["total_value"] == 0.0 + assert data["cash"] == 0.0 + assert data["daily_pnl"] == 0.0 + + +class TestPositionsEndpoint: + """test_positions_endpoint.""" + + def test_positions_returns_list( + self, client: TestClient, mock_session_factory + ) -> None: + _, session = mock_session_factory + + # Create mock positions + pos = MagicMock() + pos.id = uuid.uuid4() + pos.ticker = "AAPL" + pos.qty = 10.0 + pos.avg_entry = 150.0 + pos.unrealized_pnl = 50.0 + pos.stop_loss = 145.0 + pos.take_profit = 160.0 + + session.execute = AsyncMock( + return_value=_make_execute_result([pos]) + ) + + resp = client.get("/api/portfolio/positions") + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 1 + assert data[0]["ticker"] == "AAPL" + assert data[0]["qty"] == 10.0 + + +# --------------------------------------------------------------------------- +# Trades Tests +# --------------------------------------------------------------------------- + + +class TestTradesListEndpoint: + """test_trades_list_endpoint.""" + + def test_trades_returns_paginated_list( + self, client: TestClient, mock_session_factory + ) -> None: + _, session = mock_session_factory + + trade = MagicMock() + trade.id = uuid.uuid4() + trade.ticker = "TSLA" + trade.side.value = "BUY" + trade.qty = 5.0 + trade.price = 200.0 + trade.status.value = "FILLED" + trade.pnl = 25.0 + trade.strategy_id = None + trade.signal_id = None + trade.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc) + + # session.execute will be called twice: count + data + count_result = _make_execute_result([], scalar=1) + data_result = _make_execute_result([trade]) + session.execute = AsyncMock(side_effect=[count_result, data_result]) + + resp = client.get("/api/trades") + assert resp.status_code == 200 + data = resp.json() + assert "trades" in data + assert "total" in data + assert "page" in data + + +class TestTradesPagination: + """test_trades_pagination.""" + + def test_trades_page_and_per_page( + self, client: TestClient, mock_session_factory + ) -> None: + _, session = mock_session_factory + + count_result = _make_execute_result([], scalar=50) + data_result = _make_execute_result([]) + session.execute = AsyncMock(side_effect=[count_result, data_result]) + + resp = client.get("/api/trades?page=3&per_page=10") + assert resp.status_code == 200 + data = resp.json() + assert data["page"] == 3 + assert data["per_page"] == 10 + assert data["pages"] == 5 # 50 / 10 + + +# --------------------------------------------------------------------------- +# Strategies Tests +# --------------------------------------------------------------------------- + + +class TestStrategiesEndpoint: + """test_strategies_endpoint.""" + + def test_strategies_returns_list( + self, client: TestClient, mock_session_factory + ) -> None: + _, session = mock_session_factory + + strategy = MagicMock() + strategy.id = uuid.uuid4() + strategy.name = "momentum" + strategy.description = "Momentum strategy" + strategy.current_weight = 0.333 + strategy.active = True + strategy.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc) + + session.execute = AsyncMock( + return_value=_make_execute_result([strategy]) + ) + + resp = client.get("/api/strategies") + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 1 + assert data[0]["name"] == "momentum" + assert data[0]["current_weight"] == 0.333 + + +# --------------------------------------------------------------------------- +# News Tests +# --------------------------------------------------------------------------- + + +class TestNewsEndpoint: + """test_news_endpoint.""" + + def test_news_returns_paginated_articles( + self, client: TestClient, mock_session_factory + ) -> None: + _, session = mock_session_factory + + article = MagicMock() + article.id = uuid.uuid4() + article.source = "reuters" + article.url = "https://reuters.com/article/1" + article.title = "Stock rises" + article.published_at = datetime(2024, 1, 1, tzinfo=timezone.utc) + article.fetched_at = datetime(2024, 1, 1, tzinfo=timezone.utc) + + sentiment = MagicMock() + sentiment.ticker = "AAPL" + sentiment.score = 0.8 + sentiment.confidence = 0.9 + sentiment.model_used = "finbert" + + count_result = _make_execute_result([], scalar=1) + data_result = MagicMock() + data_result.all.return_value = [(article, sentiment)] + session.execute = AsyncMock(side_effect=[count_result, data_result]) + + resp = client.get("/api/news") + assert resp.status_code == 200 + data = resp.json() + assert "articles" in data + assert data["total"] == 1 + + +# --------------------------------------------------------------------------- +# Controls Tests +# --------------------------------------------------------------------------- + + +class TestControlsPauseResume: + """test_controls_pause_resume.""" + + def test_pause_sets_redis_key( + self, client: TestClient, mock_redis: AsyncMock + ) -> None: + resp = client.post("/api/controls/pause") + assert resp.status_code == 200 + assert resp.json()["status"] == "paused" + mock_redis.set.assert_called_once_with("trading:paused", "1") + + def test_resume_clears_redis_key( + self, client: TestClient, mock_redis: AsyncMock + ) -> None: + resp = client.post("/api/controls/resume") + assert resp.status_code == 200 + assert resp.json()["status"] == "active" + mock_redis.delete.assert_called_once_with("trading:paused") + + +class TestControlsStatus: + """test_controls_status.""" + + def test_status_active_when_not_paused( + self, client: TestClient, mock_redis: AsyncMock + ) -> None: + mock_redis.get = AsyncMock(return_value=None) + resp = client.get("/api/controls/status") + assert resp.status_code == 200 + assert resp.json()["status"] == "active" + + def test_status_paused_when_flag_set( + self, client: TestClient, mock_redis: AsyncMock + ) -> None: + mock_redis.get = AsyncMock(return_value="1") + resp = client.get("/api/controls/status") + assert resp.status_code == 200 + assert resp.json()["status"] == "paused" + + +# --------------------------------------------------------------------------- +# Backtest Tests +# --------------------------------------------------------------------------- + + +class TestBacktestRunEndpoint: + """test_backtest_run_endpoint.""" + + def test_backtest_run_returns_run_id( + self, client: TestClient, mock_redis: AsyncMock + ) -> None: + resp = client.post( + "/api/backtest/run", + json={ + "start_date": "2024-01-01T00:00:00Z", + "end_date": "2024-06-01T00:00:00Z", + "initial_capital": 100000, + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert "run_id" in data + assert data["status"] == "running" + mock_redis.setex.assert_called() + + def test_backtest_get_not_found( + self, client: TestClient, mock_redis: AsyncMock + ) -> None: + mock_redis.get = AsyncMock(return_value=None) + resp = client.get("/api/backtest/nonexistent-id") + assert resp.status_code == 404 + + def test_backtest_get_returns_result( + self, client: TestClient, mock_redis: AsyncMock + ) -> None: + stored = json.dumps({ + "status": "completed", + "result": {"total_return": 0.15, "sharpe_ratio": 1.2}, + }) + mock_redis.get = AsyncMock(return_value=stored) + + resp = client.get("/api/backtest/some-run-id") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "completed" + assert data["result"]["total_return"] == 0.15