fix: resolve 13 important issues from code review
I1: Add graceful shutdown (SIGTERM/SIGINT) to all 5 background services I2: Fix Dockerfile healthcheck to use curl on /metrics endpoint I3: Fix StreamConsumer.ensure_group() to only catch BUSYGROUP errors I4: Fix SimulatedBroker to reject orders with insufficient cash/shares I5: Move ORM attribute access inside DB session context in trades routes I6: Add Redis-based rate limiting (10 req/min/IP) on all auth endpoints I8: Prevent backtest background task garbage collection I9: Use Numeric(16,6) instead of Float for financial columns in migration I10: Add index on trades.created_at for time-range queries I11: Bind infrastructure ports to 127.0.0.1 in docker-compose I12: Add migrations init service; all app services depend on it I13: Fix user enumeration in login_begin (return options for non-existent users)
This commit is contained in:
parent
2a56727267
commit
5a6b20c8f1
13 changed files with 355 additions and 188 deletions
|
|
@ -66,8 +66,8 @@ def upgrade() -> None:
|
||||||
sa.Enum("BUY", "SELL", name="tradeside"),
|
sa.Enum("BUY", "SELL", name="tradeside"),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column("qty", sa.Float, nullable=False),
|
sa.Column("qty", sa.Numeric(16, 6), nullable=False),
|
||||||
sa.Column("price", sa.Float, nullable=False),
|
sa.Column("price", sa.Numeric(16, 6), nullable=False),
|
||||||
sa.Column("timestamp", sa.String, nullable=True),
|
sa.Column("timestamp", sa.String, nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"strategy_id",
|
"strategy_id",
|
||||||
|
|
@ -87,8 +87,8 @@ def upgrade() -> None:
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="PENDING",
|
server_default="PENDING",
|
||||||
),
|
),
|
||||||
sa.Column("pnl", sa.Float, nullable=True),
|
sa.Column("pnl", sa.Numeric(16, 6), nullable=True),
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), index=True),
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -96,11 +96,11 @@ def upgrade() -> None:
|
||||||
"positions",
|
"positions",
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||||
sa.Column("ticker", sa.String(20), unique=True, nullable=False),
|
sa.Column("ticker", sa.String(20), unique=True, nullable=False),
|
||||||
sa.Column("qty", sa.Float, nullable=False),
|
sa.Column("qty", sa.Numeric(16, 6), nullable=False),
|
||||||
sa.Column("avg_entry", sa.Float, nullable=False),
|
sa.Column("avg_entry", sa.Numeric(16, 6), nullable=False),
|
||||||
sa.Column("unrealized_pnl", sa.Float, nullable=True),
|
sa.Column("unrealized_pnl", sa.Numeric(16, 6), nullable=True),
|
||||||
sa.Column("stop_loss", sa.Float, nullable=True),
|
sa.Column("stop_loss", sa.Numeric(16, 6), nullable=True),
|
||||||
sa.Column("take_profit", sa.Float, nullable=True),
|
sa.Column("take_profit", sa.Numeric(16, 6), nullable=True),
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -93,10 +93,32 @@ class SimulatedBroker(BaseBroker):
|
||||||
|
|
||||||
# Deduct / credit cash
|
# Deduct / credit cash
|
||||||
if order.side == OrderSide.BUY:
|
if order.side == OrderSide.BUY:
|
||||||
self.cash -= cost
|
total_cost = cost + self.commission_per_trade
|
||||||
self.cash -= self.commission_per_trade
|
if total_cost > self.cash:
|
||||||
|
return OrderResult(
|
||||||
|
order_id=str(uuid.uuid4()),
|
||||||
|
ticker=order.ticker,
|
||||||
|
side=order.side,
|
||||||
|
qty=order.qty,
|
||||||
|
filled_price=None,
|
||||||
|
status=OrderStatus.REJECTED,
|
||||||
|
timestamp=datetime.now(tz=timezone.utc),
|
||||||
|
)
|
||||||
|
self.cash -= total_cost
|
||||||
self._update_position_buy(order.ticker, order.qty, fill_price)
|
self._update_position_buy(order.ticker, order.qty, fill_price)
|
||||||
else:
|
else:
|
||||||
|
# Validate sufficient shares to sell
|
||||||
|
current_qty = self._positions.get(order.ticker, {}).get("qty", 0.0)
|
||||||
|
if order.qty > current_qty:
|
||||||
|
return OrderResult(
|
||||||
|
order_id=str(uuid.uuid4()),
|
||||||
|
ticker=order.ticker,
|
||||||
|
side=order.side,
|
||||||
|
qty=order.qty,
|
||||||
|
filled_price=None,
|
||||||
|
status=OrderStatus.REJECTED,
|
||||||
|
timestamp=datetime.now(tz=timezone.utc),
|
||||||
|
)
|
||||||
self.cash += cost
|
self.cash += cost
|
||||||
self.cash -= self.commission_per_trade
|
self.cash -= self.commission_per_trade
|
||||||
self._update_position_sell(order.ticker, order.qty)
|
self._update_position_sell(order.ticker, order.qty)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ services:
|
||||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-trading}
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-trading}
|
||||||
POSTGRES_DB: trading
|
POSTGRES_DB: trading
|
||||||
ports:
|
ports:
|
||||||
- "5432:5432"
|
- "127.0.0.1:5432:5432"
|
||||||
volumes:
|
volumes:
|
||||||
- pgdata:/var/lib/postgresql/data
|
- pgdata:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
|
|
@ -21,7 +21,7 @@ services:
|
||||||
redis:
|
redis:
|
||||||
image: redis:7-alpine
|
image: redis:7-alpine
|
||||||
ports:
|
ports:
|
||||||
- "6379:6379"
|
- "127.0.0.1:6379:6379"
|
||||||
volumes:
|
volumes:
|
||||||
- redisdata:/data
|
- redisdata:/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
|
|
@ -33,10 +33,27 @@ services:
|
||||||
ollama:
|
ollama:
|
||||||
image: ollama/ollama:latest
|
image: ollama/ollama:latest
|
||||||
ports:
|
ports:
|
||||||
- "11434:11434"
|
- "127.0.0.1:11434:11434"
|
||||||
volumes:
|
volumes:
|
||||||
- ollama_models:/root/.ollama
|
- ollama_models:/root/.ollama
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Database migrations — runs once before application services start
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
migrations:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: docker/Dockerfile.service
|
||||||
|
args:
|
||||||
|
EXTRAS: "dev"
|
||||||
|
SERVICE_MODULE: "api_gateway"
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
env_file: .env
|
||||||
|
command: python -m alembic upgrade head
|
||||||
|
restart: "no"
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Application services
|
# Application services
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -52,6 +69,8 @@ services:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
migrations:
|
||||||
|
condition: service_completed_successfully
|
||||||
env_file: .env
|
env_file: .env
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
|
@ -67,6 +86,8 @@ services:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
ollama:
|
ollama:
|
||||||
condition: service_started
|
condition: service_started
|
||||||
|
migrations:
|
||||||
|
condition: service_completed_successfully
|
||||||
env_file: .env
|
env_file: .env
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
|
@ -82,6 +103,8 @@ services:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
migrations:
|
||||||
|
condition: service_completed_successfully
|
||||||
env_file: .env
|
env_file: .env
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
|
@ -97,6 +120,8 @@ services:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
migrations:
|
||||||
|
condition: service_completed_successfully
|
||||||
env_file: .env
|
env_file: .env
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
|
@ -112,6 +137,8 @@ services:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
migrations:
|
||||||
|
condition: service_completed_successfully
|
||||||
env_file: .env
|
env_file: .env
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
|
@ -127,6 +154,8 @@ services:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
migrations:
|
||||||
|
condition: service_completed_successfully
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
env_file: .env
|
env_file: .env
|
||||||
|
|
|
||||||
|
|
@ -19,13 +19,15 @@ COPY alembic/ alembic/
|
||||||
COPY alembic.ini .
|
COPY alembic.ini .
|
||||||
|
|
||||||
ARG EXTRAS="dev"
|
ARG EXTRAS="dev"
|
||||||
RUN pip install --no-cache-dir ".[$EXTRAS]"
|
RUN pip install --no-cache-dir ".[$EXTRAS]" && pip install --no-cache-dir curl_cffi 2>/dev/null || true
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Stage 2: slim runtime image
|
# Stage 2: slim runtime image
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
FROM python:3.12-slim
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Copy installed packages and CLI entry-points from the builder
|
# Copy installed packages and CLI entry-points from the builder
|
||||||
|
|
@ -37,9 +39,11 @@ COPY --from=builder /app .
|
||||||
|
|
||||||
ARG SERVICE_MODULE="api_gateway"
|
ARG SERVICE_MODULE="api_gateway"
|
||||||
ENV SERVICE_MODULE=${SERVICE_MODULE}
|
ENV SERVICE_MODULE=${SERVICE_MODULE}
|
||||||
|
ARG HEALTH_PORT="9090"
|
||||||
|
ENV HEALTH_PORT=${HEALTH_PORT}
|
||||||
|
|
||||||
# Simple health check — verify the Python process is running
|
# Check /metrics endpoint (all services expose it via OpenTelemetry)
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
|
||||||
CMD python -c "import sys; sys.exit(0)" || exit 1
|
CMD curl -sf http://localhost:${HEALTH_PORT}/metrics > /dev/null || exit 1
|
||||||
|
|
||||||
CMD python -m services.${SERVICE_MODULE}.main
|
CMD python -m services.${SERVICE_MODULE}.main
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -52,6 +53,21 @@ async def _get_db(request: Request):
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def _rate_limit_auth(request: Request) -> None:
|
||||||
|
"""Simple fixed-window rate limiter for auth endpoints (10 req/min per IP)."""
|
||||||
|
redis = request.app.state.redis
|
||||||
|
client_ip = request.client.host if request.client else "unknown"
|
||||||
|
key = f"rate_limit:auth:{client_ip}"
|
||||||
|
current = await redis.incr(key)
|
||||||
|
if current == 1:
|
||||||
|
await redis.expire(key, 60)
|
||||||
|
if current > 10:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail="Too many requests. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Registration
|
# Registration
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -62,6 +78,7 @@ async def register_begin(
|
||||||
body: RegisterRequest,
|
body: RegisterRequest,
|
||||||
request: Request,
|
request: Request,
|
||||||
config: ApiGatewayConfig = Depends(get_config),
|
config: ApiGatewayConfig = Depends(get_config),
|
||||||
|
_rl: None = Depends(_rate_limit_auth),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Generate WebAuthn registration options (challenge + relying party info).
|
"""Generate WebAuthn registration options (challenge + relying party info).
|
||||||
|
|
||||||
|
|
@ -125,6 +142,7 @@ async def register_begin(
|
||||||
async def register_complete(
|
async def register_complete(
|
||||||
request: Request,
|
request: Request,
|
||||||
config: ApiGatewayConfig = Depends(get_config),
|
config: ApiGatewayConfig = Depends(get_config),
|
||||||
|
_rl: None = Depends(_rate_limit_auth),
|
||||||
) -> TokenResponse:
|
) -> TokenResponse:
|
||||||
"""Verify WebAuthn registration response and store credential.
|
"""Verify WebAuthn registration response and store credential.
|
||||||
|
|
||||||
|
|
@ -211,8 +229,13 @@ async def login_begin(
|
||||||
body: LoginRequest,
|
body: LoginRequest,
|
||||||
request: Request,
|
request: Request,
|
||||||
config: ApiGatewayConfig = Depends(get_config),
|
config: ApiGatewayConfig = Depends(get_config),
|
||||||
|
_rl: None = Depends(_rate_limit_auth),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Generate WebAuthn authentication options for an existing user."""
|
"""Generate WebAuthn authentication options for an existing user.
|
||||||
|
|
||||||
|
Returns valid-looking options even if the user does not exist
|
||||||
|
to prevent username enumeration.
|
||||||
|
"""
|
||||||
import json as _json
|
import json as _json
|
||||||
|
|
||||||
redis = await _get_redis(request)
|
redis = await _get_redis(request)
|
||||||
|
|
@ -227,26 +250,22 @@ async def login_begin(
|
||||||
select(User).where(User.username == body.username)
|
select(User).where(User.username == body.username)
|
||||||
)
|
)
|
||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
if user is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="User not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
creds = (
|
allow_credentials = []
|
||||||
await session.execute(
|
if user is not None:
|
||||||
select(UserCredential).where(
|
creds = (
|
||||||
UserCredential.user_id == user.id
|
await session.execute(
|
||||||
|
select(UserCredential).where(
|
||||||
|
UserCredential.user_id == user.id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
).scalars().all()
|
||||||
).scalars().all()
|
allow_credentials = [
|
||||||
|
PublicKeyCredentialDescriptor(
|
||||||
allow_credentials = [
|
id=base64.urlsafe_b64decode(c.credential_id),
|
||||||
PublicKeyCredentialDescriptor(
|
)
|
||||||
id=base64.urlsafe_b64decode(c.credential_id),
|
for c in creds
|
||||||
)
|
]
|
||||||
for c in creds
|
|
||||||
]
|
|
||||||
|
|
||||||
options = generate_authentication_options(
|
options = generate_authentication_options(
|
||||||
rp_id=config.rp_id,
|
rp_id=config.rp_id,
|
||||||
|
|
@ -255,16 +274,21 @@ async def login_begin(
|
||||||
)
|
)
|
||||||
|
|
||||||
challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode()
|
challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode()
|
||||||
redis_key = f"webauthn:login:{body.username}"
|
|
||||||
await redis.setex(
|
if user is not None:
|
||||||
redis_key,
|
# Real user — store challenge for verification
|
||||||
300,
|
redis_key = f"webauthn:login:{body.username}"
|
||||||
_json.dumps({
|
await redis.setex(
|
||||||
"challenge": challenge_b64,
|
redis_key,
|
||||||
"user_id": str(user.id),
|
300,
|
||||||
"username": user.username,
|
_json.dumps({
|
||||||
}),
|
"challenge": challenge_b64,
|
||||||
)
|
"user_id": str(user.id),
|
||||||
|
"username": user.username,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
# If user doesn't exist, we still return options but don't store
|
||||||
|
# a challenge. login/complete will fail with a generic error.
|
||||||
|
|
||||||
from webauthn.helpers import options_to_json
|
from webauthn.helpers import options_to_json
|
||||||
|
|
||||||
|
|
@ -275,6 +299,7 @@ async def login_begin(
|
||||||
async def login_complete(
|
async def login_complete(
|
||||||
request: Request,
|
request: Request,
|
||||||
config: ApiGatewayConfig = Depends(get_config),
|
config: ApiGatewayConfig = Depends(get_config),
|
||||||
|
_rl: None = Depends(_rate_limit_auth),
|
||||||
) -> TokenResponse:
|
) -> TokenResponse:
|
||||||
"""Verify WebAuthn authentication response and issue JWT."""
|
"""Verify WebAuthn authentication response and issue JWT."""
|
||||||
import json as _json
|
import json as _json
|
||||||
|
|
@ -356,7 +381,11 @@ async def login_complete(
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh")
|
@router.post("/refresh")
|
||||||
async def refresh(request: Request, config: ApiGatewayConfig = Depends(get_config)) -> TokenResponse:
|
async def refresh(
|
||||||
|
request: Request,
|
||||||
|
config: ApiGatewayConfig = Depends(get_config),
|
||||||
|
_rl: None = Depends(_rate_limit_auth),
|
||||||
|
) -> TokenResponse:
|
||||||
"""Exchange a valid refresh token for a new access token."""
|
"""Exchange a valid refresh token for a new access token."""
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
refresh_token = body.get("refresh_token", "")
|
refresh_token = body.get("refresh_token", "")
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,9 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
|
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
|
||||||
|
|
||||||
|
# Store references to background tasks to prevent garbage collection
|
||||||
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
class BacktestRequest(BaseModel):
|
class BacktestRequest(BaseModel):
|
||||||
"""Request body for starting a new backtest."""
|
"""Request body for starting a new backtest."""
|
||||||
|
|
@ -56,8 +59,10 @@ async def run_backtest(
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Launch background task
|
# Launch background task (stored in set to prevent GC)
|
||||||
asyncio.create_task(_run_backtest_task(run_id, body, redis))
|
task = asyncio.create_task(_run_backtest_task(run_id, body, redis))
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
return {"run_id": run_id, "status": "running"}
|
return {"run_id": run_id, "status": "running"}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -67,27 +67,27 @@ async def list_trades(
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
trades = result.scalars().all()
|
trades = result.scalars().all()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"trades": [
|
"trades": [
|
||||||
{
|
{
|
||||||
"id": str(t.id),
|
"id": str(t.id),
|
||||||
"ticker": t.ticker,
|
"ticker": t.ticker,
|
||||||
"side": t.side.value,
|
"side": t.side.value,
|
||||||
"qty": t.qty,
|
"qty": t.qty,
|
||||||
"price": t.price,
|
"price": t.price,
|
||||||
"status": t.status.value,
|
"status": t.status.value,
|
||||||
"pnl": t.pnl,
|
"pnl": t.pnl,
|
||||||
"strategy_id": str(t.strategy_id) if t.strategy_id else None,
|
"strategy_id": str(t.strategy_id) if t.strategy_id else None,
|
||||||
"signal_id": str(t.signal_id) if t.signal_id else None,
|
"signal_id": str(t.signal_id) if t.signal_id else None,
|
||||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||||
}
|
}
|
||||||
for t in trades
|
for t in trades
|
||||||
],
|
],
|
||||||
"total": total,
|
"total": total,
|
||||||
"page": page,
|
"page": page,
|
||||||
"per_page": per_page,
|
"per_page": per_page,
|
||||||
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{trade_id}")
|
@router.get("/{trade_id}")
|
||||||
|
|
@ -105,21 +105,21 @@ async def get_trade(
|
||||||
await session.execute(select(Trade).where(Trade.id == trade_id))
|
await session.execute(select(Trade).where(Trade.id == trade_id))
|
||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
|
|
||||||
if trade is None:
|
if trade is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="Trade not found",
|
detail="Trade not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": str(trade.id),
|
"id": str(trade.id),
|
||||||
"ticker": trade.ticker,
|
"ticker": trade.ticker,
|
||||||
"side": trade.side.value,
|
"side": trade.side.value,
|
||||||
"qty": trade.qty,
|
"qty": trade.qty,
|
||||||
"price": trade.price,
|
"price": trade.price,
|
||||||
"status": trade.status.value,
|
"status": trade.status.value,
|
||||||
"pnl": trade.pnl,
|
"pnl": trade.pnl,
|
||||||
"strategy_id": str(trade.strategy_id) if trade.strategy_id else None,
|
"strategy_id": str(trade.strategy_id) if trade.strategy_id else None,
|
||||||
"signal_id": str(trade.signal_id) if trade.signal_id else None,
|
"signal_id": str(trade.signal_id) if trade.signal_id else None,
|
||||||
"created_at": trade.created_at.isoformat() if trade.created_at else None,
|
"created_at": trade.created_at.isoformat() if trade.created_at else None,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import signal
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
@ -275,24 +276,36 @@ async def run(config: LearningEngineConfig | None = None) -> None:
|
||||||
|
|
||||||
logger.info("Consuming from trades:executed")
|
logger.info("Consuming from trades:executed")
|
||||||
|
|
||||||
|
# Graceful shutdown on SIGTERM/SIGINT
|
||||||
|
shutdown_event = asyncio.Event()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||||
|
loop.add_signal_handler(sig, shutdown_event.set)
|
||||||
|
|
||||||
# --- Consume loop ---
|
# --- Consume loop ---
|
||||||
async for _msg_id, data in consumer.consume():
|
try:
|
||||||
try:
|
async for _msg_id, data in consumer.consume():
|
||||||
trade = TradeExecution.model_validate(data)
|
if shutdown_event.is_set():
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
trade = TradeExecution.model_validate(data)
|
||||||
|
|
||||||
if trade.status.value != "FILLED":
|
if trade.status.value != "FILLED":
|
||||||
logger.debug("Skipping non-filled trade: %s", trade.status.value)
|
logger.debug("Skipping non-filled trade: %s", trade.status.value)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
adjustments = await process_trade(trade, redis, evaluator, adjuster, counters)
|
adjustments = await process_trade(trade, redis, evaluator, adjuster, counters)
|
||||||
if adjustments:
|
if adjustments:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Made %d weight adjustment(s) for %s",
|
"Made %d weight adjustment(s) for %s",
|
||||||
len(adjustments),
|
len(adjustments),
|
||||||
trade.ticker,
|
trade.ticker,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error processing trade execution: %s", data)
|
logger.exception("Error processing trade execution: %s", data)
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
logger.info("Learning engine stopped gracefully")
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ to the ``news:raw`` Redis Stream.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import signal
|
||||||
|
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
|
@ -53,9 +54,10 @@ async def _poll_rss(
|
||||||
publisher: StreamPublisher,
|
publisher: StreamPublisher,
|
||||||
articles_fetched_counter,
|
articles_fetched_counter,
|
||||||
fetch_errors_counter,
|
fetch_errors_counter,
|
||||||
|
shutdown_event: asyncio.Event,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Continuously poll RSS feeds at *interval* seconds."""
|
"""Continuously poll RSS feeds at *interval* seconds."""
|
||||||
while True:
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
logger.info("Polling RSS feeds …")
|
logger.info("Polling RSS feeds …")
|
||||||
articles = await source.fetch()
|
articles = await source.fetch()
|
||||||
|
|
@ -66,7 +68,11 @@ async def _poll_rss(
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("RSS poll cycle failed")
|
logger.exception("RSS poll cycle failed")
|
||||||
fetch_errors_counter.add(1)
|
fetch_errors_counter.add(1)
|
||||||
await asyncio.sleep(interval)
|
try:
|
||||||
|
await asyncio.wait_for(shutdown_event.wait(), timeout=interval)
|
||||||
|
return # Shutdown signaled
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass # Normal timeout — continue polling
|
||||||
|
|
||||||
|
|
||||||
async def _poll_reddit(
|
async def _poll_reddit(
|
||||||
|
|
@ -76,9 +82,10 @@ async def _poll_reddit(
|
||||||
publisher: StreamPublisher,
|
publisher: StreamPublisher,
|
||||||
articles_fetched_counter,
|
articles_fetched_counter,
|
||||||
fetch_errors_counter,
|
fetch_errors_counter,
|
||||||
|
shutdown_event: asyncio.Event,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Continuously poll Reddit at *interval* seconds."""
|
"""Continuously poll Reddit at *interval* seconds."""
|
||||||
while True:
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
logger.info("Polling Reddit …")
|
logger.info("Polling Reddit …")
|
||||||
articles = await source.fetch()
|
articles = await source.fetch()
|
||||||
|
|
@ -89,7 +96,11 @@ async def _poll_reddit(
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Reddit poll cycle failed")
|
logger.exception("Reddit poll cycle failed")
|
||||||
fetch_errors_counter.add(1)
|
fetch_errors_counter.add(1)
|
||||||
await asyncio.sleep(interval)
|
try:
|
||||||
|
await asyncio.wait_for(shutdown_event.wait(), timeout=interval)
|
||||||
|
return # Shutdown signaled
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass # Normal timeout — continue polling
|
||||||
|
|
||||||
|
|
||||||
async def run() -> None:
|
async def run() -> None:
|
||||||
|
|
@ -124,28 +135,40 @@ async def run() -> None:
|
||||||
min_score=config.reddit_min_score,
|
min_score=config.reddit_min_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Graceful shutdown on SIGTERM/SIGINT
|
||||||
|
shutdown_event = asyncio.Event()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||||
|
loop.add_signal_handler(sig, shutdown_event.set)
|
||||||
|
|
||||||
# Run pollers concurrently
|
# Run pollers concurrently
|
||||||
async with asyncio.TaskGroup() as tg:
|
try:
|
||||||
tg.create_task(
|
async with asyncio.TaskGroup() as tg:
|
||||||
_poll_rss(
|
tg.create_task(
|
||||||
rss_source,
|
_poll_rss(
|
||||||
config.rss_poll_interval_seconds,
|
rss_source,
|
||||||
redis,
|
config.rss_poll_interval_seconds,
|
||||||
publisher,
|
redis,
|
||||||
articles_fetched_counter,
|
publisher,
|
||||||
fetch_errors_counter,
|
articles_fetched_counter,
|
||||||
|
fetch_errors_counter,
|
||||||
|
shutdown_event,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
tg.create_task(
|
||||||
tg.create_task(
|
_poll_reddit(
|
||||||
_poll_reddit(
|
reddit_source,
|
||||||
reddit_source,
|
config.reddit_poll_interval_seconds,
|
||||||
config.reddit_poll_interval_seconds,
|
redis,
|
||||||
redis,
|
publisher,
|
||||||
publisher,
|
articles_fetched_counter,
|
||||||
articles_fetched_counter,
|
fetch_errors_counter,
|
||||||
fetch_errors_counter,
|
shutdown_event,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
logger.info("News fetcher stopped gracefully")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import signal
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
|
|
@ -151,13 +152,25 @@ async def run(config: SentimentAnalyzerConfig | None = None) -> None:
|
||||||
|
|
||||||
logger.info("Consuming from news:raw, publishing to news:scored")
|
logger.info("Consuming from news:raw, publishing to news:scored")
|
||||||
|
|
||||||
|
# Graceful shutdown on SIGTERM/SIGINT
|
||||||
|
shutdown_event = asyncio.Event()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||||
|
loop.add_signal_handler(sig, shutdown_event.set)
|
||||||
|
|
||||||
# --- Consume loop ---
|
# --- Consume loop ---
|
||||||
async for _msg_id, data in consumer.consume():
|
try:
|
||||||
try:
|
async for _msg_id, data in consumer.consume():
|
||||||
article = RawArticle.model_validate(data)
|
if shutdown_event.is_set():
|
||||||
await process_article(article, finbert, ollama, publisher, config, counters)
|
break
|
||||||
except Exception:
|
try:
|
||||||
logger.exception("Error processing article: %s", data.get("title", "<unknown>"))
|
article = RawArticle.model_validate(data)
|
||||||
|
await process_article(article, finbert, ollama, publisher, config, counters)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error processing article: %s", data.get("title", "<unknown>"))
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
logger.info("Sentiment analyzer stopped gracefully")
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import signal
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
|
|
@ -101,59 +102,71 @@ async def run(config: SignalGeneratorConfig | None = None) -> None:
|
||||||
|
|
||||||
logger.info("Consuming from news:scored, publishing to signals:generated")
|
logger.info("Consuming from news:scored, publishing to signals:generated")
|
||||||
|
|
||||||
|
# Graceful shutdown on SIGTERM/SIGINT
|
||||||
|
shutdown_event = asyncio.Event()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||||
|
loop.add_signal_handler(sig, shutdown_event.set)
|
||||||
|
|
||||||
# --- Consume loop ---
|
# --- Consume loop ---
|
||||||
async for _msg_id, data in consumer.consume():
|
try:
|
||||||
try:
|
async for _msg_id, data in consumer.consume():
|
||||||
article = ScoredArticle.model_validate(data)
|
if shutdown_event.is_set():
|
||||||
ticker = article.ticker
|
break
|
||||||
|
try:
|
||||||
|
article = ScoredArticle.model_validate(data)
|
||||||
|
ticker = article.ticker
|
||||||
|
|
||||||
# Update sentiment accumulators
|
# Update sentiment accumulators
|
||||||
sentiment_scores[ticker].append(article.sentiment_score)
|
sentiment_scores[ticker].append(article.sentiment_score)
|
||||||
sentiment_confidences[ticker].append(article.confidence)
|
sentiment_confidences[ticker].append(article.confidence)
|
||||||
|
|
||||||
# Build sentiment context
|
# Build sentiment context
|
||||||
sentiment = _build_sentiment_context(
|
sentiment = _build_sentiment_context(
|
||||||
ticker,
|
|
||||||
sentiment_scores[ticker],
|
|
||||||
sentiment_confidences[ticker],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get market snapshot (may be None if no bars received yet)
|
|
||||||
snapshot = market_data.get_snapshot(ticker)
|
|
||||||
if snapshot is None:
|
|
||||||
# Create a minimal snapshot from sentiment data alone
|
|
||||||
# (the news_driven strategy does not require market indicators)
|
|
||||||
from shared.schemas.trading import MarketSnapshot
|
|
||||||
|
|
||||||
snapshot = MarketSnapshot(
|
|
||||||
ticker=ticker,
|
|
||||||
current_price=0.0,
|
|
||||||
open=0.0,
|
|
||||||
high=0.0,
|
|
||||||
low=0.0,
|
|
||||||
close=0.0,
|
|
||||||
volume=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run ensemble
|
|
||||||
signal = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
|
|
||||||
|
|
||||||
if signal is not None:
|
|
||||||
await publisher.publish(signal.model_dump(mode="json"))
|
|
||||||
signals_generated.add(1)
|
|
||||||
for src in signal.strategy_sources:
|
|
||||||
strategy_name = src.split(":")[0]
|
|
||||||
per_strategy_signal_count.add(1, {"strategy": strategy_name})
|
|
||||||
logger.info(
|
|
||||||
"Signal generated: %s %s strength=%.4f sources=%s",
|
|
||||||
signal.direction.value,
|
|
||||||
ticker,
|
ticker,
|
||||||
signal.strength,
|
sentiment_scores[ticker],
|
||||||
signal.strategy_sources,
|
sentiment_confidences[ticker],
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception:
|
# Get market snapshot (may be None if no bars received yet)
|
||||||
logger.exception("Error processing scored article: %s", data.get("title", "<unknown>"))
|
snapshot = market_data.get_snapshot(ticker)
|
||||||
|
if snapshot is None:
|
||||||
|
# Create a minimal snapshot from sentiment data alone
|
||||||
|
# (the news_driven strategy does not require market indicators)
|
||||||
|
from shared.schemas.trading import MarketSnapshot
|
||||||
|
|
||||||
|
snapshot = MarketSnapshot(
|
||||||
|
ticker=ticker,
|
||||||
|
current_price=0.0,
|
||||||
|
open=0.0,
|
||||||
|
high=0.0,
|
||||||
|
low=0.0,
|
||||||
|
close=0.0,
|
||||||
|
volume=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run ensemble
|
||||||
|
signal_result = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
|
||||||
|
|
||||||
|
if signal_result is not None:
|
||||||
|
await publisher.publish(signal_result.model_dump(mode="json"))
|
||||||
|
signals_generated.add(1)
|
||||||
|
for src in signal_result.strategy_sources:
|
||||||
|
strategy_name = src.split(":")[0]
|
||||||
|
per_strategy_signal_count.add(1, {"strategy": strategy_name})
|
||||||
|
logger.info(
|
||||||
|
"Signal generated: %s %s strength=%.4f sources=%s",
|
||||||
|
signal_result.direction.value,
|
||||||
|
ticker,
|
||||||
|
signal_result.strength,
|
||||||
|
signal_result.strategy_sources,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error processing scored article: %s", data.get("title", "<unknown>"))
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
logger.info("Signal generator stopped gracefully")
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import signal
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
@ -158,13 +159,25 @@ async def run(config: TradeExecutorConfig | None = None) -> None:
|
||||||
|
|
||||||
logger.info("Consuming from signals:generated, publishing to trades:executed")
|
logger.info("Consuming from signals:generated, publishing to trades:executed")
|
||||||
|
|
||||||
|
# Graceful shutdown on SIGTERM/SIGINT
|
||||||
|
shutdown_event = asyncio.Event()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||||
|
loop.add_signal_handler(sig, shutdown_event.set)
|
||||||
|
|
||||||
# --- Consume loop ---
|
# --- Consume loop ---
|
||||||
async for _msg_id, data in consumer.consume():
|
try:
|
||||||
try:
|
async for _msg_id, data in consumer.consume():
|
||||||
signal = TradeSignal.model_validate(data)
|
if shutdown_event.is_set():
|
||||||
await process_signal(signal, risk_manager, broker, publisher, counters)
|
break
|
||||||
except Exception:
|
try:
|
||||||
logger.exception("Error processing signal: %s", data)
|
signal_msg = TradeSignal.model_validate(data)
|
||||||
|
await process_signal(signal_msg, risk_manager, broker, publisher, counters)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error processing signal: %s", data)
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
logger.info("Trade executor stopped gracefully")
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
|
|
||||||
|
|
@ -40,9 +40,12 @@ class StreamConsumer:
|
||||||
try:
|
try:
|
||||||
await self.redis.xgroup_create(self.stream, self.group, id="0", mkstream=True)
|
await self.redis.xgroup_create(self.stream, self.group, id="0", mkstream=True)
|
||||||
logger.info("Created consumer group %s on %s", self.group, self.stream)
|
logger.info("Created consumer group %s on %s", self.group, self.stream)
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
# Group already exists — this is expected on subsequent starts.
|
# BUSYGROUP means group already exists — expected on subsequent starts.
|
||||||
pass
|
if "BUSYGROUP" in str(exc):
|
||||||
|
logger.debug("Consumer group %s already exists on %s", self.group, self.stream)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
async def consume(
|
async def consume(
|
||||||
self, batch_size: int = 10, block_ms: int = 5000
|
self, batch_size: int = 10, block_ms: int = 5000
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue