feat: API gateway trading endpoints, controls, backtest, WebSocket
This commit is contained in:
parent
e0d138c457
commit
6fe586f722
11 changed files with 1304 additions and 0 deletions
|
|
@ -69,6 +69,28 @@ def create_app(config: ApiGatewayConfig | None = None) -> FastAPI:
|
||||||
# Auth routes (unauthenticated)
|
# Auth routes (unauthenticated)
|
||||||
app.include_router(auth_router)
|
app.include_router(auth_router)
|
||||||
|
|
||||||
|
# Trading routes (authenticated) — imported lazily to avoid circular deps
|
||||||
|
from services.api_gateway.routes.portfolio import router as portfolio_router
|
||||||
|
from services.api_gateway.routes.trades import router as trades_router
|
||||||
|
from services.api_gateway.routes.signals import router as signals_router
|
||||||
|
from services.api_gateway.routes.strategies import router as strategies_router
|
||||||
|
from services.api_gateway.routes.news import router as news_router
|
||||||
|
from services.api_gateway.routes.controls import router as controls_router
|
||||||
|
from services.api_gateway.routes.backtest import router as backtest_router
|
||||||
|
|
||||||
|
app.include_router(portfolio_router)
|
||||||
|
app.include_router(trades_router)
|
||||||
|
app.include_router(signals_router)
|
||||||
|
app.include_router(strategies_router)
|
||||||
|
app.include_router(news_router)
|
||||||
|
app.include_router(controls_router)
|
||||||
|
app.include_router(backtest_router)
|
||||||
|
|
||||||
|
# WebSocket
|
||||||
|
from services.api_gateway.ws import router as ws_router
|
||||||
|
|
||||||
|
app.include_router(ws_router)
|
||||||
|
|
||||||
# Health check
|
# Health check
|
||||||
@app.get("/health", tags=["health"])
|
@app.get("/health", tags=["health"])
|
||||||
async def health() -> dict:
|
async def health() -> dict:
|
||||||
|
|
|
||||||
1
services/api_gateway/routes/__init__.py
Normal file
1
services/api_gateway/routes/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Trading route sub-package."""
|
||||||
156
services/api_gateway/routes/backtest.py
Normal file
156
services/api_gateway/routes/backtest.py
Normal file
|
|
@ -0,0 +1,156 @@
|
||||||
|
"""Backtest endpoints — run backtests and retrieve results."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from services.api_gateway.auth.middleware import get_current_user
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
|
||||||
|
|
||||||
|
|
||||||
|
class BacktestRequest(BaseModel):
|
||||||
|
"""Request body for starting a new backtest."""
|
||||||
|
|
||||||
|
start_date: datetime
|
||||||
|
end_date: datetime
|
||||||
|
initial_capital: float = Field(default=100_000.0, gt=0)
|
||||||
|
commission_per_trade: float = Field(default=0.0, ge=0)
|
||||||
|
slippage_pct: float = Field(default=0.001, ge=0)
|
||||||
|
strategy_weights: dict[str, float] = Field(default_factory=dict)
|
||||||
|
max_position_pct: float = Field(default=0.05, gt=0, le=1.0)
|
||||||
|
signal_threshold: float = Field(default=0.3, ge=0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/run")
|
||||||
|
async def run_backtest(
|
||||||
|
body: BacktestRequest,
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Start a backtest with the given configuration.
|
||||||
|
|
||||||
|
Returns a ``run_id`` immediately. The backtest runs in a background
|
||||||
|
task and its results can be retrieved via ``GET /api/backtest/{run_id}``.
|
||||||
|
"""
|
||||||
|
run_id = str(uuid.uuid4())
|
||||||
|
redis = request.app.state.redis
|
||||||
|
|
||||||
|
# Store initial status
|
||||||
|
await redis.setex(
|
||||||
|
f"backtest:{run_id}",
|
||||||
|
86400, # 24h TTL
|
||||||
|
json.dumps({
|
||||||
|
"status": "running",
|
||||||
|
"config": body.model_dump(mode="json"),
|
||||||
|
"started_at": datetime.now(tz=timezone.utc).isoformat(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Launch background task
|
||||||
|
asyncio.create_task(_run_backtest_task(run_id, body, redis))
|
||||||
|
|
||||||
|
return {"run_id": run_id, "status": "running"}
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_backtest_task(
|
||||||
|
run_id: str,
|
||||||
|
config: BacktestRequest,
|
||||||
|
redis,
|
||||||
|
) -> None:
|
||||||
|
"""Execute the backtest in the background and store results in Redis."""
|
||||||
|
try:
|
||||||
|
from backtester.config import BacktestConfig
|
||||||
|
from backtester.engine import BacktestEngine
|
||||||
|
from shared.strategies.momentum import MomentumStrategy
|
||||||
|
from shared.strategies.mean_reversion import MeanReversionStrategy
|
||||||
|
from shared.strategies.news_driven import NewsDrivenStrategy
|
||||||
|
|
||||||
|
bt_config = BacktestConfig(
|
||||||
|
start_date=config.start_date,
|
||||||
|
end_date=config.end_date,
|
||||||
|
initial_capital=config.initial_capital,
|
||||||
|
commission_per_trade=config.commission_per_trade,
|
||||||
|
slippage_pct=config.slippage_pct,
|
||||||
|
strategy_weights=config.strategy_weights,
|
||||||
|
max_position_pct=config.max_position_pct,
|
||||||
|
signal_threshold=config.signal_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
strategies = [
|
||||||
|
MomentumStrategy(),
|
||||||
|
MeanReversionStrategy(),
|
||||||
|
NewsDrivenStrategy(),
|
||||||
|
]
|
||||||
|
|
||||||
|
engine = BacktestEngine(config=bt_config, strategies=strategies)
|
||||||
|
|
||||||
|
# Use an in-memory stub data loader for now; a full implementation
|
||||||
|
# would read from TimescaleDB.
|
||||||
|
from backtester.data_loader import BacktestDataLoader
|
||||||
|
|
||||||
|
data_loader = BacktestDataLoader(
|
||||||
|
session_factory=None,
|
||||||
|
start_date=config.start_date,
|
||||||
|
end_date=config.end_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await engine.run(data_loader)
|
||||||
|
|
||||||
|
await redis.setex(
|
||||||
|
f"backtest:{run_id}",
|
||||||
|
86400,
|
||||||
|
json.dumps({
|
||||||
|
"status": "completed",
|
||||||
|
"config": config.model_dump(mode="json"),
|
||||||
|
"result": {
|
||||||
|
"total_return": result.total_return,
|
||||||
|
"annualized_return": result.annualized_return,
|
||||||
|
"sharpe_ratio": result.sharpe_ratio,
|
||||||
|
"sortino_ratio": result.sortino_ratio,
|
||||||
|
"max_drawdown": result.max_drawdown,
|
||||||
|
"max_drawdown_duration_days": result.max_drawdown_duration_days,
|
||||||
|
"win_rate": result.win_rate,
|
||||||
|
"trade_count": result.trade_count,
|
||||||
|
"avg_hold_duration_hours": result.avg_hold_duration_hours,
|
||||||
|
"profit_factor": result.profit_factor,
|
||||||
|
},
|
||||||
|
"completed_at": datetime.now(tz=timezone.utc).isoformat(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Backtest %s failed", run_id)
|
||||||
|
await redis.setex(
|
||||||
|
f"backtest:{run_id}",
|
||||||
|
86400,
|
||||||
|
json.dumps({
|
||||||
|
"status": "failed",
|
||||||
|
"error": str(exc),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{run_id}")
|
||||||
|
async def get_backtest(
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Get backtest results by run ID."""
|
||||||
|
redis = request.app.state.redis
|
||||||
|
raw = await redis.get(f"backtest:{run_id}")
|
||||||
|
if raw is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Backtest run not found",
|
||||||
|
)
|
||||||
|
return json.loads(raw)
|
||||||
95
services/api_gateway/routes/controls.py
Normal file
95
services/api_gateway/routes/controls.py
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
"""Control endpoints — pause/resume trading, force close positions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from services.api_gateway.auth.middleware import get_current_user
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/controls", tags=["controls"])
|
||||||
|
|
||||||
|
TRADING_PAUSED_KEY = "trading:paused"
|
||||||
|
|
||||||
|
|
||||||
|
class ClosePositionRequest(BaseModel):
|
||||||
|
"""Body for the force-close-position endpoint."""
|
||||||
|
|
||||||
|
ticker: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/pause")
|
||||||
|
async def pause_trading(
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Set Redis flag to pause trading."""
|
||||||
|
redis = request.app.state.redis
|
||||||
|
await redis.set(TRADING_PAUSED_KEY, "1")
|
||||||
|
return {"status": "paused"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/resume")
|
||||||
|
async def resume_trading(
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Clear pause flag to resume trading."""
|
||||||
|
redis = request.app.state.redis
|
||||||
|
await redis.delete(TRADING_PAUSED_KEY)
|
||||||
|
return {"status": "active"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/close-position")
|
||||||
|
async def close_position(
|
||||||
|
body: ClosePositionRequest,
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Force close an open position by ticker.
|
||||||
|
|
||||||
|
Publishes a close-position command to Redis so the trade executor
|
||||||
|
picks it up asynchronously.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from shared.models.trading import Position
|
||||||
|
|
||||||
|
db = request.app.state.db_session_factory
|
||||||
|
async with db() as session:
|
||||||
|
position = (
|
||||||
|
await session.execute(
|
||||||
|
select(Position).where(Position.ticker == body.ticker.upper())
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if position is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"No open position for {body.ticker}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Publish close command to Redis for the trade executor
|
||||||
|
redis = request.app.state.redis
|
||||||
|
await redis.publish(
|
||||||
|
"controls:close_position",
|
||||||
|
json.dumps({"ticker": body.ticker.upper(), "qty": position.qty}),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "close_requested",
|
||||||
|
"ticker": body.ticker.upper(),
|
||||||
|
"qty": position.qty,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status")
|
||||||
|
async def get_trading_status(
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Current trading status — active or paused."""
|
||||||
|
redis = request.app.state.redis
|
||||||
|
paused = await redis.get(TRADING_PAUSED_KEY)
|
||||||
|
return {"status": "paused" if paused else "active"}
|
||||||
87
services/api_gateway/routes/news.py
Normal file
87
services/api_gateway/routes/news.py
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
"""News endpoints — recent scored articles with filtering."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
|
|
||||||
|
from services.api_gateway.auth.middleware import get_current_user
|
||||||
|
from sqlalchemy import select, desc, func
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/news", tags=["news"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_news(
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
ticker: str | None = Query(default=None),
|
||||||
|
source: str | None = Query(default=None),
|
||||||
|
min_score: float | None = Query(default=None, ge=-1.0, le=1.0),
|
||||||
|
max_score: float | None = Query(default=None, ge=-1.0, le=1.0),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
per_page: int = Query(default=20, ge=1, le=100),
|
||||||
|
) -> dict:
|
||||||
|
"""Recent scored articles with optional filters."""
|
||||||
|
from shared.models.news import Article, ArticleSentiment
|
||||||
|
|
||||||
|
db = request.app.state.db_session_factory
|
||||||
|
async with db() as session:
|
||||||
|
# Base query joining articles with sentiments
|
||||||
|
query = (
|
||||||
|
select(Article, ArticleSentiment)
|
||||||
|
.join(ArticleSentiment, Article.id == ArticleSentiment.article_id)
|
||||||
|
.order_by(desc(Article.fetched_at))
|
||||||
|
)
|
||||||
|
count_query = (
|
||||||
|
select(func.count())
|
||||||
|
.select_from(Article)
|
||||||
|
.join(ArticleSentiment, Article.id == ArticleSentiment.article_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if ticker:
|
||||||
|
query = query.where(ArticleSentiment.ticker == ticker.upper())
|
||||||
|
count_query = count_query.where(
|
||||||
|
ArticleSentiment.ticker == ticker.upper()
|
||||||
|
)
|
||||||
|
if source:
|
||||||
|
query = query.where(Article.source == source)
|
||||||
|
count_query = count_query.where(Article.source == source)
|
||||||
|
if min_score is not None:
|
||||||
|
query = query.where(ArticleSentiment.score >= min_score)
|
||||||
|
count_query = count_query.where(ArticleSentiment.score >= min_score)
|
||||||
|
if max_score is not None:
|
||||||
|
query = query.where(ArticleSentiment.score <= max_score)
|
||||||
|
count_query = count_query.where(ArticleSentiment.score <= max_score)
|
||||||
|
|
||||||
|
total = (await session.execute(count_query)).scalar() or 0
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
query = query.offset(offset).limit(per_page)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
rows = result.all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"articles": [
|
||||||
|
{
|
||||||
|
"id": str(article.id),
|
||||||
|
"source": article.source,
|
||||||
|
"url": article.url,
|
||||||
|
"title": article.title,
|
||||||
|
"published_at": (
|
||||||
|
article.published_at.isoformat()
|
||||||
|
if article.published_at
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"fetched_at": article.fetched_at.isoformat(),
|
||||||
|
"ticker": sentiment.ticker,
|
||||||
|
"sentiment_score": sentiment.score,
|
||||||
|
"confidence": sentiment.confidence,
|
||||||
|
"model_used": sentiment.model_used,
|
||||||
|
}
|
||||||
|
for article, sentiment in rows
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||||
|
}
|
||||||
125
services/api_gateway/routes/portfolio.py
Normal file
125
services/api_gateway/routes/portfolio.py
Normal file
|
|
@ -0,0 +1,125 @@
|
||||||
|
"""Portfolio endpoints — current value, positions, equity curve."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
|
|
||||||
|
from services.api_gateway.auth.middleware import get_current_user
|
||||||
|
from sqlalchemy import select, desc
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/portfolio", tags=["portfolio"])
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryPeriod(str, Enum):
|
||||||
|
ONE_DAY = "1d"
|
||||||
|
ONE_WEEK = "1w"
|
||||||
|
ONE_MONTH = "1m"
|
||||||
|
THREE_MONTHS = "3m"
|
||||||
|
ONE_YEAR = "1y"
|
||||||
|
|
||||||
|
|
||||||
|
def _period_to_timedelta(period: HistoryPeriod) -> timedelta:
|
||||||
|
"""Convert a period enum value to a timedelta."""
|
||||||
|
mapping = {
|
||||||
|
HistoryPeriod.ONE_DAY: timedelta(days=1),
|
||||||
|
HistoryPeriod.ONE_WEEK: timedelta(weeks=1),
|
||||||
|
HistoryPeriod.ONE_MONTH: timedelta(days=30),
|
||||||
|
HistoryPeriod.THREE_MONTHS: timedelta(days=90),
|
||||||
|
HistoryPeriod.ONE_YEAR: timedelta(days=365),
|
||||||
|
}
|
||||||
|
return mapping[period]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def get_portfolio(
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Current portfolio summary — value, cash, buying power, daily P&L."""
|
||||||
|
from shared.models.timeseries import PortfolioSnapshot
|
||||||
|
|
||||||
|
db = request.app.state.db_session_factory
|
||||||
|
async with db() as session:
|
||||||
|
latest = (
|
||||||
|
await session.execute(
|
||||||
|
select(PortfolioSnapshot)
|
||||||
|
.order_by(desc(PortfolioSnapshot.timestamp))
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if latest is None:
|
||||||
|
return {
|
||||||
|
"total_value": 0.0,
|
||||||
|
"cash": 0.0,
|
||||||
|
"buying_power": 0.0,
|
||||||
|
"daily_pnl": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_value": latest.total_value,
|
||||||
|
"cash": latest.cash,
|
||||||
|
"buying_power": latest.cash,
|
||||||
|
"daily_pnl": latest.daily_pnl,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/positions")
|
||||||
|
async def get_positions(
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
) -> list[dict]:
|
||||||
|
"""All open positions with unrealized P&L."""
|
||||||
|
from shared.models.trading import Position
|
||||||
|
|
||||||
|
db = request.app.state.db_session_factory
|
||||||
|
async with db() as session:
|
||||||
|
result = await session.execute(select(Position))
|
||||||
|
positions = result.scalars().all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": str(p.id),
|
||||||
|
"ticker": p.ticker,
|
||||||
|
"qty": p.qty,
|
||||||
|
"avg_entry": p.avg_entry,
|
||||||
|
"unrealized_pnl": p.unrealized_pnl or 0.0,
|
||||||
|
"stop_loss": p.stop_loss,
|
||||||
|
"take_profit": p.take_profit,
|
||||||
|
}
|
||||||
|
for p in positions
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history")
|
||||||
|
async def get_portfolio_history(
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
period: HistoryPeriod = Query(default=HistoryPeriod.ONE_MONTH),
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Equity curve from portfolio_snapshots over a given period."""
|
||||||
|
from shared.models.timeseries import PortfolioSnapshot
|
||||||
|
|
||||||
|
since = datetime.now(timezone.utc) - _period_to_timedelta(period)
|
||||||
|
db = request.app.state.db_session_factory
|
||||||
|
async with db() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(PortfolioSnapshot)
|
||||||
|
.where(PortfolioSnapshot.timestamp >= since)
|
||||||
|
.order_by(PortfolioSnapshot.timestamp)
|
||||||
|
)
|
||||||
|
snapshots = result.scalars().all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"timestamp": s.timestamp.isoformat(),
|
||||||
|
"total_value": s.total_value,
|
||||||
|
"cash": s.cash,
|
||||||
|
"positions_value": s.positions_value,
|
||||||
|
"daily_pnl": s.daily_pnl,
|
||||||
|
}
|
||||||
|
for s in snapshots
|
||||||
|
]
|
||||||
58
services/api_gateway/routes/signals.py
Normal file
58
services/api_gateway/routes/signals.py
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
"""Signal endpoints — recent signals with filtering."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
|
|
||||||
|
from services.api_gateway.auth.middleware import get_current_user
|
||||||
|
from sqlalchemy import select, desc, func
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/signals", tags=["signals"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_signals(
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
ticker: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
per_page: int = Query(default=20, ge=1, le=100),
|
||||||
|
) -> dict:
|
||||||
|
"""Recent signals with optional ticker filter and pagination."""
|
||||||
|
from shared.models.trading import Signal
|
||||||
|
|
||||||
|
db = request.app.state.db_session_factory
|
||||||
|
async with db() as session:
|
||||||
|
query = select(Signal).order_by(desc(Signal.created_at))
|
||||||
|
count_query = select(func.count()).select_from(Signal)
|
||||||
|
|
||||||
|
if ticker:
|
||||||
|
query = query.where(Signal.ticker == ticker.upper())
|
||||||
|
count_query = count_query.where(Signal.ticker == ticker.upper())
|
||||||
|
|
||||||
|
total = (await session.execute(count_query)).scalar() or 0
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
query = query.offset(offset).limit(per_page)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
signals = result.scalars().all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"signals": [
|
||||||
|
{
|
||||||
|
"id": str(s.id),
|
||||||
|
"ticker": s.ticker,
|
||||||
|
"direction": s.direction.value,
|
||||||
|
"strength": s.strength,
|
||||||
|
"strategy_sources": s.strategy_sources,
|
||||||
|
"sentiment_score": s.sentiment_score,
|
||||||
|
"acted_on": s.acted_on,
|
||||||
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||||
|
}
|
||||||
|
for s in signals
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||||
|
}
|
||||||
111
services/api_gateway/routes/strategies.py
Normal file
111
services/api_gateway/routes/strategies.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""Strategy endpoints — list strategies, weight history, metrics."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
|
||||||
|
from services.api_gateway.auth.middleware import get_current_user
|
||||||
|
from sqlalchemy import select, desc
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/strategies", tags=["strategies"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_strategies(
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
) -> list[dict]:
|
||||||
|
"""All strategies with current weights."""
|
||||||
|
from shared.models.trading import Strategy
|
||||||
|
|
||||||
|
db = request.app.state.db_session_factory
|
||||||
|
async with db() as session:
|
||||||
|
result = await session.execute(select(Strategy))
|
||||||
|
strategies = result.scalars().all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": str(s.id),
|
||||||
|
"name": s.name,
|
||||||
|
"description": s.description,
|
||||||
|
"current_weight": s.current_weight,
|
||||||
|
"active": s.active,
|
||||||
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||||
|
}
|
||||||
|
for s in strategies
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{strategy_id}/history")
|
||||||
|
async def get_strategy_weight_history(
|
||||||
|
strategy_id: UUID,
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Weight history for a specific strategy."""
|
||||||
|
from shared.models.trading import StrategyWeightHistory, Strategy
|
||||||
|
|
||||||
|
db = request.app.state.db_session_factory
|
||||||
|
async with db() as session:
|
||||||
|
# Verify strategy exists
|
||||||
|
strategy = (
|
||||||
|
await session.execute(
|
||||||
|
select(Strategy).where(Strategy.id == strategy_id)
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
if strategy is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Strategy not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await session.execute(
|
||||||
|
select(StrategyWeightHistory)
|
||||||
|
.where(StrategyWeightHistory.strategy_id == strategy_id)
|
||||||
|
.order_by(desc(StrategyWeightHistory.created_at))
|
||||||
|
)
|
||||||
|
history = result.scalars().all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": str(h.id),
|
||||||
|
"old_weight": h.old_weight,
|
||||||
|
"new_weight": h.new_weight,
|
||||||
|
"reason": h.reason,
|
||||||
|
"created_at": h.created_at.isoformat() if h.created_at else None,
|
||||||
|
}
|
||||||
|
for h in history
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{strategy_id}/metrics")
|
||||||
|
async def get_strategy_metrics(
|
||||||
|
strategy_id: UUID,
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Performance metrics over time for a specific strategy."""
|
||||||
|
from shared.models.timeseries import StrategyMetric
|
||||||
|
|
||||||
|
db = request.app.state.db_session_factory
|
||||||
|
async with db() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(StrategyMetric)
|
||||||
|
.where(StrategyMetric.strategy_id == strategy_id)
|
||||||
|
.order_by(desc(StrategyMetric.timestamp))
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
metrics = result.scalars().all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"timestamp": m.timestamp.isoformat(),
|
||||||
|
"win_rate": m.win_rate,
|
||||||
|
"total_pnl": m.total_pnl,
|
||||||
|
"trade_count": m.trade_count,
|
||||||
|
"sharpe_ratio": m.sharpe_ratio,
|
||||||
|
}
|
||||||
|
for m in metrics
|
||||||
|
]
|
||||||
125
services/api_gateway/routes/trades.py
Normal file
125
services/api_gateway/routes/trades.py
Normal file
|
|
@ -0,0 +1,125 @@
|
||||||
|
"""Trade endpoints — paginated trade history and detail."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||||
|
|
||||||
|
from services.api_gateway.auth.middleware import get_current_user
|
||||||
|
from sqlalchemy import select, desc, func
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/trades", tags=["trades"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_trades(
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
ticker: str | None = Query(default=None),
|
||||||
|
start_date: datetime | None = Query(default=None),
|
||||||
|
end_date: datetime | None = Query(default=None),
|
||||||
|
strategy: str | None = Query(default=None),
|
||||||
|
profitable: bool | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
per_page: int = Query(default=20, ge=1, le=100),
|
||||||
|
) -> dict:
|
||||||
|
"""Paginated trade history with optional filters."""
|
||||||
|
from shared.models.trading import Trade, Strategy
|
||||||
|
|
||||||
|
db = request.app.state.db_session_factory
|
||||||
|
async with db() as session:
|
||||||
|
query = select(Trade).order_by(desc(Trade.created_at))
|
||||||
|
count_query = select(func.count()).select_from(Trade)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if ticker:
|
||||||
|
query = query.where(Trade.ticker == ticker.upper())
|
||||||
|
count_query = count_query.where(Trade.ticker == ticker.upper())
|
||||||
|
if start_date:
|
||||||
|
query = query.where(Trade.created_at >= start_date)
|
||||||
|
count_query = count_query.where(Trade.created_at >= start_date)
|
||||||
|
if end_date:
|
||||||
|
query = query.where(Trade.created_at <= end_date)
|
||||||
|
count_query = count_query.where(Trade.created_at <= end_date)
|
||||||
|
if strategy:
|
||||||
|
# Join with Strategy to filter by name
|
||||||
|
query = query.join(Strategy, Trade.strategy_id == Strategy.id).where(
|
||||||
|
Strategy.name == strategy
|
||||||
|
)
|
||||||
|
count_query = count_query.join(
|
||||||
|
Strategy, Trade.strategy_id == Strategy.id
|
||||||
|
).where(Strategy.name == strategy)
|
||||||
|
if profitable is not None:
|
||||||
|
if profitable:
|
||||||
|
query = query.where(Trade.pnl > 0)
|
||||||
|
count_query = count_query.where(Trade.pnl > 0)
|
||||||
|
else:
|
||||||
|
query = query.where(Trade.pnl <= 0)
|
||||||
|
count_query = count_query.where(Trade.pnl <= 0)
|
||||||
|
|
||||||
|
# Pagination
|
||||||
|
total = (await session.execute(count_query)).scalar() or 0
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
query = query.offset(offset).limit(per_page)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
trades = result.scalars().all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"trades": [
|
||||||
|
{
|
||||||
|
"id": str(t.id),
|
||||||
|
"ticker": t.ticker,
|
||||||
|
"side": t.side.value,
|
||||||
|
"qty": t.qty,
|
||||||
|
"price": t.price,
|
||||||
|
"status": t.status.value,
|
||||||
|
"pnl": t.pnl,
|
||||||
|
"strategy_id": str(t.strategy_id) if t.strategy_id else None,
|
||||||
|
"signal_id": str(t.signal_id) if t.signal_id else None,
|
||||||
|
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||||
|
}
|
||||||
|
for t in trades
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{trade_id}")
|
||||||
|
async def get_trade(
|
||||||
|
trade_id: UUID,
|
||||||
|
request: Request,
|
||||||
|
_user: dict = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Single trade detail with linked signal and outcome."""
|
||||||
|
from shared.models.trading import Trade
|
||||||
|
|
||||||
|
db = request.app.state.db_session_factory
|
||||||
|
async with db() as session:
|
||||||
|
trade = (
|
||||||
|
await session.execute(select(Trade).where(Trade.id == trade_id))
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if trade is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Trade not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(trade.id),
|
||||||
|
"ticker": trade.ticker,
|
||||||
|
"side": trade.side.value,
|
||||||
|
"qty": trade.qty,
|
||||||
|
"price": trade.price,
|
||||||
|
"status": trade.status.value,
|
||||||
|
"pnl": trade.pnl,
|
||||||
|
"strategy_id": str(trade.strategy_id) if trade.strategy_id else None,
|
||||||
|
"signal_id": str(trade.signal_id) if trade.signal_id else None,
|
||||||
|
"created_at": trade.created_at.isoformat() if trade.created_at else None,
|
||||||
|
}
|
||||||
138
services/api_gateway/ws.py
Normal file
138
services/api_gateway/ws.py
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
"""WebSocket endpoint — authenticated real-time event stream."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||||
|
|
||||||
|
import jwt as pyjwt
|
||||||
|
|
||||||
|
from services.api_gateway.auth.jwt import decode_token
|
||||||
|
from services.api_gateway.config import ApiGatewayConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["websocket"])
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionManager:
|
||||||
|
"""Manages active WebSocket connections."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.active_connections: list[WebSocket] = []
|
||||||
|
|
||||||
|
async def connect(self, websocket: WebSocket) -> None:
|
||||||
|
await websocket.accept()
|
||||||
|
self.active_connections.append(websocket)
|
||||||
|
|
||||||
|
def disconnect(self, websocket: WebSocket) -> None:
|
||||||
|
if websocket in self.active_connections:
|
||||||
|
self.active_connections.remove(websocket)
|
||||||
|
|
||||||
|
async def broadcast(self, message: dict) -> None:
|
||||||
|
"""Send a message to all connected clients. Remove any that have
|
||||||
|
disconnected."""
|
||||||
|
disconnected: list[WebSocket] = []
|
||||||
|
for connection in self.active_connections:
|
||||||
|
try:
|
||||||
|
await connection.send_json(message)
|
||||||
|
except Exception:
|
||||||
|
disconnected.append(connection)
|
||||||
|
for ws in disconnected:
|
||||||
|
self.disconnect(ws)
|
||||||
|
|
||||||
|
|
||||||
|
manager = ConnectionManager()
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/ws")
|
||||||
|
async def websocket_endpoint(
|
||||||
|
websocket: WebSocket,
|
||||||
|
token: str = Query(default=""),
|
||||||
|
) -> None:
|
||||||
|
"""Authenticated WebSocket that pushes real-time trading events.
|
||||||
|
|
||||||
|
Connect via ``ws://host/ws?token=<JWT>``. The server subscribes
|
||||||
|
to Redis pub/sub channels and forwards events to all connected
|
||||||
|
clients.
|
||||||
|
|
||||||
|
Event types pushed to clients:
|
||||||
|
- ``trade_executed``
|
||||||
|
- ``signal_generated``
|
||||||
|
- ``portfolio_update``
|
||||||
|
"""
|
||||||
|
# Authenticate via JWT query parameter
|
||||||
|
config: ApiGatewayConfig = websocket.app.state.config
|
||||||
|
try:
|
||||||
|
payload = decode_token(token, config)
|
||||||
|
if payload.get("type") != "access":
|
||||||
|
await websocket.close(code=4001, reason="Invalid token type")
|
||||||
|
return
|
||||||
|
except pyjwt.ExpiredSignatureError:
|
||||||
|
await websocket.close(code=4001, reason="Token expired")
|
||||||
|
return
|
||||||
|
except pyjwt.InvalidTokenError:
|
||||||
|
await websocket.close(code=4001, reason="Invalid token")
|
||||||
|
return
|
||||||
|
|
||||||
|
await manager.connect(websocket)
|
||||||
|
|
||||||
|
# Subscribe to Redis pub/sub channels for real-time events
|
||||||
|
redis = websocket.app.state.redis
|
||||||
|
pubsub = redis.pubsub()
|
||||||
|
await pubsub.subscribe(
|
||||||
|
"events:trade_executed",
|
||||||
|
"events:signal_generated",
|
||||||
|
"events:portfolio_update",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run two concurrent tasks: one reads from Redis pub/sub and pushes
|
||||||
|
# to the client, the other keeps the WebSocket alive by reading
|
||||||
|
# (and discarding) incoming messages.
|
||||||
|
|
||||||
|
async def _redis_listener() -> None:
|
||||||
|
"""Forward Redis pub/sub messages to this WebSocket client."""
|
||||||
|
while True:
|
||||||
|
message = await pubsub.get_message(
|
||||||
|
ignore_subscribe_messages=True, timeout=1.0
|
||||||
|
)
|
||||||
|
if message and message["type"] == "message":
|
||||||
|
try:
|
||||||
|
data = json.loads(message["data"])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
data = {"raw": str(message["data"])}
|
||||||
|
channel = message["channel"]
|
||||||
|
if isinstance(channel, bytes):
|
||||||
|
channel = channel.decode()
|
||||||
|
event_type = channel.replace("events:", "")
|
||||||
|
await websocket.send_json(
|
||||||
|
{"event": event_type, "data": data}
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
async def _ws_receiver() -> None:
|
||||||
|
"""Keep the connection alive by reading messages."""
|
||||||
|
while True:
|
||||||
|
await websocket.receive_text()
|
||||||
|
|
||||||
|
# Run both tasks; whichever finishes first (e.g., client disconnect)
|
||||||
|
# will cancel the other.
|
||||||
|
listener_task = asyncio.create_task(_redis_listener())
|
||||||
|
receiver_task = asyncio.create_task(_ws_receiver())
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
{listener_task, receiver_task},
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
manager.disconnect(websocket)
|
||||||
|
await pubsub.unsubscribe()
|
||||||
|
await pubsub.aclose()
|
||||||
386
tests/services/test_api_routes.py
Normal file
386
tests/services/test_api_routes.py
Normal file
|
|
@ -0,0 +1,386 @@
|
||||||
|
"""Tests for API Gateway trading endpoints (Task 14).
|
||||||
|
|
||||||
|
Uses FastAPI TestClient with mocked DB sessions and Redis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from services.api_gateway.auth.jwt import create_access_token
|
||||||
|
from services.api_gateway.auth.middleware import get_config, get_current_user
|
||||||
|
from services.api_gateway.config import ApiGatewayConfig
|
||||||
|
from services.api_gateway.main import create_app
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def config() -> ApiGatewayConfig:
|
||||||
|
return ApiGatewayConfig(
|
||||||
|
jwt_secret_key="test-secret-for-routes",
|
||||||
|
database_url="sqlite+aiosqlite:///:memory:",
|
||||||
|
redis_url="redis://localhost:6379/0",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_user() -> dict:
|
||||||
|
return {"sub": "user-test", "username": "tester", "type": "access"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def auth_headers(config: ApiGatewayConfig) -> dict[str, str]:
|
||||||
|
token = create_access_token("user-test", "tester", config)
|
||||||
|
return {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_redis() -> AsyncMock:
|
||||||
|
"""A fully mocked Redis client."""
|
||||||
|
redis = AsyncMock()
|
||||||
|
redis.get = AsyncMock(return_value=None)
|
||||||
|
redis.set = AsyncMock()
|
||||||
|
redis.setex = AsyncMock()
|
||||||
|
redis.delete = AsyncMock()
|
||||||
|
redis.publish = AsyncMock()
|
||||||
|
return redis
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_session_factory():
|
||||||
|
"""Creates a mock async session factory that returns a mock session."""
|
||||||
|
session = AsyncMock()
|
||||||
|
session.__aenter__ = AsyncMock(return_value=session)
|
||||||
|
session.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
factory = MagicMock()
|
||||||
|
factory.return_value = session
|
||||||
|
return factory, session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def client(
|
||||||
|
config: ApiGatewayConfig,
|
||||||
|
mock_user: dict,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
mock_session_factory,
|
||||||
|
) -> TestClient:
|
||||||
|
"""Create a test client with all dependencies mocked."""
|
||||||
|
factory, session = mock_session_factory
|
||||||
|
|
||||||
|
app = create_app(config)
|
||||||
|
|
||||||
|
# Override auth dependency to bypass JWT validation
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||||
|
app.dependency_overrides[get_config] = lambda: config
|
||||||
|
|
||||||
|
# Inject mock state
|
||||||
|
app.state.redis = mock_redis
|
||||||
|
app.state.db_session_factory = factory
|
||||||
|
app.state.db_engine = MagicMock()
|
||||||
|
app.state.config = config
|
||||||
|
|
||||||
|
return TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper: build mock execute results
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_execute_result(rows, scalar=None):
|
||||||
|
"""Build a mock result for session.execute()."""
|
||||||
|
result = MagicMock()
|
||||||
|
result.scalars.return_value.all.return_value = rows
|
||||||
|
result.scalars.return_value.__iter__ = lambda self: iter(rows)
|
||||||
|
result.scalar_one_or_none.return_value = scalar
|
||||||
|
result.scalar.return_value = len(rows) if scalar is None else scalar
|
||||||
|
result.all.return_value = rows
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Portfolio Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPortfolioEndpoint:
|
||||||
|
"""test_portfolio_endpoint."""
|
||||||
|
|
||||||
|
def test_portfolio_returns_defaults_when_empty(
|
||||||
|
self, client: TestClient, mock_session_factory
|
||||||
|
) -> None:
|
||||||
|
_, session = mock_session_factory
|
||||||
|
session.execute = AsyncMock(
|
||||||
|
return_value=_make_execute_result([], scalar=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = client.get("/api/portfolio")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["total_value"] == 0.0
|
||||||
|
assert data["cash"] == 0.0
|
||||||
|
assert data["daily_pnl"] == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestPositionsEndpoint:
|
||||||
|
"""test_positions_endpoint."""
|
||||||
|
|
||||||
|
def test_positions_returns_list(
|
||||||
|
self, client: TestClient, mock_session_factory
|
||||||
|
) -> None:
|
||||||
|
_, session = mock_session_factory
|
||||||
|
|
||||||
|
# Create mock positions
|
||||||
|
pos = MagicMock()
|
||||||
|
pos.id = uuid.uuid4()
|
||||||
|
pos.ticker = "AAPL"
|
||||||
|
pos.qty = 10.0
|
||||||
|
pos.avg_entry = 150.0
|
||||||
|
pos.unrealized_pnl = 50.0
|
||||||
|
pos.stop_loss = 145.0
|
||||||
|
pos.take_profit = 160.0
|
||||||
|
|
||||||
|
session.execute = AsyncMock(
|
||||||
|
return_value=_make_execute_result([pos])
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = client.get("/api/portfolio/positions")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
assert data[0]["ticker"] == "AAPL"
|
||||||
|
assert data[0]["qty"] == 10.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Trades Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTradesListEndpoint:
|
||||||
|
"""test_trades_list_endpoint."""
|
||||||
|
|
||||||
|
def test_trades_returns_paginated_list(
|
||||||
|
self, client: TestClient, mock_session_factory
|
||||||
|
) -> None:
|
||||||
|
_, session = mock_session_factory
|
||||||
|
|
||||||
|
trade = MagicMock()
|
||||||
|
trade.id = uuid.uuid4()
|
||||||
|
trade.ticker = "TSLA"
|
||||||
|
trade.side.value = "BUY"
|
||||||
|
trade.qty = 5.0
|
||||||
|
trade.price = 200.0
|
||||||
|
trade.status.value = "FILLED"
|
||||||
|
trade.pnl = 25.0
|
||||||
|
trade.strategy_id = None
|
||||||
|
trade.signal_id = None
|
||||||
|
trade.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
# session.execute will be called twice: count + data
|
||||||
|
count_result = _make_execute_result([], scalar=1)
|
||||||
|
data_result = _make_execute_result([trade])
|
||||||
|
session.execute = AsyncMock(side_effect=[count_result, data_result])
|
||||||
|
|
||||||
|
resp = client.get("/api/trades")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "trades" in data
|
||||||
|
assert "total" in data
|
||||||
|
assert "page" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestTradesPagination:
|
||||||
|
"""test_trades_pagination."""
|
||||||
|
|
||||||
|
def test_trades_page_and_per_page(
|
||||||
|
self, client: TestClient, mock_session_factory
|
||||||
|
) -> None:
|
||||||
|
_, session = mock_session_factory
|
||||||
|
|
||||||
|
count_result = _make_execute_result([], scalar=50)
|
||||||
|
data_result = _make_execute_result([])
|
||||||
|
session.execute = AsyncMock(side_effect=[count_result, data_result])
|
||||||
|
|
||||||
|
resp = client.get("/api/trades?page=3&per_page=10")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["page"] == 3
|
||||||
|
assert data["per_page"] == 10
|
||||||
|
assert data["pages"] == 5 # 50 / 10
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Strategies Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestStrategiesEndpoint:
|
||||||
|
"""test_strategies_endpoint."""
|
||||||
|
|
||||||
|
def test_strategies_returns_list(
|
||||||
|
self, client: TestClient, mock_session_factory
|
||||||
|
) -> None:
|
||||||
|
_, session = mock_session_factory
|
||||||
|
|
||||||
|
strategy = MagicMock()
|
||||||
|
strategy.id = uuid.uuid4()
|
||||||
|
strategy.name = "momentum"
|
||||||
|
strategy.description = "Momentum strategy"
|
||||||
|
strategy.current_weight = 0.333
|
||||||
|
strategy.active = True
|
||||||
|
strategy.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
session.execute = AsyncMock(
|
||||||
|
return_value=_make_execute_result([strategy])
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = client.get("/api/strategies")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
assert data[0]["name"] == "momentum"
|
||||||
|
assert data[0]["current_weight"] == 0.333
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# News Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNewsEndpoint:
|
||||||
|
"""test_news_endpoint."""
|
||||||
|
|
||||||
|
def test_news_returns_paginated_articles(
|
||||||
|
self, client: TestClient, mock_session_factory
|
||||||
|
) -> None:
|
||||||
|
_, session = mock_session_factory
|
||||||
|
|
||||||
|
article = MagicMock()
|
||||||
|
article.id = uuid.uuid4()
|
||||||
|
article.source = "reuters"
|
||||||
|
article.url = "https://reuters.com/article/1"
|
||||||
|
article.title = "Stock rises"
|
||||||
|
article.published_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
article.fetched_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
sentiment = MagicMock()
|
||||||
|
sentiment.ticker = "AAPL"
|
||||||
|
sentiment.score = 0.8
|
||||||
|
sentiment.confidence = 0.9
|
||||||
|
sentiment.model_used = "finbert"
|
||||||
|
|
||||||
|
count_result = _make_execute_result([], scalar=1)
|
||||||
|
data_result = MagicMock()
|
||||||
|
data_result.all.return_value = [(article, sentiment)]
|
||||||
|
session.execute = AsyncMock(side_effect=[count_result, data_result])
|
||||||
|
|
||||||
|
resp = client.get("/api/news")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "articles" in data
|
||||||
|
assert data["total"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Controls Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestControlsPauseResume:
|
||||||
|
"""test_controls_pause_resume."""
|
||||||
|
|
||||||
|
def test_pause_sets_redis_key(
|
||||||
|
self, client: TestClient, mock_redis: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
resp = client.post("/api/controls/pause")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "paused"
|
||||||
|
mock_redis.set.assert_called_once_with("trading:paused", "1")
|
||||||
|
|
||||||
|
def test_resume_clears_redis_key(
|
||||||
|
self, client: TestClient, mock_redis: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
resp = client.post("/api/controls/resume")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "active"
|
||||||
|
mock_redis.delete.assert_called_once_with("trading:paused")
|
||||||
|
|
||||||
|
|
||||||
|
class TestControlsStatus:
|
||||||
|
"""test_controls_status."""
|
||||||
|
|
||||||
|
def test_status_active_when_not_paused(
|
||||||
|
self, client: TestClient, mock_redis: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
mock_redis.get = AsyncMock(return_value=None)
|
||||||
|
resp = client.get("/api/controls/status")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "active"
|
||||||
|
|
||||||
|
def test_status_paused_when_flag_set(
|
||||||
|
self, client: TestClient, mock_redis: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
mock_redis.get = AsyncMock(return_value="1")
|
||||||
|
resp = client.get("/api/controls/status")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "paused"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Backtest Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBacktestRunEndpoint:
|
||||||
|
"""test_backtest_run_endpoint."""
|
||||||
|
|
||||||
|
def test_backtest_run_returns_run_id(
|
||||||
|
self, client: TestClient, mock_redis: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/backtest/run",
|
||||||
|
json={
|
||||||
|
"start_date": "2024-01-01T00:00:00Z",
|
||||||
|
"end_date": "2024-06-01T00:00:00Z",
|
||||||
|
"initial_capital": 100000,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "run_id" in data
|
||||||
|
assert data["status"] == "running"
|
||||||
|
mock_redis.setex.assert_called()
|
||||||
|
|
||||||
|
def test_backtest_get_not_found(
|
||||||
|
self, client: TestClient, mock_redis: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
mock_redis.get = AsyncMock(return_value=None)
|
||||||
|
resp = client.get("/api/backtest/nonexistent-id")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_backtest_get_returns_result(
|
||||||
|
self, client: TestClient, mock_redis: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
stored = json.dumps({
|
||||||
|
"status": "completed",
|
||||||
|
"result": {"total_return": 0.15, "sharpe_ratio": 1.2},
|
||||||
|
})
|
||||||
|
mock_redis.get = AsyncMock(return_value=stored)
|
||||||
|
|
||||||
|
resp = client.get("/api/backtest/some-run-id")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["status"] == "completed"
|
||||||
|
assert data["result"]["total_return"] == 0.15
|
||||||
Loading…
Add table
Add a link
Reference in a new issue