fix: resolve 13 important issues from code review
I1: Add graceful shutdown (SIGTERM/SIGINT) to all 5 background services I2: Fix Dockerfile healthcheck to use curl on /metrics endpoint I3: Fix StreamConsumer.ensure_group() to only catch BUSYGROUP errors I4: Fix SimulatedBroker to reject orders with insufficient cash/shares I5: Move ORM attribute access inside DB session context in trades routes I6: Add Redis-based rate limiting (10 req/min/IP) on all auth endpoints I8: Prevent backtest background task garbage collection I9: Use Numeric(16,6) instead of Float for financial columns in migration I10: Add index on trades.created_at for time-range queries I11: Bind infrastructure ports to 127.0.0.1 in docker-compose I12: Add migrations init service; all app services depend on it I13: Fix user enumeration in login_begin (return options for non-existent users)
This commit is contained in:
parent
2a56727267
commit
5a6b20c8f1
13 changed files with 355 additions and 188 deletions
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -52,6 +53,21 @@ async def _get_db(request: Request):
|
|||
yield session
|
||||
|
||||
|
||||
async def _rate_limit_auth(request: Request) -> None:
|
||||
"""Simple fixed-window rate limiter for auth endpoints (10 req/min per IP)."""
|
||||
redis = request.app.state.redis
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
key = f"rate_limit:auth:{client_ip}"
|
||||
current = await redis.incr(key)
|
||||
if current == 1:
|
||||
await redis.expire(key, 60)
|
||||
if current > 10:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Too many requests. Please try again later.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -62,6 +78,7 @@ async def register_begin(
|
|||
body: RegisterRequest,
|
||||
request: Request,
|
||||
config: ApiGatewayConfig = Depends(get_config),
|
||||
_rl: None = Depends(_rate_limit_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""Generate WebAuthn registration options (challenge + relying party info).
|
||||
|
||||
|
|
@ -125,6 +142,7 @@ async def register_begin(
|
|||
async def register_complete(
|
||||
request: Request,
|
||||
config: ApiGatewayConfig = Depends(get_config),
|
||||
_rl: None = Depends(_rate_limit_auth),
|
||||
) -> TokenResponse:
|
||||
"""Verify WebAuthn registration response and store credential.
|
||||
|
||||
|
|
@ -211,8 +229,13 @@ async def login_begin(
|
|||
body: LoginRequest,
|
||||
request: Request,
|
||||
config: ApiGatewayConfig = Depends(get_config),
|
||||
_rl: None = Depends(_rate_limit_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""Generate WebAuthn authentication options for an existing user."""
|
||||
"""Generate WebAuthn authentication options for an existing user.
|
||||
|
||||
Returns valid-looking options even if the user does not exist
|
||||
to prevent username enumeration.
|
||||
"""
|
||||
import json as _json
|
||||
|
||||
redis = await _get_redis(request)
|
||||
|
|
@ -227,26 +250,22 @@ async def login_begin(
|
|||
select(User).where(User.username == body.username)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
creds = (
|
||||
await session.execute(
|
||||
select(UserCredential).where(
|
||||
UserCredential.user_id == user.id
|
||||
allow_credentials = []
|
||||
if user is not None:
|
||||
creds = (
|
||||
await session.execute(
|
||||
select(UserCredential).where(
|
||||
UserCredential.user_id == user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
allow_credentials = [
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64.urlsafe_b64decode(c.credential_id),
|
||||
)
|
||||
for c in creds
|
||||
]
|
||||
).scalars().all()
|
||||
allow_credentials = [
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64.urlsafe_b64decode(c.credential_id),
|
||||
)
|
||||
for c in creds
|
||||
]
|
||||
|
||||
options = generate_authentication_options(
|
||||
rp_id=config.rp_id,
|
||||
|
|
@ -255,16 +274,21 @@ async def login_begin(
|
|||
)
|
||||
|
||||
challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode()
|
||||
redis_key = f"webauthn:login:{body.username}"
|
||||
await redis.setex(
|
||||
redis_key,
|
||||
300,
|
||||
_json.dumps({
|
||||
"challenge": challenge_b64,
|
||||
"user_id": str(user.id),
|
||||
"username": user.username,
|
||||
}),
|
||||
)
|
||||
|
||||
if user is not None:
|
||||
# Real user — store challenge for verification
|
||||
redis_key = f"webauthn:login:{body.username}"
|
||||
await redis.setex(
|
||||
redis_key,
|
||||
300,
|
||||
_json.dumps({
|
||||
"challenge": challenge_b64,
|
||||
"user_id": str(user.id),
|
||||
"username": user.username,
|
||||
}),
|
||||
)
|
||||
# If user doesn't exist, we still return options but don't store
|
||||
# a challenge. login/complete will fail with a generic error.
|
||||
|
||||
from webauthn.helpers import options_to_json
|
||||
|
||||
|
|
@ -275,6 +299,7 @@ async def login_begin(
|
|||
async def login_complete(
|
||||
request: Request,
|
||||
config: ApiGatewayConfig = Depends(get_config),
|
||||
_rl: None = Depends(_rate_limit_auth),
|
||||
) -> TokenResponse:
|
||||
"""Verify WebAuthn authentication response and issue JWT."""
|
||||
import json as _json
|
||||
|
|
@ -356,7 +381,11 @@ async def login_complete(
|
|||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh(request: Request, config: ApiGatewayConfig = Depends(get_config)) -> TokenResponse:
|
||||
async def refresh(
|
||||
request: Request,
|
||||
config: ApiGatewayConfig = Depends(get_config),
|
||||
_rl: None = Depends(_rate_limit_auth),
|
||||
) -> TokenResponse:
|
||||
"""Exchange a valid refresh token for a new access token."""
|
||||
body = await request.json()
|
||||
refresh_token = body.get("refresh_token", "")
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
|
||||
|
||||
# Store references to background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
class BacktestRequest(BaseModel):
|
||||
"""Request body for starting a new backtest."""
|
||||
|
|
@ -56,8 +59,10 @@ async def run_backtest(
|
|||
}),
|
||||
)
|
||||
|
||||
# Launch background task
|
||||
asyncio.create_task(_run_backtest_task(run_id, body, redis))
|
||||
# Launch background task (stored in set to prevent GC)
|
||||
task = asyncio.create_task(_run_backtest_task(run_id, body, redis))
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
return {"run_id": run_id, "status": "running"}
|
||||
|
||||
|
|
|
|||
|
|
@ -67,27 +67,27 @@ async def list_trades(
|
|||
result = await session.execute(query)
|
||||
trades = result.scalars().all()
|
||||
|
||||
return {
|
||||
"trades": [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"ticker": t.ticker,
|
||||
"side": t.side.value,
|
||||
"qty": t.qty,
|
||||
"price": t.price,
|
||||
"status": t.status.value,
|
||||
"pnl": t.pnl,
|
||||
"strategy_id": str(t.strategy_id) if t.strategy_id else None,
|
||||
"signal_id": str(t.signal_id) if t.signal_id else None,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in trades
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||
}
|
||||
return {
|
||||
"trades": [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"ticker": t.ticker,
|
||||
"side": t.side.value,
|
||||
"qty": t.qty,
|
||||
"price": t.price,
|
||||
"status": t.status.value,
|
||||
"pnl": t.pnl,
|
||||
"strategy_id": str(t.strategy_id) if t.strategy_id else None,
|
||||
"signal_id": str(t.signal_id) if t.signal_id else None,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in trades
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{trade_id}")
|
||||
|
|
@ -105,21 +105,21 @@ async def get_trade(
|
|||
await session.execute(select(Trade).where(Trade.id == trade_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if trade is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Trade not found",
|
||||
)
|
||||
if trade is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Trade not found",
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(trade.id),
|
||||
"ticker": trade.ticker,
|
||||
"side": trade.side.value,
|
||||
"qty": trade.qty,
|
||||
"price": trade.price,
|
||||
"status": trade.status.value,
|
||||
"pnl": trade.pnl,
|
||||
"strategy_id": str(trade.strategy_id) if trade.strategy_id else None,
|
||||
"signal_id": str(trade.signal_id) if trade.signal_id else None,
|
||||
"created_at": trade.created_at.isoformat() if trade.created_at else None,
|
||||
}
|
||||
return {
|
||||
"id": str(trade.id),
|
||||
"ticker": trade.ticker,
|
||||
"side": trade.side.value,
|
||||
"qty": trade.qty,
|
||||
"price": trade.price,
|
||||
"status": trade.status.value,
|
||||
"pnl": trade.pnl,
|
||||
"strategy_id": str(trade.strategy_id) if trade.strategy_id else None,
|
||||
"signal_id": str(trade.signal_id) if trade.signal_id else None,
|
||||
"created_at": trade.created_at.isoformat() if trade.created_at else None,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue