fix: resolve 13 important issues from code review

I1: Add graceful shutdown (SIGTERM/SIGINT) to all 5 background services
I2: Fix Dockerfile healthcheck to use curl on /metrics endpoint
I3: Fix StreamConsumer.ensure_group() to only catch BUSYGROUP errors
I4: Fix SimulatedBroker to reject orders with insufficient cash/shares
I5: Move ORM attribute access inside DB session context in trades routes
I6: Add Redis-based rate limiting (10 req/min/IP) on all auth endpoints
I8: Prevent backtest background task garbage collection
I9: Use Numeric(16,6) instead of Float for financial columns in migration
I10: Add index on trades.created_at for time-range queries
I11: Bind infrastructure ports to 127.0.0.1 in docker-compose
I12: Add migrations init service; all app services depend on it
I13: Fix user enumeration in login_begin (return options for non-existent users)
This commit is contained in:
Viktor Barzin 2026-02-22 17:58:01 +00:00
parent 2a56727267
commit 5a6b20c8f1
No known key found for this signature in database
GPG key ID: 0EB088298288D958
13 changed files with 355 additions and 188 deletions

View file

@ -66,8 +66,8 @@ def upgrade() -> None:
sa.Enum("BUY", "SELL", name="tradeside"), sa.Enum("BUY", "SELL", name="tradeside"),
nullable=False, nullable=False,
), ),
sa.Column("qty", sa.Float, nullable=False), sa.Column("qty", sa.Numeric(16, 6), nullable=False),
sa.Column("price", sa.Float, nullable=False), sa.Column("price", sa.Numeric(16, 6), nullable=False),
sa.Column("timestamp", sa.String, nullable=True), sa.Column("timestamp", sa.String, nullable=True),
sa.Column( sa.Column(
"strategy_id", "strategy_id",
@ -87,8 +87,8 @@ def upgrade() -> None:
nullable=False, nullable=False,
server_default="PENDING", server_default="PENDING",
), ),
sa.Column("pnl", sa.Float, nullable=True), sa.Column("pnl", sa.Numeric(16, 6), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), index=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
) )
@ -96,11 +96,11 @@ def upgrade() -> None:
"positions", "positions",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("ticker", sa.String(20), unique=True, nullable=False), sa.Column("ticker", sa.String(20), unique=True, nullable=False),
sa.Column("qty", sa.Float, nullable=False), sa.Column("qty", sa.Numeric(16, 6), nullable=False),
sa.Column("avg_entry", sa.Float, nullable=False), sa.Column("avg_entry", sa.Numeric(16, 6), nullable=False),
sa.Column("unrealized_pnl", sa.Float, nullable=True), sa.Column("unrealized_pnl", sa.Numeric(16, 6), nullable=True),
sa.Column("stop_loss", sa.Float, nullable=True), sa.Column("stop_loss", sa.Numeric(16, 6), nullable=True),
sa.Column("take_profit", sa.Float, nullable=True), sa.Column("take_profit", sa.Numeric(16, 6), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
) )

View file

@ -93,10 +93,32 @@ class SimulatedBroker(BaseBroker):
# Deduct / credit cash # Deduct / credit cash
if order.side == OrderSide.BUY: if order.side == OrderSide.BUY:
self.cash -= cost total_cost = cost + self.commission_per_trade
self.cash -= self.commission_per_trade if total_cost > self.cash:
return OrderResult(
order_id=str(uuid.uuid4()),
ticker=order.ticker,
side=order.side,
qty=order.qty,
filled_price=None,
status=OrderStatus.REJECTED,
timestamp=datetime.now(tz=timezone.utc),
)
self.cash -= total_cost
self._update_position_buy(order.ticker, order.qty, fill_price) self._update_position_buy(order.ticker, order.qty, fill_price)
else: else:
# Validate sufficient shares to sell
current_qty = self._positions.get(order.ticker, {}).get("qty", 0.0)
if order.qty > current_qty:
return OrderResult(
order_id=str(uuid.uuid4()),
ticker=order.ticker,
side=order.side,
qty=order.qty,
filled_price=None,
status=OrderStatus.REJECTED,
timestamp=datetime.now(tz=timezone.utc),
)
self.cash += cost self.cash += cost
self.cash -= self.commission_per_trade self.cash -= self.commission_per_trade
self._update_position_sell(order.ticker, order.qty) self._update_position_sell(order.ticker, order.qty)

View file

@ -9,7 +9,7 @@ services:
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-trading} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-trading}
POSTGRES_DB: trading POSTGRES_DB: trading
ports: ports:
- "5432:5432" - "127.0.0.1:5432:5432"
volumes: volumes:
- pgdata:/var/lib/postgresql/data - pgdata:/var/lib/postgresql/data
healthcheck: healthcheck:
@ -21,7 +21,7 @@ services:
redis: redis:
image: redis:7-alpine image: redis:7-alpine
ports: ports:
- "6379:6379" - "127.0.0.1:6379:6379"
volumes: volumes:
- redisdata:/data - redisdata:/data
healthcheck: healthcheck:
@ -33,10 +33,27 @@ services:
ollama: ollama:
image: ollama/ollama:latest image: ollama/ollama:latest
ports: ports:
- "11434:11434" - "127.0.0.1:11434:11434"
volumes: volumes:
- ollama_models:/root/.ollama - ollama_models:/root/.ollama
# ---------------------------------------------------------------------------
# Database migrations — runs once before application services start
# ---------------------------------------------------------------------------
migrations:
build:
context: .
dockerfile: docker/Dockerfile.service
args:
EXTRAS: "dev"
SERVICE_MODULE: "api_gateway"
depends_on:
postgres:
condition: service_healthy
env_file: .env
command: python -m alembic upgrade head
restart: "no"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Application services # Application services
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -52,6 +69,8 @@ services:
condition: service_healthy condition: service_healthy
redis: redis:
condition: service_healthy condition: service_healthy
migrations:
condition: service_completed_successfully
env_file: .env env_file: .env
restart: unless-stopped restart: unless-stopped
@ -67,6 +86,8 @@ services:
condition: service_healthy condition: service_healthy
ollama: ollama:
condition: service_started condition: service_started
migrations:
condition: service_completed_successfully
env_file: .env env_file: .env
restart: unless-stopped restart: unless-stopped
@ -82,6 +103,8 @@ services:
condition: service_healthy condition: service_healthy
redis: redis:
condition: service_healthy condition: service_healthy
migrations:
condition: service_completed_successfully
env_file: .env env_file: .env
restart: unless-stopped restart: unless-stopped
@ -97,6 +120,8 @@ services:
condition: service_healthy condition: service_healthy
redis: redis:
condition: service_healthy condition: service_healthy
migrations:
condition: service_completed_successfully
env_file: .env env_file: .env
restart: unless-stopped restart: unless-stopped
@ -112,6 +137,8 @@ services:
condition: service_healthy condition: service_healthy
redis: redis:
condition: service_healthy condition: service_healthy
migrations:
condition: service_completed_successfully
env_file: .env env_file: .env
restart: unless-stopped restart: unless-stopped
@ -127,6 +154,8 @@ services:
condition: service_healthy condition: service_healthy
redis: redis:
condition: service_healthy condition: service_healthy
migrations:
condition: service_completed_successfully
ports: ports:
- "8000:8000" - "8000:8000"
env_file: .env env_file: .env

View file

