feat: API gateway trading endpoints, controls, backtest, WebSocket
This commit is contained in:
parent
e0d138c457
commit
6fe586f722
11 changed files with 1304 additions and 0 deletions
|
|
@ -69,6 +69,28 @@ def create_app(config: ApiGatewayConfig | None = None) -> FastAPI:
|
|||
# Auth routes (unauthenticated)
|
||||
app.include_router(auth_router)
|
||||
|
||||
# Trading routes (authenticated) — imported lazily to avoid circular deps
|
||||
from services.api_gateway.routes.portfolio import router as portfolio_router
|
||||
from services.api_gateway.routes.trades import router as trades_router
|
||||
from services.api_gateway.routes.signals import router as signals_router
|
||||
from services.api_gateway.routes.strategies import router as strategies_router
|
||||
from services.api_gateway.routes.news import router as news_router
|
||||
from services.api_gateway.routes.controls import router as controls_router
|
||||
from services.api_gateway.routes.backtest import router as backtest_router
|
||||
|
||||
app.include_router(portfolio_router)
|
||||
app.include_router(trades_router)
|
||||
app.include_router(signals_router)
|
||||
app.include_router(strategies_router)
|
||||
app.include_router(news_router)
|
||||
app.include_router(controls_router)
|
||||
app.include_router(backtest_router)
|
||||
|
||||
# WebSocket
|
||||
from services.api_gateway.ws import router as ws_router
|
||||
|
||||
app.include_router(ws_router)
|
||||
|
||||
# Health check
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
|
|
|
|||
1
services/api_gateway/routes/__init__.py
Normal file
1
services/api_gateway/routes/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Trading route sub-package."""
|
||||
156
services/api_gateway/routes/backtest.py
Normal file
156
services/api_gateway/routes/backtest.py
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
"""Backtest endpoints — run backtests and retrieve results."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
|
||||
|
||||
|
||||
class BacktestRequest(BaseModel):
|
||||
"""Request body for starting a new backtest."""
|
||||
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
initial_capital: float = Field(default=100_000.0, gt=0)
|
||||
commission_per_trade: float = Field(default=0.0, ge=0)
|
||||
slippage_pct: float = Field(default=0.001, ge=0)
|
||||
strategy_weights: dict[str, float] = Field(default_factory=dict)
|
||||
max_position_pct: float = Field(default=0.05, gt=0, le=1.0)
|
||||
signal_threshold: float = Field(default=0.3, ge=0, le=1.0)
|
||||
|
||||
|
||||
@router.post("/run")
|
||||
async def run_backtest(
|
||||
body: BacktestRequest,
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Start a backtest with the given configuration.
|
||||
|
||||
Returns a ``run_id`` immediately. The backtest runs in a background
|
||||
task and its results can be retrieved via ``GET /api/backtest/{run_id}``.
|
||||
"""
|
||||
run_id = str(uuid.uuid4())
|
||||
redis = request.app.state.redis
|
||||
|
||||
# Store initial status
|
||||
await redis.setex(
|
||||
f"backtest:{run_id}",
|
||||
86400, # 24h TTL
|
||||
json.dumps({
|
||||
"status": "running",
|
||||
"config": body.model_dump(mode="json"),
|
||||
"started_at": datetime.now(tz=timezone.utc).isoformat(),
|
||||
}),
|
||||
)
|
||||
|
||||
# Launch background task
|
||||
asyncio.create_task(_run_backtest_task(run_id, body, redis))
|
||||
|
||||
return {"run_id": run_id, "status": "running"}
|
||||
|
||||
|
||||
async def _run_backtest_task(
|
||||
run_id: str,
|
||||
config: BacktestRequest,
|
||||
redis,
|
||||
) -> None:
|
||||
"""Execute the backtest in the background and store results in Redis."""
|
||||
try:
|
||||
from backtester.config import BacktestConfig
|
||||
from backtester.engine import BacktestEngine
|
||||
from shared.strategies.momentum import MomentumStrategy
|
||||
from shared.strategies.mean_reversion import MeanReversionStrategy
|
||||
from shared.strategies.news_driven import NewsDrivenStrategy
|
||||
|
||||
bt_config = BacktestConfig(
|
||||
start_date=config.start_date,
|
||||
end_date=config.end_date,
|
||||
initial_capital=config.initial_capital,
|
||||
commission_per_trade=config.commission_per_trade,
|
||||
slippage_pct=config.slippage_pct,
|
||||
strategy_weights=config.strategy_weights,
|
||||
max_position_pct=config.max_position_pct,
|
||||
signal_threshold=config.signal_threshold,
|
||||
)
|
||||
|
||||
strategies = [
|
||||
MomentumStrategy(),
|
||||
MeanReversionStrategy(),
|
||||
NewsDrivenStrategy(),
|
||||
]
|
||||
|
||||
engine = BacktestEngine(config=bt_config, strategies=strategies)
|
||||
|
||||
# Use an in-memory stub data loader for now; a full implementation
|
||||
# would read from TimescaleDB.
|
||||
from backtester.data_loader import BacktestDataLoader
|
||||
|
||||
data_loader = BacktestDataLoader(
|
||||
session_factory=None,
|
||||
start_date=config.start_date,
|
||||
end_date=config.end_date,
|
||||
)
|
||||
|
||||
result = await engine.run(data_loader)
|
||||
|
||||
await redis.setex(
|
||||
f"backtest:{run_id}",
|
||||
86400,
|
||||
json.dumps({
|
||||
"status": "completed",
|
||||
"config": config.model_dump(mode="json"),
|
||||
"result": {
|
||||
"total_return": result.total_return,
|
||||
"annualized_return": result.annualized_return,
|
||||
"sharpe_ratio": result.sharpe_ratio,
|
||||
"sortino_ratio": result.sortino_ratio,
|
||||
"max_drawdown": result.max_drawdown,
|
||||
"max_drawdown_duration_days": result.max_drawdown_duration_days,
|
||||
"win_rate": result.win_rate,
|
||||
"trade_count": result.trade_count,
|
||||
"avg_hold_duration_hours": result.avg_hold_duration_hours,
|
||||
"profit_factor": result.profit_factor,
|
||||
},
|
||||
"completed_at": datetime.now(tz=timezone.utc).isoformat(),
|
||||
}),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Backtest %s failed", run_id)
|
||||
await redis.setex(
|
||||
f"backtest:{run_id}",
|
||||
86400,
|
||||
json.dumps({
|
||||
"status": "failed",
|
||||
"error": str(exc),
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{run_id}")
|
||||
async def get_backtest(
|
||||
run_id: str,
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Get backtest results by run ID."""
|
||||
redis = request.app.state.redis
|
||||
raw = await redis.get(f"backtest:{run_id}")
|
||||
if raw is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Backtest run not found",
|
||||
)
|
||||
return json.loads(raw)
|
||||
95
services/api_gateway/routes/controls.py
Normal file
95
services/api_gateway/routes/controls.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""Control endpoints — pause/resume trading, force close positions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/controls", tags=["controls"])
|
||||
|
||||
TRADING_PAUSED_KEY = "trading:paused"
|
||||
|
||||
|
||||
class ClosePositionRequest(BaseModel):
|
||||
"""Body for the force-close-position endpoint."""
|
||||
|
||||
ticker: str
|
||||
|
||||
|
||||
@router.post("/pause")
|
||||
async def pause_trading(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Set Redis flag to pause trading."""
|
||||
redis = request.app.state.redis
|
||||
await redis.set(TRADING_PAUSED_KEY, "1")
|
||||
return {"status": "paused"}
|
||||
|
||||
|
||||
@router.post("/resume")
|
||||
async def resume_trading(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Clear pause flag to resume trading."""
|
||||
redis = request.app.state.redis
|
||||
await redis.delete(TRADING_PAUSED_KEY)
|
||||
return {"status": "active"}
|
||||
|
||||
|
||||
@router.post("/close-position")
|
||||
async def close_position(
|
||||
body: ClosePositionRequest,
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Force close an open position by ticker.
|
||||
|
||||
Publishes a close-position command to Redis so the trade executor
|
||||
picks it up asynchronously.
|
||||
"""
|
||||
import json
|
||||
|
||||
from sqlalchemy import select
|
||||
from shared.models.trading import Position
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
position = (
|
||||
await session.execute(
|
||||
select(Position).where(Position.ticker == body.ticker.upper())
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if position is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No open position for {body.ticker}",
|
||||
)
|
||||
|
||||
# Publish close command to Redis for the trade executor
|
||||
redis = request.app.state.redis
|
||||
await redis.publish(
|
||||
"controls:close_position",
|
||||
json.dumps({"ticker": body.ticker.upper(), "qty": position.qty}),
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "close_requested",
|
||||
"ticker": body.ticker.upper(),
|
||||
"qty": position.qty,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_trading_status(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Current trading status — active or paused."""
|
||||
redis = request.app.state.redis
|
||||
paused = await redis.get(TRADING_PAUSED_KEY)
|
||||
return {"status": "paused" if paused else "active"}
|
||||
87
services/api_gateway/routes/news.py
Normal file
87
services/api_gateway/routes/news.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""News endpoints — recent scored articles with filtering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
from sqlalchemy import select, desc, func
|
||||
|
||||
router = APIRouter(prefix="/api/news", tags=["news"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_news(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
ticker: str | None = Query(default=None),
|
||||
source: str | None = Query(default=None),
|
||||
min_score: float | None = Query(default=None, ge=-1.0, le=1.0),
|
||||
max_score: float | None = Query(default=None, ge=-1.0, le=1.0),
|
||||
page: int = Query(default=1, ge=1),
|
||||
per_page: int = Query(default=20, ge=1, le=100),
|
||||
) -> dict:
|
||||
"""Recent scored articles with optional filters."""
|
||||
from shared.models.news import Article, ArticleSentiment
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
# Base query joining articles with sentiments
|
||||
query = (
|
||||
select(Article, ArticleSentiment)
|
||||
.join(ArticleSentiment, Article.id == ArticleSentiment.article_id)
|
||||
.order_by(desc(Article.fetched_at))
|
||||
)
|
||||
count_query = (
|
||||
select(func.count())
|
||||
.select_from(Article)
|
||||
.join(ArticleSentiment, Article.id == ArticleSentiment.article_id)
|
||||
)
|
||||
|
||||
if ticker:
|
||||
query = query.where(ArticleSentiment.ticker == ticker.upper())
|
||||
count_query = count_query.where(
|
||||
ArticleSentiment.ticker == ticker.upper()
|
||||
)
|
||||
if source:
|
||||
query = query.where(Article.source == source)
|
||||
count_query = count_query.where(Article.source == source)
|
||||
if min_score is not None:
|
||||
query = query.where(ArticleSentiment.score >= min_score)
|
||||
count_query = count_query.where(ArticleSentiment.score >= min_score)
|
||||
if max_score is not None:
|
||||
query = query.where(ArticleSentiment.score <= max_score)
|
||||
count_query = count_query.where(ArticleSentiment.score <= max_score)
|
||||
|
||||
total = (await session.execute(count_query)).scalar() or 0
|
||||
offset = (page - 1) * per_page
|
||||
query = query.offset(offset).limit(per_page)
|
||||
|
||||
result = await session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
return {
|
||||
"articles": [
|
||||
{
|
||||
"id": str(article.id),
|
||||
"source": article.source,
|
||||
"url": article.url,
|
||||
"title": article.title,
|
||||
"published_at": (
|
||||
article.published_at.isoformat()
|
||||
if article.published_at
|
||||
else None
|
||||
),
|
||||
"fetched_at": article.fetched_at.isoformat(),
|
||||
"ticker": sentiment.ticker,
|
||||
"sentiment_score": sentiment.score,
|
||||
"confidence": sentiment.confidence,
|
||||
"model_used": sentiment.model_used,
|
||||
}
|
||||
for article, sentiment in rows
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||
}
|
||||
125
services/api_gateway/routes/portfolio.py
Normal file
125
services/api_gateway/routes/portfolio.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
"""Portfolio endpoints — current value, positions, equity curve."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
from sqlalchemy import select, desc
|
||||
|
||||
router = APIRouter(prefix="/api/portfolio", tags=["portfolio"])
|
||||
|
||||
|
||||
class HistoryPeriod(str, Enum):
|
||||
ONE_DAY = "1d"
|
||||
ONE_WEEK = "1w"
|
||||
ONE_MONTH = "1m"
|
||||
THREE_MONTHS = "3m"
|
||||
ONE_YEAR = "1y"
|
||||
|
||||
|
||||
def _period_to_timedelta(period: HistoryPeriod) -> timedelta:
|
||||
"""Convert a period enum value to a timedelta."""
|
||||
mapping = {
|
||||
HistoryPeriod.ONE_DAY: timedelta(days=1),
|
||||
HistoryPeriod.ONE_WEEK: timedelta(weeks=1),
|
||||
HistoryPeriod.ONE_MONTH: timedelta(days=30),
|
||||
HistoryPeriod.THREE_MONTHS: timedelta(days=90),
|
||||
HistoryPeriod.ONE_YEAR: timedelta(days=365),
|
||||
}
|
||||
return mapping[period]
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_portfolio(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Current portfolio summary — value, cash, buying power, daily P&L."""
|
||||
from shared.models.timeseries import PortfolioSnapshot
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
latest = (
|
||||
await session.execute(
|
||||
select(PortfolioSnapshot)
|
||||
.order_by(desc(PortfolioSnapshot.timestamp))
|
||||
.limit(1)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if latest is None:
|
||||
return {
|
||||
"total_value": 0.0,
|
||||
"cash": 0.0,
|
||||
"buying_power": 0.0,
|
||||
"daily_pnl": 0.0,
|
||||
}
|
||||
|
||||
return {
|
||||
"total_value": latest.total_value,
|
||||
"cash": latest.cash,
|
||||
"buying_power": latest.cash,
|
||||
"daily_pnl": latest.daily_pnl,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/positions")
|
||||
async def get_positions(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""All open positions with unrealized P&L."""
|
||||
from shared.models.trading import Position
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
result = await session.execute(select(Position))
|
||||
positions = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(p.id),
|
||||
"ticker": p.ticker,
|
||||
"qty": p.qty,
|
||||
"avg_entry": p.avg_entry,
|
||||
"unrealized_pnl": p.unrealized_pnl or 0.0,
|
||||
"stop_loss": p.stop_loss,
|
||||
"take_profit": p.take_profit,
|
||||
}
|
||||
for p in positions
|
||||
]
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_portfolio_history(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
period: HistoryPeriod = Query(default=HistoryPeriod.ONE_MONTH),
|
||||
) -> list[dict]:
|
||||
"""Equity curve from portfolio_snapshots over a given period."""
|
||||
from shared.models.timeseries import PortfolioSnapshot
|
||||
|
||||
since = datetime.now(timezone.utc) - _period_to_timedelta(period)
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
result = await session.execute(
|
||||
select(PortfolioSnapshot)
|
||||
.where(PortfolioSnapshot.timestamp >= since)
|
||||
.order_by(PortfolioSnapshot.timestamp)
|
||||
)
|
||||
snapshots = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"timestamp": s.timestamp.isoformat(),
|
||||
"total_value": s.total_value,
|
||||
"cash": s.cash,
|
||||
"positions_value": s.positions_value,
|
||||
"daily_pnl": s.daily_pnl,
|
||||
}
|
||||
for s in snapshots
|
||||
]
|
||||
58
services/api_gateway/routes/signals.py
Normal file
58
services/api_gateway/routes/signals.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Signal endpoints — recent signals with filtering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
from sqlalchemy import select, desc, func
|
||||
|
||||
router = APIRouter(prefix="/api/signals", tags=["signals"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_signals(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
ticker: str | None = Query(default=None),
|
||||
page: int = Query(default=1, ge=1),
|
||||
per_page: int = Query(default=20, ge=1, le=100),
|
||||
) -> dict:
|
||||
"""Recent signals with optional ticker filter and pagination."""
|
||||
from shared.models.trading import Signal
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
query = select(Signal).order_by(desc(Signal.created_at))
|
||||
count_query = select(func.count()).select_from(Signal)
|
||||
|
||||
if ticker:
|
||||
query = query.where(Signal.ticker == ticker.upper())
|
||||
count_query = count_query.where(Signal.ticker == ticker.upper())
|
||||
|
||||
total = (await session.execute(count_query)).scalar() or 0
|
||||
offset = (page - 1) * per_page
|
||||
query = query.offset(offset).limit(per_page)
|
||||
|
||||
result = await session.execute(query)
|
||||
signals = result.scalars().all()
|
||||
|
||||
return {
|
||||
"signals": [
|
||||
{
|
||||
"id": str(s.id),
|
||||
"ticker": s.ticker,
|
||||
"direction": s.direction.value,
|
||||
"strength": s.strength,
|
||||
"strategy_sources": s.strategy_sources,
|
||||
"sentiment_score": s.sentiment_score,
|
||||
"acted_on": s.acted_on,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in signals
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||
}
|
||||
111
services/api_gateway/routes/strategies.py
Normal file
111
services/api_gateway/routes/strategies.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Strategy endpoints — list strategies, weight history, metrics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
from sqlalchemy import select, desc
|
||||
|
||||
router = APIRouter(prefix="/api/strategies", tags=["strategies"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_strategies(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""All strategies with current weights."""
|
||||
from shared.models.trading import Strategy
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
result = await session.execute(select(Strategy))
|
||||
strategies = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(s.id),
|
||||
"name": s.name,
|
||||
"description": s.description,
|
||||
"current_weight": s.current_weight,
|
||||
"active": s.active,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in strategies
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{strategy_id}/history")
|
||||
async def get_strategy_weight_history(
|
||||
strategy_id: UUID,
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""Weight history for a specific strategy."""
|
||||
from shared.models.trading import StrategyWeightHistory, Strategy
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
# Verify strategy exists
|
||||
strategy = (
|
||||
await session.execute(
|
||||
select(Strategy).where(Strategy.id == strategy_id)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if strategy is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Strategy not found",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(StrategyWeightHistory)
|
||||
.where(StrategyWeightHistory.strategy_id == strategy_id)
|
||||
.order_by(desc(StrategyWeightHistory.created_at))
|
||||
)
|
||||
history = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(h.id),
|
||||
"old_weight": h.old_weight,
|
||||
"new_weight": h.new_weight,
|
||||
"reason": h.reason,
|
||||
"created_at": h.created_at.isoformat() if h.created_at else None,
|
||||
}
|
||||
for h in history
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{strategy_id}/metrics")
|
||||
async def get_strategy_metrics(
|
||||
strategy_id: UUID,
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""Performance metrics over time for a specific strategy."""
|
||||
from shared.models.timeseries import StrategyMetric
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
result = await session.execute(
|
||||
select(StrategyMetric)
|
||||
.where(StrategyMetric.strategy_id == strategy_id)
|
||||
.order_by(desc(StrategyMetric.timestamp))
|
||||
.limit(100)
|
||||
)
|
||||
metrics = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"timestamp": m.timestamp.isoformat(),
|
||||
"win_rate": m.win_rate,
|
||||
"total_pnl": m.total_pnl,
|
||||
"trade_count": m.trade_count,
|
||||
"sharpe_ratio": m.sharpe_ratio,
|
||||
}
|
||||
for m in metrics
|
||||
]
|
||||
125
services/api_gateway/routes/trades.py
Normal file
125
services/api_gateway/routes/trades.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
"""Trade endpoints — paginated trade history and detail."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
from sqlalchemy import select, desc, func
|
||||
|
||||
router = APIRouter(prefix="/api/trades", tags=["trades"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_trades(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
ticker: str | None = Query(default=None),
|
||||
start_date: datetime | None = Query(default=None),
|
||||
end_date: datetime | None = Query(default=None),
|
||||
strategy: str | None = Query(default=None),
|
||||
profitable: bool | None = Query(default=None),
|
||||
page: int = Query(default=1, ge=1),
|
||||
per_page: int = Query(default=20, ge=1, le=100),
|
||||
) -> dict:
|
||||
"""Paginated trade history with optional filters."""
|
||||
from shared.models.trading import Trade, Strategy
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
query = select(Trade).order_by(desc(Trade.created_at))
|
||||
count_query = select(func.count()).select_from(Trade)
|
||||
|
||||
# Apply filters
|
||||
if ticker:
|
||||
query = query.where(Trade.ticker == ticker.upper())
|
||||
count_query = count_query.where(Trade.ticker == ticker.upper())
|
||||
if start_date:
|
||||
query = query.where(Trade.created_at >= start_date)
|
||||
count_query = count_query.where(Trade.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.where(Trade.created_at <= end_date)
|
||||
count_query = count_query.where(Trade.created_at <= end_date)
|
||||
if strategy:
|
||||
# Join with Strategy to filter by name
|
||||
query = query.join(Strategy, Trade.strategy_id == Strategy.id).where(
|
||||
Strategy.name == strategy
|
||||
)
|
||||
count_query = count_query.join(
|
||||
Strategy, Trade.strategy_id == Strategy.id
|
||||
).where(Strategy.name == strategy)
|
||||
if profitable is not None:
|
||||
if profitable:
|
||||
query = query.where(Trade.pnl > 0)
|
||||
count_query = count_query.where(Trade.pnl > 0)
|
||||
else:
|
||||
query = query.where(Trade.pnl <= 0)
|
||||
count_query = count_query.where(Trade.pnl <= 0)
|
||||
|
||||
# Pagination
|
||||
total = (await session.execute(count_query)).scalar() or 0
|
||||
offset = (page - 1) * per_page
|
||||
query = query.offset(offset).limit(per_page)
|
||||
|
||||
result = await session.execute(query)
|
||||
trades = result.scalars().all()
|
||||
|
||||
return {
|
||||
"trades": [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"ticker": t.ticker,
|
||||
"side": t.side.value,
|
||||
"qty": t.qty,
|
||||
"price": t.price,
|
||||
"status": t.status.value,
|
||||
"pnl": t.pnl,
|
||||
"strategy_id": str(t.strategy_id) if t.strategy_id else None,
|
||||
"signal_id": str(t.signal_id) if t.signal_id else None,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in trades
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{trade_id}")
|
||||
async def get_trade(
|
||||
trade_id: UUID,
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Single trade detail with linked signal and outcome."""
|
||||
from shared.models.trading import Trade
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
trade = (
|
||||
await session.execute(select(Trade).where(Trade.id == trade_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if trade is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Trade not found",
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(trade.id),
|
||||
"ticker": trade.ticker,
|
||||
"side": trade.side.value,
|
||||
"qty": trade.qty,
|
||||
"price": trade.price,
|
||||
"status": trade.status.value,
|
||||
"pnl": trade.pnl,
|
||||
"strategy_id": str(trade.strategy_id) if trade.strategy_id else None,
|
||||
"signal_id": str(trade.signal_id) if trade.signal_id else None,
|
||||
"created_at": trade.created_at.isoformat() if trade.created_at else None,
|
||||
}
|
||||
138
services/api_gateway/ws.py
Normal file
138
services/api_gateway/ws.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""WebSocket endpoint — authenticated real-time event stream."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
from services.api_gateway.auth.jwt import decode_token
|
||||
from services.api_gateway.config import ApiGatewayConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["websocket"])
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages active WebSocket connections."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: list[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket) -> None:
|
||||
if websocket in self.active_connections:
|
||||
self.active_connections.remove(websocket)
|
||||
|
||||
async def broadcast(self, message: dict) -> None:
|
||||
"""Send a message to all connected clients. Remove any that have
|
||||
disconnected."""
|
||||
disconnected: list[WebSocket] = []
|
||||
for connection in self.active_connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception:
|
||||
disconnected.append(connection)
|
||||
for ws in disconnected:
|
||||
self.disconnect(ws)
|
||||
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(default=""),
|
||||
) -> None:
|
||||
"""Authenticated WebSocket that pushes real-time trading events.
|
||||
|
||||
Connect via ``ws://host/ws?token=<JWT>``. The server subscribes
|
||||
to Redis pub/sub channels and forwards events to all connected
|
||||
clients.
|
||||
|
||||
Event types pushed to clients:
|
||||
- ``trade_executed``
|
||||
- ``signal_generated``
|
||||
- ``portfolio_update``
|
||||
"""
|
||||
# Authenticate via JWT query parameter
|
||||
config: ApiGatewayConfig = websocket.app.state.config
|
||||
try:
|
||||
payload = decode_token(token, config)
|
||||
if payload.get("type") != "access":
|
||||
await websocket.close(code=4001, reason="Invalid token type")
|
||||
return
|
||||
except pyjwt.ExpiredSignatureError:
|
||||
await websocket.close(code=4001, reason="Token expired")
|
||||
return
|
||||
except pyjwt.InvalidTokenError:
|
||||
await websocket.close(code=4001, reason="Invalid token")
|
||||
return
|
||||
|
||||
await manager.connect(websocket)
|
||||
|
||||
# Subscribe to Redis pub/sub channels for real-time events
|
||||
redis = websocket.app.state.redis
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(
|
||||
"events:trade_executed",
|
||||
"events:signal_generated",
|
||||
"events:portfolio_update",
|
||||
)
|
||||
|
||||
try:
|
||||
# Run two concurrent tasks: one reads from Redis pub/sub and pushes
|
||||
# to the client, the other keeps the WebSocket alive by reading
|
||||
# (and discarding) incoming messages.
|
||||
|
||||
async def _redis_listener() -> None:
|
||||
"""Forward Redis pub/sub messages to this WebSocket client."""
|
||||
while True:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if message and message["type"] == "message":
|
||||
try:
|
||||
data = json.loads(message["data"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
data = {"raw": str(message["data"])}
|
||||
channel = message["channel"]
|
||||
if isinstance(channel, bytes):
|
||||
channel = channel.decode()
|
||||
event_type = channel.replace("events:", "")
|
||||
await websocket.send_json(
|
||||
{"event": event_type, "data": data}
|
||||
)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _ws_receiver() -> None:
|
||||
"""Keep the connection alive by reading messages."""
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
|
||||
# Run both tasks; whichever finishes first (e.g., client disconnect)
|
||||
# will cancel the other.
|
||||
listener_task = asyncio.create_task(_redis_listener())
|
||||
receiver_task = asyncio.create_task(_ws_receiver())
|
||||
done, pending = await asyncio.wait(
|
||||
{listener_task, receiver_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
manager.disconnect(websocket)
|
||||
await pubsub.unsubscribe()
|
||||
await pubsub.aclose()
|
||||
386
tests/services/test_api_routes.py
Normal file
386
tests/services/test_api_routes.py
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
"""Tests for API Gateway trading endpoints (Task 14).
|
||||
|
||||
Uses FastAPI TestClient with mocked DB sessions and Redis.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from services.api_gateway.auth.jwt import create_access_token
|
||||
from services.api_gateway.auth.middleware import get_config, get_current_user
|
||||
from services.api_gateway.config import ApiGatewayConfig
|
||||
from services.api_gateway.main import create_app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def config() -> ApiGatewayConfig:
|
||||
return ApiGatewayConfig(
|
||||
jwt_secret_key="test-secret-for-routes",
|
||||
database_url="sqlite+aiosqlite:///:memory:",
|
||||
redis_url="redis://localhost:6379/0",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_user() -> dict:
|
||||
return {"sub": "user-test", "username": "tester", "type": "access"}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def auth_headers(config: ApiGatewayConfig) -> dict[str, str]:
|
||||
token = create_access_token("user-test", "tester", config)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_redis() -> AsyncMock:
|
||||
"""A fully mocked Redis client."""
|
||||
redis = AsyncMock()
|
||||
redis.get = AsyncMock(return_value=None)
|
||||
redis.set = AsyncMock()
|
||||
redis.setex = AsyncMock()
|
||||
redis.delete = AsyncMock()
|
||||
redis.publish = AsyncMock()
|
||||
return redis
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_session_factory():
|
||||
"""Creates a mock async session factory that returns a mock session."""
|
||||
session = AsyncMock()
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
factory = MagicMock()
|
||||
factory.return_value = session
|
||||
return factory, session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(
|
||||
config: ApiGatewayConfig,
|
||||
mock_user: dict,
|
||||
mock_redis: AsyncMock,
|
||||
mock_session_factory,
|
||||
) -> TestClient:
|
||||
"""Create a test client with all dependencies mocked."""
|
||||
factory, session = mock_session_factory
|
||||
|
||||
app = create_app(config)
|
||||
|
||||
# Override auth dependency to bypass JWT validation
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
app.dependency_overrides[get_config] = lambda: config
|
||||
|
||||
# Inject mock state
|
||||
app.state.redis = mock_redis
|
||||
app.state.db_session_factory = factory
|
||||
app.state.db_engine = MagicMock()
|
||||
app.state.config = config
|
||||
|
||||
return TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: build mock execute results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_execute_result(rows, scalar=None):
|
||||
"""Build a mock result for session.execute()."""
|
||||
result = MagicMock()
|
||||
result.scalars.return_value.all.return_value = rows
|
||||
result.scalars.return_value.__iter__ = lambda self: iter(rows)
|
||||
result.scalar_one_or_none.return_value = scalar
|
||||
result.scalar.return_value = len(rows) if scalar is None else scalar
|
||||
result.all.return_value = rows
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Portfolio Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPortfolioEndpoint:
|
||||
"""test_portfolio_endpoint."""
|
||||
|
||||
def test_portfolio_returns_defaults_when_empty(
|
||||
self, client: TestClient, mock_session_factory
|
||||
) -> None:
|
||||
_, session = mock_session_factory
|
||||
session.execute = AsyncMock(
|
||||
return_value=_make_execute_result([], scalar=None)
|
||||
)
|
||||
|
||||
resp = client.get("/api/portfolio")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total_value"] == 0.0
|
||||
assert data["cash"] == 0.0
|
||||
assert data["daily_pnl"] == 0.0
|
||||
|
||||
|
||||
class TestPositionsEndpoint:
|
||||
"""test_positions_endpoint."""
|
||||
|
||||
def test_positions_returns_list(
|
||||
self, client: TestClient, mock_session_factory
|
||||
) -> None:
|
||||
_, session = mock_session_factory
|
||||
|
||||
# Create mock positions
|
||||
pos = MagicMock()
|
||||
pos.id = uuid.uuid4()
|
||||
pos.ticker = "AAPL"
|
||||
pos.qty = 10.0
|
||||
pos.avg_entry = 150.0
|
||||
pos.unrealized_pnl = 50.0
|
||||
pos.stop_loss = 145.0
|
||||
pos.take_profit = 160.0
|
||||
|
||||
session.execute = AsyncMock(
|
||||
return_value=_make_execute_result([pos])
|
||||
)
|
||||
|
||||
resp = client.get("/api/portfolio/positions")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["ticker"] == "AAPL"
|
||||
assert data[0]["qty"] == 10.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trades Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTradesListEndpoint:
|
||||
"""test_trades_list_endpoint."""
|
||||
|
||||
def test_trades_returns_paginated_list(
|
||||
self, client: TestClient, mock_session_factory
|
||||
) -> None:
|
||||
_, session = mock_session_factory
|
||||
|
||||
trade = MagicMock()
|
||||
trade.id = uuid.uuid4()
|
||||
trade.ticker = "TSLA"
|
||||
trade.side.value = "BUY"
|
||||
trade.qty = 5.0
|
||||
trade.price = 200.0
|
||||
trade.status.value = "FILLED"
|
||||
trade.pnl = 25.0
|
||||
trade.strategy_id = None
|
||||
trade.signal_id = None
|
||||
trade.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
# session.execute will be called twice: count + data
|
||||
count_result = _make_execute_result([], scalar=1)
|
||||
data_result = _make_execute_result([trade])
|
||||
session.execute = AsyncMock(side_effect=[count_result, data_result])
|
||||
|
||||
resp = client.get("/api/trades")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "trades" in data
|
||||
assert "total" in data
|
||||
assert "page" in data
|
||||
|
||||
|
||||
class TestTradesPagination:
|
||||
"""test_trades_pagination."""
|
||||
|
||||
def test_trades_page_and_per_page(
|
||||
self, client: TestClient, mock_session_factory
|
||||
) -> None:
|
||||
_, session = mock_session_factory
|
||||
|
||||
count_result = _make_execute_result([], scalar=50)
|
||||
data_result = _make_execute_result([])
|
||||
session.execute = AsyncMock(side_effect=[count_result, data_result])
|
||||
|
||||
resp = client.get("/api/trades?page=3&per_page=10")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["page"] == 3
|
||||
assert data["per_page"] == 10
|
||||
assert data["pages"] == 5 # 50 / 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategies Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStrategiesEndpoint:
|
||||
"""test_strategies_endpoint."""
|
||||
|
||||
def test_strategies_returns_list(
|
||||
self, client: TestClient, mock_session_factory
|
||||
) -> None:
|
||||
_, session = mock_session_factory
|
||||
|
||||
strategy = MagicMock()
|
||||
strategy.id = uuid.uuid4()
|
||||
strategy.name = "momentum"
|
||||
strategy.description = "Momentum strategy"
|
||||
strategy.current_weight = 0.333
|
||||
strategy.active = True
|
||||
strategy.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
session.execute = AsyncMock(
|
||||
return_value=_make_execute_result([strategy])
|
||||
)
|
||||
|
||||
resp = client.get("/api/strategies")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["name"] == "momentum"
|
||||
assert data[0]["current_weight"] == 0.333
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# News Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNewsEndpoint:
|
||||
"""test_news_endpoint."""
|
||||
|
||||
def test_news_returns_paginated_articles(
|
||||
self, client: TestClient, mock_session_factory
|
||||
) -> None:
|
||||
_, session = mock_session_factory
|
||||
|
||||
article = MagicMock()
|
||||
article.id = uuid.uuid4()
|
||||
article.source = "reuters"
|
||||
article.url = "https://reuters.com/article/1"
|
||||
article.title = "Stock rises"
|
||||
article.published_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
article.fetched_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
sentiment = MagicMock()
|
||||
sentiment.ticker = "AAPL"
|
||||
sentiment.score = 0.8
|
||||
sentiment.confidence = 0.9
|
||||
sentiment.model_used = "finbert"
|
||||
|
||||
count_result = _make_execute_result([], scalar=1)
|
||||
data_result = MagicMock()
|
||||
data_result.all.return_value = [(article, sentiment)]
|
||||
session.execute = AsyncMock(side_effect=[count_result, data_result])
|
||||
|
||||
resp = client.get("/api/news")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "articles" in data
|
||||
assert data["total"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Controls Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestControlsPauseResume:
|
||||
"""test_controls_pause_resume."""
|
||||
|
||||
def test_pause_sets_redis_key(
|
||||
self, client: TestClient, mock_redis: AsyncMock
|
||||
) -> None:
|
||||
resp = client.post("/api/controls/pause")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "paused"
|
||||
mock_redis.set.assert_called_once_with("trading:paused", "1")
|
||||
|
||||
def test_resume_clears_redis_key(
|
||||
self, client: TestClient, mock_redis: AsyncMock
|
||||
) -> None:
|
||||
resp = client.post("/api/controls/resume")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "active"
|
||||
mock_redis.delete.assert_called_once_with("trading:paused")
|
||||
|
||||
|
||||
class TestControlsStatus:
|
||||
"""test_controls_status."""
|
||||
|
||||
def test_status_active_when_not_paused(
|
||||
self, client: TestClient, mock_redis: AsyncMock
|
||||
) -> None:
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
resp = client.get("/api/controls/status")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "active"
|
||||
|
||||
def test_status_paused_when_flag_set(
|
||||
self, client: TestClient, mock_redis: AsyncMock
|
||||
) -> None:
|
||||
mock_redis.get = AsyncMock(return_value="1")
|
||||
resp = client.get("/api/controls/status")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "paused"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backtest Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBacktestRunEndpoint:
|
||||
"""test_backtest_run_endpoint."""
|
||||
|
||||
def test_backtest_run_returns_run_id(
|
||||
self, client: TestClient, mock_redis: AsyncMock
|
||||
) -> None:
|
||||
resp = client.post(
|
||||
"/api/backtest/run",
|
||||
json={
|
||||
"start_date": "2024-01-01T00:00:00Z",
|
||||
"end_date": "2024-06-01T00:00:00Z",
|
||||
"initial_capital": 100000,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "run_id" in data
|
||||
assert data["status"] == "running"
|
||||
mock_redis.setex.assert_called()
|
||||
|
||||
def test_backtest_get_not_found(
|
||||
self, client: TestClient, mock_redis: AsyncMock
|
||||
) -> None:
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
resp = client.get("/api/backtest/nonexistent-id")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_backtest_get_returns_result(
|
||||
self, client: TestClient, mock_redis: AsyncMock
|
||||
) -> None:
|
||||
stored = json.dumps({
|
||||
"status": "completed",
|
||||
"result": {"total_return": 0.15, "sharpe_ratio": 1.2},
|
||||
})
|
||||
mock_redis.get = AsyncMock(return_value=stored)
|
||||
|
||||
resp = client.get("/api/backtest/some-run-id")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "completed"
|
||||
assert data["result"]["total_return"] == 0.15
|
||||
Loading…
Add table
Add a link
Reference in a new issue