feat: API gateway trading endpoints, controls, backtest, WebSocket
This commit is contained in:
parent
e0d138c457
commit
6fe586f722
11 changed files with 1304 additions and 0 deletions
|
|
@ -69,6 +69,28 @@ def create_app(config: ApiGatewayConfig | None = None) -> FastAPI:
|
|||
# Auth routes (unauthenticated)
|
||||
app.include_router(auth_router)
|
||||
|
||||
# Trading routes (authenticated) — imported lazily to avoid circular deps
|
||||
from services.api_gateway.routes.portfolio import router as portfolio_router
|
||||
from services.api_gateway.routes.trades import router as trades_router
|
||||
from services.api_gateway.routes.signals import router as signals_router
|
||||
from services.api_gateway.routes.strategies import router as strategies_router
|
||||
from services.api_gateway.routes.news import router as news_router
|
||||
from services.api_gateway.routes.controls import router as controls_router
|
||||
from services.api_gateway.routes.backtest import router as backtest_router
|
||||
|
||||
app.include_router(portfolio_router)
|
||||
app.include_router(trades_router)
|
||||
app.include_router(signals_router)
|
||||
app.include_router(strategies_router)
|
||||
app.include_router(news_router)
|
||||
app.include_router(controls_router)
|
||||
app.include_router(backtest_router)
|
||||
|
||||
# WebSocket
|
||||
from services.api_gateway.ws import router as ws_router
|
||||
|
||||
app.include_router(ws_router)
|
||||
|
||||
# Health check
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
|
|
|
|||
1
services/api_gateway/routes/__init__.py
Normal file
1
services/api_gateway/routes/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Trading route sub-package."""
|
||||
156
services/api_gateway/routes/backtest.py
Normal file
156
services/api_gateway/routes/backtest.py
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
"""Backtest endpoints — run backtests and retrieve results."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
|
||||
|
||||
|
||||
class BacktestRequest(BaseModel):
|
||||
"""Request body for starting a new backtest."""
|
||||
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
initial_capital: float = Field(default=100_000.0, gt=0)
|
||||
commission_per_trade: float = Field(default=0.0, ge=0)
|
||||
slippage_pct: float = Field(default=0.001, ge=0)
|
||||
strategy_weights: dict[str, float] = Field(default_factory=dict)
|
||||
max_position_pct: float = Field(default=0.05, gt=0, le=1.0)
|
||||
signal_threshold: float = Field(default=0.3, ge=0, le=1.0)
|
||||
|
||||
|
||||
@router.post("/run")
|
||||
async def run_backtest(
|
||||
body: BacktestRequest,
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Start a backtest with the given configuration.
|
||||
|
||||
Returns a ``run_id`` immediately. The backtest runs in a background
|
||||
task and its results can be retrieved via ``GET /api/backtest/{run_id}``.
|
||||
"""
|
||||
run_id = str(uuid.uuid4())
|
||||
redis = request.app.state.redis
|
||||
|
||||
# Store initial status
|
||||
await redis.setex(
|
||||
f"backtest:{run_id}",
|
||||
86400, # 24h TTL
|
||||
json.dumps({
|
||||
"status": "running",
|
||||
"config": body.model_dump(mode="json"),
|
||||
"started_at": datetime.now(tz=timezone.utc).isoformat(),
|
||||
}),
|
||||
)
|
||||
|
||||
# Launch background task
|
||||
asyncio.create_task(_run_backtest_task(run_id, body, redis))
|
||||
|
||||
return {"run_id": run_id, "status": "running"}
|
||||
|
||||
|
||||
async def _run_backtest_task(
|
||||
run_id: str,
|
||||
config: BacktestRequest,
|
||||
redis,
|
||||
) -> None:
|
||||
"""Execute the backtest in the background and store results in Redis."""
|
||||
try:
|
||||
from backtester.config import BacktestConfig
|
||||
from backtester.engine import BacktestEngine
|
||||
from shared.strategies.momentum import MomentumStrategy
|
||||
from shared.strategies.mean_reversion import MeanReversionStrategy
|
||||
from shared.strategies.news_driven import NewsDrivenStrategy
|
||||
|
||||
bt_config = BacktestConfig(
|
||||
start_date=config.start_date,
|
||||
end_date=config.end_date,
|
||||
initial_capital=config.initial_capital,
|
||||
commission_per_trade=config.commission_per_trade,
|
||||
slippage_pct=config.slippage_pct,
|
||||
strategy_weights=config.strategy_weights,
|
||||
max_position_pct=config.max_position_pct,
|
||||
signal_threshold=config.signal_threshold,
|
||||
)
|
||||
|
||||
strategies = [
|
||||
MomentumStrategy(),
|
||||
MeanReversionStrategy(),
|
||||
NewsDrivenStrategy(),
|
||||
]
|
||||
|
||||
engine = BacktestEngine(config=bt_config, strategies=strategies)
|
||||
|
||||
# Use an in-memory stub data loader for now; a full implementation
|
||||
# would read from TimescaleDB.
|
||||
from backtester.data_loader import BacktestDataLoader
|
||||
|
||||
data_loader = BacktestDataLoader(
|
||||
session_factory=None,
|
||||
start_date=config.start_date,
|
||||
end_date=config.end_date,
|
||||
)
|
||||
|
||||
result = await engine.run(data_loader)
|
||||
|
||||
await redis.setex(
|
||||
f"backtest:{run_id}",
|
||||
86400,
|
||||
json.dumps({
|
||||
"status": "completed",
|
||||
"config": config.model_dump(mode="json"),
|
||||
"result": {
|
||||
"total_return": result.total_return,
|
||||
"annualized_return": result.annualized_return,
|
||||
"sharpe_ratio": result.sharpe_ratio,
|
||||
"sortino_ratio": result.sortino_ratio,
|
||||
"max_drawdown": result.max_drawdown,
|
||||
"max_drawdown_duration_days": result.max_drawdown_duration_days,
|
||||
"win_rate": result.win_rate,
|
||||
"trade_count": result.trade_count,
|
||||
"avg_hold_duration_hours": result.avg_hold_duration_hours,
|
||||
"profit_factor": result.profit_factor,
|
||||
},
|
||||
"completed_at": datetime.now(tz=timezone.utc).isoformat(),
|
||||
}),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Backtest %s failed", run_id)
|
||||
await redis.setex(
|
||||
f"backtest:{run_id}",
|
||||
86400,
|
||||
json.dumps({
|
||||
"status": "failed",
|
||||
"error": str(exc),
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{run_id}")
|
||||
async def get_backtest(
|
||||
run_id: str,
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Get backtest results by run ID."""
|
||||
redis = request.app.state.redis
|
||||
raw = await redis.get(f"backtest:{run_id}")
|
||||
if raw is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Backtest run not found",
|
||||
)
|
||||
return json.loads(raw)
|
||||
95
services/api_gateway/routes/controls.py
Normal file
95
services/api_gateway/routes/controls.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""Control endpoints — pause/resume trading, force close positions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/controls", tags=["controls"])
|
||||
|
||||
TRADING_PAUSED_KEY = "trading:paused"
|
||||
|
||||
|
||||
class ClosePositionRequest(BaseModel):
|
||||
"""Body for the force-close-position endpoint."""
|
||||
|
||||
ticker: str
|
||||
|
||||
|
||||
@router.post("/pause")
|
||||
async def pause_trading(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Set Redis flag to pause trading."""
|
||||
redis = request.app.state.redis
|
||||
await redis.set(TRADING_PAUSED_KEY, "1")
|
||||
return {"status": "paused"}
|
||||
|
||||
|
||||
@router.post("/resume")
|
||||
async def resume_trading(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Clear pause flag to resume trading."""
|
||||
redis = request.app.state.redis
|
||||
await redis.delete(TRADING_PAUSED_KEY)
|
||||
return {"status": "active"}
|
||||
|
||||
|
||||
@router.post("/close-position")
|
||||
async def close_position(
|
||||
body: ClosePositionRequest,
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Force close an open position by ticker.
|
||||
|
||||
Publishes a close-position command to Redis so the trade executor
|
||||
picks it up asynchronously.
|
||||
"""
|
||||
import json
|
||||
|
||||
from sqlalchemy import select
|
||||
from shared.models.trading import Position
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
position = (
|
||||
await session.execute(
|
||||
select(Position).where(Position.ticker == body.ticker.upper())
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if position is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No open position for {body.ticker}",
|
||||
)
|
||||
|
||||
# Publish close command to Redis for the trade executor
|
||||
redis = request.app.state.redis
|
||||
await redis.publish(
|
||||
"controls:close_position",
|
||||
json.dumps({"ticker": body.ticker.upper(), "qty": position.qty}),
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "close_requested",
|
||||
"ticker": body.ticker.upper(),
|
||||
"qty": position.qty,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_trading_status(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Current trading status — active or paused."""
|
||||
redis = request.app.state.redis
|
||||
paused = await redis.get(TRADING_PAUSED_KEY)
|
||||
return {"status": "paused" if paused else "active"}
|
||||
87
services/api_gateway/routes/news.py
Normal file
87
services/api_gateway/routes/news.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""News endpoints — recent scored articles with filtering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
from sqlalchemy import select, desc, func
|
||||
|
||||
router = APIRouter(prefix="/api/news", tags=["news"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_news(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
ticker: str | None = Query(default=None),
|
||||
source: str | None = Query(default=None),
|
||||
min_score: float | None = Query(default=None, ge=-1.0, le=1.0),
|
||||
max_score: float | None = Query(default=None, ge=-1.0, le=1.0),
|
||||
page: int = Query(default=1, ge=1),
|
||||
per_page: int = Query(default=20, ge=1, le=100),
|
||||
) -> dict:
|
||||
"""Recent scored articles with optional filters."""
|
||||
from shared.models.news import Article, ArticleSentiment
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
# Base query joining articles with sentiments
|
||||
query = (
|
||||
select(Article, ArticleSentiment)
|
||||
.join(ArticleSentiment, Article.id == ArticleSentiment.article_id)
|
||||
.order_by(desc(Article.fetched_at))
|
||||
)
|
||||
count_query = (
|
||||
select(func.count())
|
||||
.select_from(Article)
|
||||
.join(ArticleSentiment, Article.id == ArticleSentiment.article_id)
|
||||
)
|
||||
|
||||
if ticker:
|
||||
query = query.where(ArticleSentiment.ticker == ticker.upper())
|
||||
count_query = count_query.where(
|
||||
ArticleSentiment.ticker == ticker.upper()
|
||||
)
|
||||
if source:
|
||||
query = query.where(Article.source == source)
|
||||
count_query = count_query.where(Article.source == source)
|
||||
if min_score is not None:
|
||||
query = query.where(ArticleSentiment.score >= min_score)
|
||||
count_query = count_query.where(ArticleSentiment.score >= min_score)
|
||||
if max_score is not None:
|
||||
query = query.where(ArticleSentiment.score <= max_score)
|
||||
count_query = count_query.where(ArticleSentiment.score <= max_score)
|
||||
|
||||
total = (await session.execute(count_query)).scalar() or 0
|
||||
offset = (page - 1) * per_page
|
||||
query = query.offset(offset).limit(per_page)
|
||||
|
||||
result = await session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
return {
|
||||
"articles": [
|
||||
{
|
||||
"id": str(article.id),
|
||||
"source": article.source,
|
||||
"url": article.url,
|
||||
"title": article.title,
|
||||
"published_at": (
|
||||
article.published_at.isoformat()
|
||||
if article.published_at
|
||||
else None
|
||||
),
|
||||
"fetched_at": article.fetched_at.isoformat(),
|
||||
"ticker": sentiment.ticker,
|
||||
"sentiment_score": sentiment.score,
|
||||
"confidence": sentiment.confidence,
|
||||
"model_used": sentiment.model_used,
|
||||
}
|
||||
for article, sentiment in rows
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||
}
|
||||
125
services/api_gateway/routes/portfolio.py
Normal file
125
services/api_gateway/routes/portfolio.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
"""Portfolio endpoints — current value, positions, equity curve."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
from sqlalchemy import select, desc
|
||||
|
||||
router = APIRouter(prefix="/api/portfolio", tags=["portfolio"])
|
||||
|
||||
|
||||
class HistoryPeriod(str, Enum):
|
||||
ONE_DAY = "1d"
|
||||
ONE_WEEK = "1w"
|
||||
ONE_MONTH = "1m"
|
||||
THREE_MONTHS = "3m"
|
||||
ONE_YEAR = "1y"
|
||||
|
||||
|
||||
def _period_to_timedelta(period: HistoryPeriod) -> timedelta:
|
||||
"""Convert a period enum value to a timedelta."""
|
||||
mapping = {
|
||||
HistoryPeriod.ONE_DAY: timedelta(days=1),
|
||||
HistoryPeriod.ONE_WEEK: timedelta(weeks=1),
|
||||
HistoryPeriod.ONE_MONTH: timedelta(days=30),
|
||||
HistoryPeriod.THREE_MONTHS: timedelta(days=90),
|
||||
HistoryPeriod.ONE_YEAR: timedelta(days=365),
|
||||
}
|
||||
return mapping[period]
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_portfolio(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Current portfolio summary — value, cash, buying power, daily P&L."""
|
||||
from shared.models.timeseries import PortfolioSnapshot
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
latest = (
|
||||
await session.execute(
|
||||
select(PortfolioSnapshot)
|
||||
.order_by(desc(PortfolioSnapshot.timestamp))
|
||||
.limit(1)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if latest is None:
|
||||
return {
|
||||
"total_value": 0.0,
|
||||
"cash": 0.0,
|
||||
"buying_power": 0.0,
|
||||
"daily_pnl": 0.0,
|
||||
}
|
||||
|
||||
return {
|
||||
"total_value": latest.total_value,
|
||||
"cash": latest.cash,
|
||||
"buying_power": latest.cash,
|
||||
"daily_pnl": latest.daily_pnl,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/positions")
|
||||
async def get_positions(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""All open positions with unrealized P&L."""
|
||||
from shared.models.trading import Position
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
result = await session.execute(select(Position))
|
||||
positions = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(p.id),
|
||||
"ticker": p.ticker,
|
||||
"qty": p.qty,
|
||||
"avg_entry": p.avg_entry,
|
||||
"unrealized_pnl": p.unrealized_pnl or 0.0,
|
||||
"stop_loss": p.stop_loss,
|
||||
"take_profit": p.take_profit,
|
||||
}
|
||||
for p in positions
|
||||
]
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_portfolio_history(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
period: HistoryPeriod = Query(default=HistoryPeriod.ONE_MONTH),
|
||||
) -> list[dict]:
|
||||
"""Equity curve from portfolio_snapshots over a given period."""
|
||||
from shared.models.timeseries import PortfolioSnapshot
|
||||
|
||||
since = datetime.now(timezone.utc) - _period_to_timedelta(period)
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
result = await session.execute(
|
||||
select(PortfolioSnapshot)
|
||||
.where(PortfolioSnapshot.timestamp >= since)
|
||||
.order_by(PortfolioSnapshot.timestamp)
|
||||
)
|
||||
snapshots = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"timestamp": s.timestamp.isoformat(),
|
||||
"total_value": s.total_value,
|
||||
"cash": s.cash,
|
||||
"positions_value": s.positions_value,
|
||||
"daily_pnl": s.daily_pnl,
|
||||
}
|
||||
for s in snapshots
|
||||
]
|
||||
58
services/api_gateway/routes/signals.py
Normal file
58
services/api_gateway/routes/signals.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Signal endpoints — recent signals with filtering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
from sqlalchemy import select, desc, func
|
||||
|
||||
router = APIRouter(prefix="/api/signals", tags=["signals"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_signals(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
ticker: str | None = Query(default=None),
|
||||
page: int = Query(default=1, ge=1),
|
||||
per_page: int = Query(default=20, ge=1, le=100),
|
||||
) -> dict:
|
||||
"""Recent signals with optional ticker filter and pagination."""
|
||||
from shared.models.trading import Signal
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
query = select(Signal).order_by(desc(Signal.created_at))
|
||||
count_query = select(func.count()).select_from(Signal)
|
||||
|
||||
if ticker:
|
||||
query = query.where(Signal.ticker == ticker.upper())
|
||||
count_query = count_query.where(Signal.ticker == ticker.upper())
|
||||
|
||||
total = (await session.execute(count_query)).scalar() or 0
|
||||
offset = (page - 1) * per_page
|
||||
query = query.offset(offset).limit(per_page)
|
||||
|
||||
result = await session.execute(query)
|
||||
signals = result.scalars().all()
|
||||
|
||||
return {
|
||||
"signals": [
|
||||
{
|
||||
"id": str(s.id),
|
||||
"ticker": s.ticker,
|
||||
"direction": s.direction.value,
|
||||
"strength": s.strength,
|
||||
"strategy_sources": s.strategy_sources,
|
||||
"sentiment_score": s.sentiment_score,
|
||||
"acted_on": s.acted_on,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in signals
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||
}
|
||||
111
services/api_gateway/routes/strategies.py
Normal file
111
services/api_gateway/routes/strategies.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Strategy endpoints — list strategies, weight history, metrics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
from sqlalchemy import select, desc
|
||||
|
||||
router = APIRouter(prefix="/api/strategies", tags=["strategies"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_strategies(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""All strategies with current weights."""
|
||||
from shared.models.trading import Strategy
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
result = await session.execute(select(Strategy))
|
||||
strategies = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(s.id),
|
||||
"name": s.name,
|
||||
"description": s.description,
|
||||
"current_weight": s.current_weight,
|
||||
"active": s.active,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in strategies
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{strategy_id}/history")
|
||||
async def get_strategy_weight_history(
|
||||
strategy_id: UUID,
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""Weight history for a specific strategy."""
|
||||
from shared.models.trading import StrategyWeightHistory, Strategy
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
# Verify strategy exists
|
||||
strategy = (
|
||||
await session.execute(
|
||||
select(Strategy).where(Strategy.id == strategy_id)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if strategy is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Strategy not found",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(StrategyWeightHistory)
|
||||
.where(StrategyWeightHistory.strategy_id == strategy_id)
|
||||
.order_by(desc(StrategyWeightHistory.created_at))
|
||||
)
|
||||
history = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(h.id),
|
||||
"old_weight": h.old_weight,
|
||||
"new_weight": h.new_weight,
|
||||
"reason": h.reason,
|
||||
"created_at": h.created_at.isoformat() if h.created_at else None,
|
||||
}
|
||||
for h in history
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{strategy_id}/metrics")
|
||||
async def get_strategy_metrics(
|
||||
strategy_id: UUID,
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""Performance metrics over time for a specific strategy."""
|
||||
from shared.models.timeseries import StrategyMetric
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
result = await session.execute(
|
||||
select(StrategyMetric)
|
||||
.where(StrategyMetric.strategy_id == strategy_id)
|
||||
.order_by(desc(StrategyMetric.timestamp))
|
||||
.limit(100)
|
||||
)
|
||||
metrics = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"timestamp": m.timestamp.isoformat(),
|
||||
"win_rate": m.win_rate,
|
||||
"total_pnl": m.total_pnl,
|
||||
"trade_count": m.trade_count,
|
||||
"sharpe_ratio": m.sharpe_ratio,
|
||||
}
|
||||
for m in metrics
|
||||
]
|
||||
125
services/api_gateway/routes/trades.py
Normal file
125
services/api_gateway/routes/trades.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
"""Trade endpoints — paginated trade history and detail."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
from sqlalchemy import select, desc, func
|
||||
|
||||
router = APIRouter(prefix="/api/trades", tags=["trades"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_trades(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
ticker: str | None = Query(default=None),
|
||||
start_date: datetime | None = Query(default=None),
|
||||
end_date: datetime | None = Query(default=None),
|
||||
strategy: str | None = Query(default=None),
|
||||
profitable: bool | None = Query(default=None),
|
||||
page: int = Query(default=1, ge=1),
|
||||
per_page: int = Query(default=20, ge=1, le=100),
|
||||
) -> dict:
|
||||
"""Paginated trade history with optional filters."""
|
||||
from shared.models.trading import Trade, Strategy
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
query = select(Trade).order_by(desc(Trade.created_at))
|
||||
count_query = select(func.count()).select_from(Trade)
|
||||
|
||||
# Apply filters
|
||||
if ticker:
|
||||
query = query.where(Trade.ticker == ticker.upper())
|
||||
count_query = count_query.where(Trade.ticker == ticker.upper())
|
||||
if start_date:
|
||||
query = query.where(Trade.created_at >= start_date)
|
||||
count_query = count_query.where(Trade.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.where(Trade.created_at <= end_date)
|
||||
count_query = count_query.where(Trade.created_at <= end_date)
|
||||
if strategy:
|
||||
# Join with Strategy to filter by name
|
||||
query = query.join(Strategy, Trade.strategy_id == Strategy.id).where(
|
||||
Strategy.name == strategy
|
||||
)
|
||||
count_query = count_query.join(
|
||||
Strategy, Trade.strategy_id == Strategy.id
|
||||
).where(Strategy.name == strategy)
|
||||
if profitable is not None:
|
||||
if profitable:
|
||||
query = query.where(Trade.pnl > 0)
|
||||
count_query = count_query.where(Trade.pnl > 0)
|
||||
else:
|
||||
query = query.where(Trade.pnl <= 0)
|
||||
count_query = count_query.where(Trade.pnl <= 0)
|
||||
|
||||
# Pagination
|
||||
total = (await session.execute(count_query)).scalar() or 0
|
||||
offset = (page - 1) * per_page
|
||||
query = query.offset(offset).limit(per_page)
|
||||
|
||||
result = await session.execute(query)
|
||||
trades = result.scalars().all()
|
||||
|
||||
return {
|
||||
"trades": [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"ticker": t.ticker,
|
||||
"side": t.side.value,
|
||||
"qty": t.qty,
|
||||
"price": t.price,
|
||||
"status": t.status.value,
|
||||
"pnl": t.pnl,
|
||||
"strategy_id": str(t.strategy_id) if t.strategy_id else None,
|
||||
"signal_id": str(t.signal_id) if t.signal_id else None,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in trades
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{trade_id}")
|
||||
async def get_trade(
|
||||
trade_id: UUID,
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Single trade detail with linked signal and outcome."""
|
||||
from shared.models.trading import Trade
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
trade = (
|
||||
await session.execute(select(Trade).where(Trade.id == trade_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if trade is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Trade not found",
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(trade.id),
|
||||
"ticker": trade.ticker,
|
||||
"side": trade.side.value,
|
||||
"qty": trade.qty,
|
||||
"price": trade.price,
|
||||
"status": trade.status.value,
|
||||
"pnl": trade.pnl,
|
||||
"strategy_id": str(trade.strategy_id) if trade.strategy_id else None,
|
||||
"signal_id": str(trade.signal_id) if trade.signal_id else None,
|
||||
"created_at": trade.created_at.isoformat() if trade.created_at else None,
|
||||
}
|
||||
138
services/api_gateway/ws.py
Normal file
138
services/api_gateway/ws.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""WebSocket endpoint — authenticated real-time event stream."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
from services.api_gateway.auth.jwt import decode_token
|
||||
from services.api_gateway.config import ApiGatewayConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["websocket"])
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages active WebSocket connections."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: list[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket) -> None:
|
||||
if websocket in self.active_connections:
|
||||
self.active_connections.remove(websocket)
|
||||
|
||||
async def broadcast(self, message: dict) -> None:
|
||||
"""Send a message to all connected clients. Remove any that have
|
||||
disconnected."""
|
||||
disconnected: list[WebSocket] = []
|
||||
for connection in self.active_connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception:
|
||||
disconnected.append(connection)
|
||||
for ws in disconnected:
|
||||
self.disconnect(ws)
|
||||
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(default=""),
|
||||
) -> None:
|
||||
"""Authenticated WebSocket that pushes real-time trading events.
|
||||
|
||||
Connect via ``ws://host/ws?token=<JWT>``. The server subscribes
|
||||
to Redis pub/sub channels and forwards events to all connected
|
||||
clients.
|
||||
|
||||
Event types pushed to clients:
|
||||
- ``trade_executed``
|
||||
- ``signal_generated``
|
||||
- ``portfolio_update``
|
||||
"""
|
||||
# Authenticate via JWT query parameter
|
||||
config: ApiGatewayConfig = websocket.app.state.config
|
||||
try:
|
||||
payload = decode_token(token, config)
|
||||
if payload.get("type") != "access":
|
||||
await websocket.close(code=4001, reason="Invalid token type")
|
||||
return
|
||||
except pyjwt.ExpiredSignatureError:
|
||||
await websocket.close(code=4001, reason="Token expired")
|
||||
return
|
||||
except pyjwt.InvalidTokenError:
|
||||
await websocket.close(code=4001, reason="Invalid token")
|
||||
return
|
||||
|
||||
await manager.connect(websocket)
|
||||
|
||||
# Subscribe to Redis pub/sub channels for real-time events
|
||||
redis = websocket.app.state.redis
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(
|
||||
"events:trade_executed",
|
||||
"events:signal_generated",
|
||||
"events:portfolio_update",
|
||||
)
|
||||
|
||||
try:
|
||||
# Run two concurrent tasks: one reads from Redis pub/sub and pushes
|
||||
# to the client, the other keeps the WebSocket alive by reading
|
||||
# (and discarding) incoming messages.
|
||||
|
||||
async def _redis_listener() -> None:
|
||||
"""Forward Redis pub/sub messages to this WebSocket client."""
|
||||
while True:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if message and message["type"] == "message":
|
||||
try:
|
||||
data = json.loads(message["data"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
data = {"raw": str(message["data"])}
|
||||
channel = message["channel"]
|
||||
if isinstance(channel, bytes):
|
||||
channel = channel.decode()
|
||||
event_type = channel.replace("events:", "")
|
||||
await websocket.send_json(
|
||||
{"event": event_type, "data": data}
|
||||
)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _ws_receiver() -> None:
|
||||
"""Keep the connection alive by reading messages."""
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
|
||||
# Run both tasks; whichever finishes first (e.g., client disconnect)
|
||||
# will cancel the other.
|
||||
listener_task = asyncio.create_task(_redis_listener())
|
||||
receiver_task = asyncio.create_task(_ws_receiver())
|
||||
done, pending = await asyncio.wait(
|
||||
{listener_task, receiver_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
manager.disconnect(websocket)
|
||||
await pubsub.unsubscribe()
|
||||
await pubsub.aclose()
|
||||
Loading…
Add table
Add a link
Reference in a new issue