@ -19,13 +19,15 @@ COPY alembic/ alembic/
COPY alembic.ini . COPY alembic.ini .
ARG EXTRAS="dev" ARG EXTRAS="dev"
RUN pip install --no-cache-dir ".[$EXTRAS]" RUN pip install --no-cache-dir ".[$EXTRAS]" && pip install --no-cache-dir curl_cffi 2>/dev/null || true
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Stage 2: slim runtime image # Stage 2: slim runtime image
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
FROM python:3.12-slim FROM python:3.12-slim
RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/*
WORKDIR /app WORKDIR /app
# Copy installed packages and CLI entry-points from the builder # Copy installed packages and CLI entry-points from the builder
@ -37,9 +39,11 @@ COPY --from=builder /app .
ARG SERVICE_MODULE="api_gateway" ARG SERVICE_MODULE="api_gateway"
ENV SERVICE_MODULE=${SERVICE_MODULE} ENV SERVICE_MODULE=${SERVICE_MODULE}
ARG HEALTH_PORT="9090"
ENV HEALTH_PORT=${HEALTH_PORT}
# Simple health check — verify the Python process is running # Check /metrics endpoint (all services expose it via OpenTelemetry)
HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \ HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
CMD python -c "import sys; sys.exit(0)" || exit 1 CMD curl -sf http://localhost:${HEALTH_PORT}/metrics > /dev/null || exit 1
CMD python -m services.${SERVICE_MODULE}.main CMD python -m services.${SERVICE_MODULE}.main

View file

@ -4,6 +4,7 @@ from __future__ import annotations
import base64 import base64
import logging import logging
import os
import uuid import uuid
from typing import Any from typing import Any
@ -52,6 +53,21 @@ async def _get_db(request: Request):
yield session yield session
async def _rate_limit_auth(request: Request) -> None:
"""Simple fixed-window rate limiter for auth endpoints (10 req/min per IP)."""
redis = request.app.state.redis
client_ip = request.client.host if request.client else "unknown"
key = f"rate_limit:auth:{client_ip}"
current = await redis.incr(key)
if current == 1:
await redis.expire(key, 60)
if current > 10:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many requests. Please try again later.",
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Registration # Registration
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -62,6 +78,7 @@ async def register_begin(
body: RegisterRequest, body: RegisterRequest,
request: Request, request: Request,
config: ApiGatewayConfig = Depends(get_config), config: ApiGatewayConfig = Depends(get_config),
_rl: None = Depends(_rate_limit_auth),
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Generate WebAuthn registration options (challenge + relying party info). """Generate WebAuthn registration options (challenge + relying party info).
@ -125,6 +142,7 @@ async def register_begin(
async def register_complete( async def register_complete(
request: Request, request: Request,
config: ApiGatewayConfig = Depends(get_config), config: ApiGatewayConfig = Depends(get_config),
_rl: None = Depends(_rate_limit_auth),
) -> TokenResponse: ) -> TokenResponse:
"""Verify WebAuthn registration response and store credential. """Verify WebAuthn registration response and store credential.
@ -211,8 +229,13 @@ async def login_begin(
body: LoginRequest, body: LoginRequest,
request: Request, request: Request,
config: ApiGatewayConfig = Depends(get_config), config: ApiGatewayConfig = Depends(get_config),
_rl: None = Depends(_rate_limit_auth),
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Generate WebAuthn authentication options for an existing user.""" """Generate WebAuthn authentication options for an existing user.
Returns valid-looking options even if the user does not exist
to prevent username enumeration.
"""
import json as _json import json as _json
redis = await _get_redis(request) redis = await _get_redis(request)
@ -227,26 +250,22 @@ async def login_begin(
select(User).where(User.username == body.username) select(User).where(User.username == body.username)
) )
).scalar_one_or_none() ).scalar_one_or_none()
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
creds = ( allow_credentials = []
await session.execute( if user is not None:
select(UserCredential).where( creds = (
UserCredential.user_id == user.id await session.execute(
select(UserCredential).where(
UserCredential.user_id == user.id
)
) )
) ).scalars().all()
).scalars().all() allow_credentials = [
PublicKeyCredentialDescriptor(
allow_credentials = [ id=base64.urlsafe_b64decode(c.credential_id),
PublicKeyCredentialDescriptor( )
id=base64.urlsafe_b64decode(c.credential_id), for c in creds
) ]
for c in creds
]
options = generate_authentication_options( options = generate_authentication_options(
rp_id=config.rp_id, rp_id=config.rp_id,
@ -255,16 +274,21 @@ async def login_begin(
) )
challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode() challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode()
redis_key = f"webauthn:login:{body.username}"
await redis.setex( if user is not None:
redis_key, # Real user — store challenge for verification
300, redis_key = f"webauthn:login:{body.username}"
_json.dumps({ await redis.setex(
"challenge": challenge_b64, redis_key,
"user_id": str(user.id), 300,
"username": user.username, _json.dumps({
}), "challenge": challenge_b64,
) "user_id": str(user.id),
"username": user.username,
}),
)
# If user doesn't exist, we still return options but don't store
# a challenge. login/complete will fail with a generic error.
from webauthn.helpers import options_to_json from webauthn.helpers import options_to_json
@ -275,6 +299,7 @@ async def login_begin(
async def login_complete( async def login_complete(
request: Request, request: Request,
config: ApiGatewayConfig = Depends(get_config), config: ApiGatewayConfig = Depends(get_config),
_rl: None = Depends(_rate_limit_auth),
) -> TokenResponse: ) -> TokenResponse:
"""Verify WebAuthn authentication response and issue JWT.""" """Verify WebAuthn authentication response and issue JWT."""
import json as _json import json as _json
@ -356,7 +381,11 @@ async def login_complete(
@router.post("/refresh") @router.post("/refresh")
async def refresh(request: Request, config: ApiGatewayConfig = Depends(get_config)) -> TokenResponse: async def refresh(
request: Request,
config: ApiGatewayConfig = Depends(get_config),
_rl: None = Depends(_rate_limit_auth),
) -> TokenResponse:
"""Exchange a valid refresh token for a new access token.""" """Exchange a valid refresh token for a new access token."""
body = await request.json() body = await request.json()
refresh_token = body.get("refresh_token", "") refresh_token = body.get("refresh_token", "")

View file

@ -17,6 +17,9 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/backtest", tags=["backtest"]) router = APIRouter(prefix="/api/backtest", tags=["backtest"])
# Store references to background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task] = set()
class BacktestRequest(BaseModel): class BacktestRequest(BaseModel):
"""Request body for starting a new backtest.""" """Request body for starting a new backtest."""
@ -56,8 +59,10 @@ async def run_backtest(
}), }),
) )
# Launch background task # Launch background task (stored in set to prevent GC)
asyncio.create_task(_run_backtest_task(run_id, body, redis)) task = asyncio.create_task(_run_backtest_task(run_id, body, redis))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
return {"run_id": run_id, "status": "running"} return {"run_id": run_id, "status": "running"}

View file

@ -67,27 +67,27 @@ async def list_trades(
result = await session.execute(query) result = await session.execute(query)
trades = result.scalars().all() trades = result.scalars().all()
return { return {
"trades": [ "trades": [
{ {
"id": str(t.id), "id": str(t.id),
"ticker": t.ticker, "ticker": t.ticker,
"side": t.side.value, "side": t.side.value,
"qty": t.qty, "qty": t.qty,
"price": t.price, "price": t.price,
"status": t.status.value, "status": t.status.value,
"pnl": t.pnl, "pnl": t.pnl,
"strategy_id": str(t.strategy_id) if t.strategy_id else None, "strategy_id": str(t.strategy_id) if t.strategy_id else None,
"signal_id": str(t.signal_id) if t.signal_id else None, "signal_id": str(t.signal_id) if t.signal_id else None,
"created_at": t.created_at.isoformat() if t.created_at else None, "created_at": t.created_at.isoformat() if t.created_at else None,
} }
for t in trades for t in trades
], ],
"total": total, "total": total,
"page": page, "page": page,
"per_page": per_page, "per_page": per_page,
"pages": (total + per_page - 1) // per_page if per_page else 0, "pages": (total + per_page - 1) // per_page if per_page else 0,
} }
@router.get("/{trade_id}") @router.get("/{trade_id}")
@ -105,21 +105,21 @@ async def get_trade(
await session.execute(select(Trade).where(Trade.id == trade_id)) await session.execute(select(Trade).where(Trade.id == trade_id))
).scalar_one_or_none() ).scalar_one_or_none()
if trade is None: if trade is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Trade not found", detail="Trade not found",
) )
return { return {
"id": str(trade.id), "id": str(trade.id),
"ticker": trade.ticker, "ticker": trade.ticker,
"side": trade.side.value, "side": trade.side.value,
"qty": trade.qty, "qty": trade.qty,
"price": trade.price, "price": trade.price,
"status": trade.status.value, "status": trade.status.value,
"pnl": trade.pnl, "pnl": trade.pnl,
"strategy_id": str(trade.strategy_id) if trade.strategy_id else None, "strategy_id": str(trade.strategy_id) if trade.strategy_id else None,
"signal_id": str(trade.signal_id) if trade.signal_id else None, "signal_id": str(trade.signal_id) if trade.signal_id else None,
"created_at": trade.created_at.isoformat() if trade.created_at else None, "created_at": trade.created_at.isoformat() if trade.created_at else None,
} }

View file

@ -10,6 +10,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
import signal
from datetime import datetime, timezone from datetime import datetime, timezone
from uuid import UUID from uuid import UUID
@ -275,24 +276,36 @@ async def run(config: LearningEngineConfig | None = None) -> None:
logger.info("Consuming from trades:executed") logger.info("Consuming from trades:executed")
# Graceful shutdown on SIGTERM/SIGINT
shutdown_event = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, shutdown_event.set)
# --- Consume loop --- # --- Consume loop ---
async for _msg_id, data in consumer.consume(): try:
try: async for _msg_id, data in consumer.consume():
trade = TradeExecution.model_validate(data) if shutdown_event.is_set():
break
try:
trade = TradeExecution.model_validate(data)
if trade.status.value != "FILLED": if trade.status.value != "FILLED":
logger.debug("Skipping non-filled trade: %s", trade.status.value) logger.debug("Skipping non-filled trade: %s", trade.status.value)
continue continue
adjustments = await process_trade(trade, redis, evaluator, adjuster, counters) adjustments = await process_trade(trade, redis, evaluator, adjuster, counters)
if adjustments: if adjustments:
logger.info( logger.info(
"Made %d weight adjustment(s) for %s", "Made %d weight adjustment(s) for %s",
len(adjustments), len(adjustments),
trade.ticker, trade.ticker,
) )
except Exception: except Exception:
logger.exception("Error processing trade execution: %s", data) logger.exception("Error processing trade execution: %s", data)
finally:
await redis.aclose()
logger.info("Learning engine stopped gracefully")
def main() -> None: def main() -> None:

View file

@ -7,6 +7,7 @@ to the ``news:raw`` Redis Stream.
import asyncio import asyncio
import logging import logging
import signal
from redis.asyncio import Redis from redis.asyncio import Redis
@ -53,9 +54,10 @@ async def _poll_rss(
publisher: StreamPublisher, publisher: StreamPublisher,
articles_fetched_counter, articles_fetched_counter,
fetch_errors_counter, fetch_errors_counter,
shutdown_event: asyncio.Event,
) -> None: ) -> None:
"""Continuously poll RSS feeds at *interval* seconds.""" """Continuously poll RSS feeds at *interval* seconds."""
while True: while not shutdown_event.is_set():
try: try:
logger.info("Polling RSS feeds …") logger.info("Polling RSS feeds …")
articles = await source.fetch() articles = await source.fetch()
@ -66,7 +68,11 @@ async def _poll_rss(
except Exception: except Exception:
logger.exception("RSS poll cycle failed") logger.exception("RSS poll cycle failed")
fetch_errors_counter.add(1) fetch_errors_counter.add(1)
await asyncio.sleep(interval) try:
await asyncio.wait_for(shutdown_event.wait(), timeout=interval)
return # Shutdown signaled
except asyncio.TimeoutError:
pass # Normal timeout — continue polling
async def _poll_reddit( async def _poll_reddit(
@ -76,9 +82,10 @@ async def _poll_reddit(
publisher: StreamPublisher, publisher: StreamPublisher,
articles_fetched_counter, articles_fetched_counter,
fetch_errors_counter, fetch_errors_counter,
shutdown_event: asyncio.Event,
) -> None: ) -> None:
"""Continuously poll Reddit at *interval* seconds.""" """Continuously poll Reddit at *interval* seconds."""
while True: while not shutdown_event.is_set():
try: try:
logger.info("Polling Reddit …") logger.info("Polling Reddit …")
articles = await source.fetch() articles = await source.fetch()
@ -89,7 +96,11 @@ async def _poll_reddit(
except Exception: except Exception:
logger.exception("Reddit poll cycle failed") logger.exception("Reddit poll cycle failed")
fetch_errors_counter.add(1) fetch_errors_counter.add(1)
await asyncio.sleep(interval) try:
await asyncio.wait_for(shutdown_event.wait(), timeout=interval)
return # Shutdown signaled
except asyncio.TimeoutError:
pass # Normal timeout — continue polling
async def run() -> None: async def run() -> None:
@ -124,28 +135,40 @@ async def run() -> None:
min_score=config.reddit_min_score, min_score=config.reddit_min_score,
) )
# Graceful shutdown on SIGTERM/SIGINT
shutdown_event = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, shutdown_event.set)
# Run pollers concurrently # Run pollers concurrently
async with asyncio.TaskGroup() as tg: try:
tg.create_task( async with asyncio.TaskGroup() as tg:
_poll_rss( tg.create_task(
rss_source, _poll_rss(
config.rss_poll_interval_seconds, rss_source,
redis, config.rss_poll_interval_seconds,
publisher, redis,
articles_fetched_counter, publisher,
fetch_errors_counter, articles_fetched_counter,
fetch_errors_counter,
shutdown_event,
)
) )
) tg.create_task(
tg.create_task( _poll_reddit(
_poll_reddit( reddit_source,
reddit_source, config.reddit_poll_interval_seconds,
config.reddit_poll_interval_seconds, redis,
redis, publisher,
publisher, articles_fetched_counter,
articles_fetched_counter, fetch_errors_counter,
fetch_errors_counter, shutdown_event,
)
) )
) finally:
await redis.aclose()
logger.info("News fetcher stopped gracefully")
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -10,6 +10,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import signal
import time import time
from redis.asyncio import Redis from redis.asyncio import Redis
@ -151,13 +152,25 @@ async def run(config: SentimentAnalyzerConfig | None = None) -> None:
logger.info("Consuming from news:raw, publishing to news:scored") logger.info("Consuming from news:raw, publishing to news:scored")
# Graceful shutdown on SIGTERM/SIGINT
shutdown_event = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, shutdown_event.set)
# --- Consume loop --- # --- Consume loop ---
async for _msg_id, data in consumer.consume(): try:
try: async for _msg_id, data in consumer.consume():
article = RawArticle.model_validate(data) if shutdown_event.is_set():
await process_article(article, finbert, ollama, publisher, config, counters) break
except Exception: try:
logger.exception("Error processing article: %s", data.get("title", "<unknown>")) article = RawArticle.model_validate(data)
await process_article(article, finbert, ollama, publisher, config, counters)
except Exception:
logger.exception("Error processing article: %s", data.get("title", "<unknown>"))
finally:
await redis.aclose()
logger.info("Sentiment analyzer stopped gracefully")
def main() -> None: def main() -> None:

View file

@ -9,6 +9,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import signal
from collections import defaultdict, deque from collections import defaultdict, deque
from redis.asyncio import Redis from redis.asyncio import Redis
@ -101,59 +102,71 @@ async def run(config: SignalGeneratorConfig | None = None) -> None:
logger.info("Consuming from news:scored, publishing to signals:generated") logger.info("Consuming from news:scored, publishing to signals:generated")
# Graceful shutdown on SIGTERM/SIGINT
shutdown_event = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, shutdown_event.set)
# --- Consume loop --- # --- Consume loop ---
async for _msg_id, data in consumer.consume(): try:
try: async for _msg_id, data in consumer.consume():
article = ScoredArticle.model_validate(data) if shutdown_event.is_set():
ticker = article.ticker break
try:
article = ScoredArticle.model_validate(data)
ticker = article.ticker
# Update sentiment accumulators # Update sentiment accumulators
sentiment_scores[ticker].append(article.sentiment_score) sentiment_scores[ticker].append(article.sentiment_score)
sentiment_confidences[ticker].append(article.confidence) sentiment_confidences[ticker].append(article.confidence)
# Build sentiment context # Build sentiment context
sentiment = _build_sentiment_context( sentiment = _build_sentiment_context(
ticker,
sentiment_scores[ticker],
sentiment_confidences[ticker],
)
# Get market snapshot (may be None if no bars received yet)
snapshot = market_data.get_snapshot(ticker)
if snapshot is None:
# Create a minimal snapshot from sentiment data alone
# (the news_driven strategy does not require market indicators)
from shared.schemas.trading import MarketSnapshot
snapshot = MarketSnapshot(
ticker=ticker,
current_price=0.0,
open=0.0,
high=0.0,
low=0.0,
close=0.0,
volume=0.0,
)
# Run ensemble
signal = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
if signal is not None:
await publisher.publish(signal.model_dump(mode="json"))
signals_generated.add(1)
for src in signal.strategy_sources:
strategy_name = src.split(":")[0]
per_strategy_signal_count.add(1, {"strategy": strategy_name})
logger.info(
"Signal generated: %s %s strength=%.4f sources=%s",
signal.direction.value,
ticker, ticker,
signal.strength, sentiment_scores[ticker],
signal.strategy_sources, sentiment_confidences[ticker],
) )
except Exception: # Get market snapshot (may be None if no bars received yet)
logger.exception("Error processing scored article: %s", data.get("title", "<unknown>")) snapshot = market_data.get_snapshot(ticker)
if snapshot is None:
# Create a minimal snapshot from sentiment data alone
# (the news_driven strategy does not require market indicators)
from shared.schemas.trading import MarketSnapshot
snapshot = MarketSnapshot(
ticker=ticker,
current_price=0.0,
open=0.0,
high=0.0,
low=0.0,
close=0.0,
volume=0.0,
)
# Run ensemble
signal_result = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
if signal_result is not None:
await publisher.publish(signal_result.model_dump(mode="json"))
signals_generated.add(1)
for src in signal_result.strategy_sources:
strategy_name = src.split(":")[0]
per_strategy_signal_count.add(1, {"strategy": strategy_name})
logger.info(
"Signal generated: %s %s strength=%.4f sources=%s",
signal_result.direction.value,
ticker,
signal_result.strength,
signal_result.strategy_sources,
)
except Exception:
logger.exception("Error processing scored article: %s", data.get("title", "<unknown>"))
finally:
await redis.aclose()
logger.info("Signal generator stopped gracefully")
def main() -> None: def main() -> None:

View file

@ -10,6 +10,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import signal
import time import time
import uuid import uuid
@ -158,13 +159,25 @@ async def run(config: TradeExecutorConfig | None = None) -> None:
logger.info("Consuming from signals:generated, publishing to trades:executed") logger.info("Consuming from signals:generated, publishing to trades:executed")
# Graceful shutdown on SIGTERM/SIGINT
shutdown_event = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, shutdown_event.set)
# --- Consume loop --- # --- Consume loop ---
async for _msg_id, data in consumer.consume(): try:
try: async for _msg_id, data in consumer.consume():
signal = TradeSignal.model_validate(data) if shutdown_event.is_set():
await process_signal(signal, risk_manager, broker, publisher, counters) break
except Exception: try:
logger.exception("Error processing signal: %s", data) signal_msg = TradeSignal.model_validate(data)
await process_signal(signal_msg, risk_manager, broker, publisher, counters)
except Exception:
logger.exception("Error processing signal: %s", data)
finally:
await redis.aclose()
logger.info("Trade executor stopped gracefully")
def main() -> None: def main() -> None:

View file

@ -40,9 +40,12 @@ class StreamConsumer:
try: try:
await self.redis.xgroup_create(self.stream, self.group, id="0", mkstream=True) await self.redis.xgroup_create(self.stream, self.group, id="0", mkstream=True)
logger.info("Created consumer group %s on %s", self.group, self.stream) logger.info("Created consumer group %s on %s", self.group, self.stream)
except Exception: except Exception as exc:
# Group already exists — this is expected on subsequent starts. # BUSYGROUP means group already exists — expected on subsequent starts.
pass if "BUSYGROUP" in str(exc):
logger.debug("Consumer group %s already exists on %s", self.group, self.stream)
else:
raise
async def consume( async def consume(
self, batch_size: int = 10, block_ms: int = 5000 self, batch_size: int = 10, block_ms: int = 5000