diff --git a/alembic/versions/a1b2c3d4e5f6_initial_schema.py b/alembic/versions/a1b2c3d4e5f6_initial_schema.py index 3c71fea..6a04b4c 100644 --- a/alembic/versions/a1b2c3d4e5f6_initial_schema.py +++ b/alembic/versions/a1b2c3d4e5f6_initial_schema.py @@ -66,8 +66,8 @@ def upgrade() -> None: sa.Enum("BUY", "SELL", name="tradeside"), nullable=False, ), - sa.Column("qty", sa.Float, nullable=False), - sa.Column("price", sa.Float, nullable=False), + sa.Column("qty", sa.Numeric(16, 6), nullable=False), + sa.Column("price", sa.Numeric(16, 6), nullable=False), sa.Column("timestamp", sa.String, nullable=True), sa.Column( "strategy_id", @@ -87,8 +87,8 @@ def upgrade() -> None: nullable=False, server_default="PENDING", ), - sa.Column("pnl", sa.Float, nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("pnl", sa.Numeric(16, 6), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), index=True), sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), ) @@ -96,11 +96,11 @@ def upgrade() -> None: "positions", sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), sa.Column("ticker", sa.String(20), unique=True, nullable=False), - sa.Column("qty", sa.Float, nullable=False), - sa.Column("avg_entry", sa.Float, nullable=False), - sa.Column("unrealized_pnl", sa.Float, nullable=True), - sa.Column("stop_loss", sa.Float, nullable=True), - sa.Column("take_profit", sa.Float, nullable=True), + sa.Column("qty", sa.Numeric(16, 6), nullable=False), + sa.Column("avg_entry", sa.Numeric(16, 6), nullable=False), + sa.Column("unrealized_pnl", sa.Numeric(16, 6), nullable=True), + sa.Column("stop_loss", sa.Numeric(16, 6), nullable=True), + sa.Column("take_profit", sa.Numeric(16, 6), nullable=True), sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), ) diff --git a/backtester/simulated_broker.py b/backtester/simulated_broker.py index 33965e2..3237268 100644 --- a/backtester/simulated_broker.py +++ b/backtester/simulated_broker.py @@ -93,10 +93,32 @@ class SimulatedBroker(BaseBroker): # Deduct / credit cash if order.side == OrderSide.BUY: - self.cash -= cost - self.cash -= self.commission_per_trade + total_cost = cost + self.commission_per_trade + if total_cost > self.cash: + return OrderResult( + order_id=str(uuid.uuid4()), + ticker=order.ticker, + side=order.side, + qty=order.qty, + filled_price=None, + status=OrderStatus.REJECTED, + timestamp=datetime.now(tz=timezone.utc), + ) + self.cash -= total_cost self._update_position_buy(order.ticker, order.qty, fill_price) else: + # Validate sufficient shares to sell + current_qty = self._positions.get(order.ticker, {}).get("qty", 0.0) + if order.qty > current_qty: + return OrderResult( + order_id=str(uuid.uuid4()), + ticker=order.ticker, + side=order.side, + qty=order.qty, + filled_price=None, + status=OrderStatus.REJECTED, + timestamp=datetime.now(tz=timezone.utc), + ) self.cash += cost self.cash -= self.commission_per_trade self._update_position_sell(order.ticker, order.qty) diff --git a/docker-compose.yml b/docker-compose.yml index f772ca4..d070b0b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,7 +9,7 @@ services: POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-trading} POSTGRES_DB: trading ports: - - "5432:5432" + - "127.0.0.1:5432:5432" volumes: - pgdata:/var/lib/postgresql/data healthcheck: @@ -21,7 +21,7 @@ services: redis: image: redis:7-alpine ports: - - "6379:6379" + - "127.0.0.1:6379:6379" volumes: - redisdata:/data healthcheck: @@ -33,10 +33,27 @@ services: ollama: image: ollama/ollama:latest ports: - - "11434:11434" + - "127.0.0.1:11434:11434" volumes: - ollama_models:/root/.ollama + # --------------------------------------------------------------------------- + # Database migrations — runs once before application services start + # --------------------------------------------------------------------------- + migrations: + build: + context: . + dockerfile: docker/Dockerfile.service + args: + EXTRAS: "dev" + SERVICE_MODULE: "api_gateway" + depends_on: + postgres: + condition: service_healthy + env_file: .env + command: python -m alembic upgrade head + restart: "no" + # --------------------------------------------------------------------------- # Application services # --------------------------------------------------------------------------- @@ -52,6 +69,8 @@ services: condition: service_healthy redis: condition: service_healthy + migrations: + condition: service_completed_successfully env_file: .env restart: unless-stopped @@ -67,6 +86,8 @@ services: condition: service_healthy ollama: condition: service_started + migrations: + condition: service_completed_successfully env_file: .env restart: unless-stopped @@ -82,6 +103,8 @@ services: condition: service_healthy redis: condition: service_healthy + migrations: + condition: service_completed_successfully env_file: .env restart: unless-stopped @@ -97,6 +120,8 @@ services: condition: service_healthy redis: condition: service_healthy + migrations: + condition: service_completed_successfully env_file: .env restart: unless-stopped @@ -112,6 +137,8 @@ services: condition: service_healthy redis: condition: service_healthy + migrations: + condition: service_completed_successfully env_file: .env restart: unless-stopped @@ -127,6 +154,8 @@ services: condition: service_healthy redis: condition: service_healthy + migrations: + condition: service_completed_successfully ports: - "8000:8000" env_file: .env diff --git a/docker/Dockerfile.service b/docker/Dockerfile.service index 485887b..000995c 100644 --- a/docker/Dockerfile.service +++ b/docker/Dockerfile.service @@ -19,13 +19,15 @@ COPY alembic/ alembic/ COPY alembic.ini . ARG EXTRAS="dev" -RUN pip install --no-cache-dir ".[$EXTRAS]" +RUN pip install --no-cache-dir ".[$EXTRAS]" && pip install --no-cache-dir curl_cffi 2>/dev/null || true # --------------------------------------------------------------------------- # Stage 2: slim runtime image # --------------------------------------------------------------------------- FROM python:3.12-slim +RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/* + WORKDIR /app # Copy installed packages and CLI entry-points from the builder @@ -37,9 +39,11 @@ COPY --from=builder /app . ARG SERVICE_MODULE="api_gateway" ENV SERVICE_MODULE=${SERVICE_MODULE} +ARG HEALTH_PORT="9090" +ENV HEALTH_PORT=${HEALTH_PORT} -# Simple health check — verify the Python process is running -HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \ - CMD python -c "import sys; sys.exit(0)" || exit 1 +# Check /metrics endpoint (all services expose it via OpenTelemetry) +HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \ + CMD curl -sf http://localhost:${HEALTH_PORT}/metrics > /dev/null || exit 1 CMD python -m services.${SERVICE_MODULE}.main diff --git a/services/api_gateway/auth/routes.py b/services/api_gateway/auth/routes.py index 5d9ac5c..362b276 100644 --- a/services/api_gateway/auth/routes.py +++ b/services/api_gateway/auth/routes.py @@ -4,6 +4,7 @@ from __future__ import annotations import base64 import logging +import os import uuid from typing import Any @@ -52,6 +53,21 @@ async def _get_db(request: Request): yield session +async def _rate_limit_auth(request: Request) -> None: + """Simple fixed-window rate limiter for auth endpoints (10 req/min per IP).""" + redis = request.app.state.redis + client_ip = request.client.host if request.client else "unknown" + key = f"rate_limit:auth:{client_ip}" + current = await redis.incr(key) + if current == 1: + await redis.expire(key, 60) + if current > 10: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Too many requests. Please try again later.", + ) + + # --------------------------------------------------------------------------- # Registration # --------------------------------------------------------------------------- @@ -62,6 +78,7 @@ async def register_begin( body: RegisterRequest, request: Request, config: ApiGatewayConfig = Depends(get_config), + _rl: None = Depends(_rate_limit_auth), ) -> dict[str, Any]: """Generate WebAuthn registration options (challenge + relying party info). @@ -125,6 +142,7 @@ async def register_begin( async def register_complete( request: Request, config: ApiGatewayConfig = Depends(get_config), + _rl: None = Depends(_rate_limit_auth), ) -> TokenResponse: """Verify WebAuthn registration response and store credential. @@ -211,8 +229,13 @@ async def login_begin( body: LoginRequest, request: Request, config: ApiGatewayConfig = Depends(get_config), + _rl: None = Depends(_rate_limit_auth), ) -> dict[str, Any]: - """Generate WebAuthn authentication options for an existing user.""" + """Generate WebAuthn authentication options for an existing user. + + Returns valid-looking options even if the user does not exist + to prevent username enumeration. + """ import json as _json redis = await _get_redis(request) @@ -227,26 +250,22 @@ async def login_begin( select(User).where(User.username == body.username) ) ).scalar_one_or_none() - if user is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found", - ) - creds = ( - await session.execute( - select(UserCredential).where( - UserCredential.user_id == user.id + allow_credentials = [] + if user is not None: + creds = ( + await session.execute( + select(UserCredential).where( + UserCredential.user_id == user.id + ) ) - ) - ).scalars().all() - - allow_credentials = [ - PublicKeyCredentialDescriptor( - id=base64.urlsafe_b64decode(c.credential_id), - ) - for c in creds - ] + ).scalars().all() + allow_credentials = [ + PublicKeyCredentialDescriptor( + id=base64.urlsafe_b64decode(c.credential_id), + ) + for c in creds + ] options = generate_authentication_options( rp_id=config.rp_id, @@ -255,16 +274,21 @@ async def login_begin( ) challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode() - redis_key = f"webauthn:login:{body.username}" - await redis.setex( - redis_key, - 300, - _json.dumps({ - "challenge": challenge_b64, - "user_id": str(user.id), - "username": user.username, - }), - ) + + if user is not None: + # Real user — store challenge for verification + redis_key = f"webauthn:login:{body.username}" + await redis.setex( + redis_key, + 300, + _json.dumps({ + "challenge": challenge_b64, + "user_id": str(user.id), + "username": user.username, + }), + ) + # If user doesn't exist, we still return options but don't store + # a challenge. login/complete will fail with a generic error. from webauthn.helpers import options_to_json @@ -275,6 +299,7 @@ async def login_begin( async def login_complete( request: Request, config: ApiGatewayConfig = Depends(get_config), + _rl: None = Depends(_rate_limit_auth), ) -> TokenResponse: """Verify WebAuthn authentication response and issue JWT.""" import json as _json @@ -356,7 +381,11 @@ async def login_complete( @router.post("/refresh") -async def refresh(request: Request, config: ApiGatewayConfig = Depends(get_config)) -> TokenResponse: +async def refresh( + request: Request, + config: ApiGatewayConfig = Depends(get_config), + _rl: None = Depends(_rate_limit_auth), +) -> TokenResponse: """Exchange a valid refresh token for a new access token.""" body = await request.json() refresh_token = body.get("refresh_token", "") diff --git a/services/api_gateway/routes/backtest.py b/services/api_gateway/routes/backtest.py index 32ebebc..9050864 100644 --- a/services/api_gateway/routes/backtest.py +++ b/services/api_gateway/routes/backtest.py @@ -17,6 +17,9 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/backtest", tags=["backtest"]) +# Store references to background tasks to prevent garbage collection +_background_tasks: set[asyncio.Task] = set() + class BacktestRequest(BaseModel): """Request body for starting a new backtest.""" @@ -56,8 +59,10 @@ async def run_backtest( }), ) - # Launch background task - asyncio.create_task(_run_backtest_task(run_id, body, redis)) + # Launch background task (stored in set to prevent GC) + task = asyncio.create_task(_run_backtest_task(run_id, body, redis)) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) return {"run_id": run_id, "status": "running"} diff --git a/services/api_gateway/routes/trades.py b/services/api_gateway/routes/trades.py index 00f1016..abb3fc4 100644 --- a/services/api_gateway/routes/trades.py +++ b/services/api_gateway/routes/trades.py @@ -67,27 +67,27 @@ async def list_trades( result = await session.execute(query) trades = result.scalars().all() - return { - "trades": [ - { - "id": str(t.id), - "ticker": t.ticker, - "side": t.side.value, - "qty": t.qty, - "price": t.price, - "status": t.status.value, - "pnl": t.pnl, - "strategy_id": str(t.strategy_id) if t.strategy_id else None, - "signal_id": str(t.signal_id) if t.signal_id else None, - "created_at": t.created_at.isoformat() if t.created_at else None, - } - for t in trades - ], - "total": total, - "page": page, - "per_page": per_page, - "pages": (total + per_page - 1) // per_page if per_page else 0, - } + return { + "trades": [ + { + "id": str(t.id), + "ticker": t.ticker, + "side": t.side.value, + "qty": t.qty, + "price": t.price, + "status": t.status.value, + "pnl": t.pnl, + "strategy_id": str(t.strategy_id) if t.strategy_id else None, + "signal_id": str(t.signal_id) if t.signal_id else None, + "created_at": t.created_at.isoformat() if t.created_at else None, + } + for t in trades + ], + "total": total, + "page": page, + "per_page": per_page, + "pages": (total + per_page - 1) // per_page if per_page else 0, + } @router.get("/{trade_id}") @@ -105,21 +105,21 @@ async def get_trade( await session.execute(select(Trade).where(Trade.id == trade_id)) ).scalar_one_or_none() - if trade is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Trade not found", - ) + if trade is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Trade not found", + ) - return { - "id": str(trade.id), - "ticker": trade.ticker, - "side": trade.side.value, - "qty": trade.qty, - "price": trade.price, - "status": trade.status.value, - "pnl": trade.pnl, - "strategy_id": str(trade.strategy_id) if trade.strategy_id else None, - "signal_id": str(trade.signal_id) if trade.signal_id else None, - "created_at": trade.created_at.isoformat() if trade.created_at else None, - } + return { + "id": str(trade.id), + "ticker": trade.ticker, + "side": trade.side.value, + "qty": trade.qty, + "price": trade.price, + "status": trade.status.value, + "pnl": trade.pnl, + "strategy_id": str(trade.strategy_id) if trade.strategy_id else None, + "signal_id": str(trade.signal_id) if trade.signal_id else None, + "created_at": trade.created_at.isoformat() if trade.created_at else None, + } diff --git a/services/learning_engine/main.py b/services/learning_engine/main.py index 30de800..d877597 100644 --- a/services/learning_engine/main.py +++ b/services/learning_engine/main.py @@ -10,6 +10,7 @@ from __future__ import annotations import asyncio import json import logging +import signal from datetime import datetime, timezone from uuid import UUID @@ -275,24 +276,36 @@ async def run(config: LearningEngineConfig | None = None) -> None: logger.info("Consuming from trades:executed") + # Graceful shutdown on SIGTERM/SIGINT + shutdown_event = asyncio.Event() + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, shutdown_event.set) + # --- Consume loop --- - async for _msg_id, data in consumer.consume(): - try: - trade = TradeExecution.model_validate(data) + try: + async for _msg_id, data in consumer.consume(): + if shutdown_event.is_set(): + break + try: + trade = TradeExecution.model_validate(data) - if trade.status.value != "FILLED": - logger.debug("Skipping non-filled trade: %s", trade.status.value) - continue + if trade.status.value != "FILLED": + logger.debug("Skipping non-filled trade: %s", trade.status.value) + continue - adjustments = await process_trade(trade, redis, evaluator, adjuster, counters) - if adjustments: - logger.info( - "Made %d weight adjustment(s) for %s", - len(adjustments), - trade.ticker, - ) - except Exception: - logger.exception("Error processing trade execution: %s", data) + adjustments = await process_trade(trade, redis, evaluator, adjuster, counters) + if adjustments: + logger.info( + "Made %d weight adjustment(s) for %s", + len(adjustments), + trade.ticker, + ) + except Exception: + logger.exception("Error processing trade execution: %s", data) + finally: + await redis.aclose() + logger.info("Learning engine stopped gracefully") def main() -> None: diff --git a/services/news_fetcher/main.py b/services/news_fetcher/main.py index e90603d..87e1bf7 100644 --- a/services/news_fetcher/main.py +++ b/services/news_fetcher/main.py @@ -7,6 +7,7 @@ to the ``news:raw`` Redis Stream. import asyncio import logging +import signal from redis.asyncio import Redis @@ -53,9 +54,10 @@ async def _poll_rss( publisher: StreamPublisher, articles_fetched_counter, fetch_errors_counter, + shutdown_event: asyncio.Event, ) -> None: """Continuously poll RSS feeds at *interval* seconds.""" - while True: + while not shutdown_event.is_set(): try: logger.info("Polling RSS feeds …") articles = await source.fetch() @@ -66,7 +68,11 @@ async def _poll_rss( except Exception: logger.exception("RSS poll cycle failed") fetch_errors_counter.add(1) - await asyncio.sleep(interval) + try: + await asyncio.wait_for(shutdown_event.wait(), timeout=interval) + return # Shutdown signaled + except asyncio.TimeoutError: + pass # Normal timeout — continue polling async def _poll_reddit( @@ -76,9 +82,10 @@ async def _poll_reddit( publisher: StreamPublisher, articles_fetched_counter, fetch_errors_counter, + shutdown_event: asyncio.Event, ) -> None: """Continuously poll Reddit at *interval* seconds.""" - while True: + while not shutdown_event.is_set(): try: logger.info("Polling Reddit …") articles = await source.fetch() @@ -89,7 +96,11 @@ async def _poll_reddit( except Exception: logger.exception("Reddit poll cycle failed") fetch_errors_counter.add(1) - await asyncio.sleep(interval) + try: + await asyncio.wait_for(shutdown_event.wait(), timeout=interval) + return # Shutdown signaled + except asyncio.TimeoutError: + pass # Normal timeout — continue polling async def run() -> None: @@ -124,28 +135,40 @@ async def run() -> None: min_score=config.reddit_min_score, ) + # Graceful shutdown on SIGTERM/SIGINT + shutdown_event = asyncio.Event() + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, shutdown_event.set) + # Run pollers concurrently - async with asyncio.TaskGroup() as tg: - tg.create_task( - _poll_rss( - rss_source, - config.rss_poll_interval_seconds, - redis, - publisher, - articles_fetched_counter, - fetch_errors_counter, + try: + async with asyncio.TaskGroup() as tg: + tg.create_task( + _poll_rss( + rss_source, + config.rss_poll_interval_seconds, + redis, + publisher, + articles_fetched_counter, + fetch_errors_counter, + shutdown_event, + ) ) - ) - tg.create_task( - _poll_reddit( - reddit_source, - config.reddit_poll_interval_seconds, - redis, - publisher, - articles_fetched_counter, - fetch_errors_counter, + tg.create_task( + _poll_reddit( + reddit_source, + config.reddit_poll_interval_seconds, + redis, + publisher, + articles_fetched_counter, + fetch_errors_counter, + shutdown_event, + ) ) - ) + finally: + await redis.aclose() + logger.info("News fetcher stopped gracefully") if __name__ == "__main__": diff --git a/services/sentiment_analyzer/main.py b/services/sentiment_analyzer/main.py index ded27fc..a88c1b1 100644 --- a/services/sentiment_analyzer/main.py +++ b/services/sentiment_analyzer/main.py @@ -10,6 +10,7 @@ from __future__ import annotations import asyncio import logging +import signal import time from redis.asyncio import Redis @@ -151,13 +152,25 @@ async def run(config: SentimentAnalyzerConfig | None = None) -> None: logger.info("Consuming from news:raw, publishing to news:scored") + # Graceful shutdown on SIGTERM/SIGINT + shutdown_event = asyncio.Event() + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, shutdown_event.set) + # --- Consume loop --- - async for _msg_id, data in consumer.consume(): - try: - article = RawArticle.model_validate(data) - await process_article(article, finbert, ollama, publisher, config, counters) - except Exception: - logger.exception("Error processing article: %s", data.get("title", "")) + try: + async for _msg_id, data in consumer.consume(): + if shutdown_event.is_set(): + break + try: + article = RawArticle.model_validate(data) + await process_article(article, finbert, ollama, publisher, config, counters) + except Exception: + logger.exception("Error processing article: %s", data.get("title", "")) + finally: + await redis.aclose() + logger.info("Sentiment analyzer stopped gracefully") def main() -> None: diff --git a/services/signal_generator/main.py b/services/signal_generator/main.py index ae6d735..e8a04d0 100644 --- a/services/signal_generator/main.py +++ b/services/signal_generator/main.py @@ -9,6 +9,7 @@ from __future__ import annotations import asyncio import logging +import signal from collections import defaultdict, deque from redis.asyncio import Redis @@ -101,59 +102,71 @@ async def run(config: SignalGeneratorConfig | None = None) -> None: logger.info("Consuming from news:scored, publishing to signals:generated") + # Graceful shutdown on SIGTERM/SIGINT + shutdown_event = asyncio.Event() + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, shutdown_event.set) + # --- Consume loop --- - async for _msg_id, data in consumer.consume(): - try: - article = ScoredArticle.model_validate(data) - ticker = article.ticker + try: + async for _msg_id, data in consumer.consume(): + if shutdown_event.is_set(): + break + try: + article = ScoredArticle.model_validate(data) + ticker = article.ticker - # Update sentiment accumulators - sentiment_scores[ticker].append(article.sentiment_score) - sentiment_confidences[ticker].append(article.confidence) + # Update sentiment accumulators + sentiment_scores[ticker].append(article.sentiment_score) + sentiment_confidences[ticker].append(article.confidence) - # Build sentiment context - sentiment = _build_sentiment_context( - ticker, - sentiment_scores[ticker], - sentiment_confidences[ticker], - ) - - # Get market snapshot (may be None if no bars received yet) - snapshot = market_data.get_snapshot(ticker) - if snapshot is None: - # Create a minimal snapshot from sentiment data alone - # (the news_driven strategy does not require market indicators) - from shared.schemas.trading import MarketSnapshot - - snapshot = MarketSnapshot( - ticker=ticker, - current_price=0.0, - open=0.0, - high=0.0, - low=0.0, - close=0.0, - volume=0.0, - ) - - # Run ensemble - signal = await ensemble.evaluate(ticker, snapshot, sentiment, weights) - - if signal is not None: - await publisher.publish(signal.model_dump(mode="json")) - signals_generated.add(1) - for src in signal.strategy_sources: - strategy_name = src.split(":")[0] - per_strategy_signal_count.add(1, {"strategy": strategy_name}) - logger.info( - "Signal generated: %s %s strength=%.4f sources=%s", - signal.direction.value, + # Build sentiment context + sentiment = _build_sentiment_context( ticker, - signal.strength, - signal.strategy_sources, + sentiment_scores[ticker], + sentiment_confidences[ticker], ) - except Exception: - logger.exception("Error processing scored article: %s", data.get("title", "")) + # Get market snapshot (may be None if no bars received yet) + snapshot = market_data.get_snapshot(ticker) + if snapshot is None: + # Create a minimal snapshot from sentiment data alone + # (the news_driven strategy does not require market indicators) + from shared.schemas.trading import MarketSnapshot + + snapshot = MarketSnapshot( + ticker=ticker, + current_price=0.0, + open=0.0, + high=0.0, + low=0.0, + close=0.0, + volume=0.0, + ) + + # Run ensemble + signal_result = await ensemble.evaluate(ticker, snapshot, sentiment, weights) + + if signal_result is not None: + await publisher.publish(signal_result.model_dump(mode="json")) + signals_generated.add(1) + for src in signal_result.strategy_sources: + strategy_name = src.split(":")[0] + per_strategy_signal_count.add(1, {"strategy": strategy_name}) + logger.info( + "Signal generated: %s %s strength=%.4f sources=%s", + signal_result.direction.value, + ticker, + signal_result.strength, + signal_result.strategy_sources, + ) + + except Exception: + logger.exception("Error processing scored article: %s", data.get("title", "")) + finally: + await redis.aclose() + logger.info("Signal generator stopped gracefully") def main() -> None: diff --git a/services/trade_executor/main.py b/services/trade_executor/main.py index 5c89fc9..5d3d7d0 100644 --- a/services/trade_executor/main.py +++ b/services/trade_executor/main.py @@ -10,6 +10,7 @@ from __future__ import annotations import asyncio import logging +import signal import time import uuid @@ -158,13 +159,25 @@ async def run(config: TradeExecutorConfig | None = None) -> None: logger.info("Consuming from signals:generated, publishing to trades:executed") + # Graceful shutdown on SIGTERM/SIGINT + shutdown_event = asyncio.Event() + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, shutdown_event.set) + # --- Consume loop --- - async for _msg_id, data in consumer.consume(): - try: - signal = TradeSignal.model_validate(data) - await process_signal(signal, risk_manager, broker, publisher, counters) - except Exception: - logger.exception("Error processing signal: %s", data) + try: + async for _msg_id, data in consumer.consume(): + if shutdown_event.is_set(): + break + try: + signal_msg = TradeSignal.model_validate(data) + await process_signal(signal_msg, risk_manager, broker, publisher, counters) + except Exception: + logger.exception("Error processing signal: %s", data) + finally: + await redis.aclose() + logger.info("Trade executor stopped gracefully") def main() -> None: diff --git a/shared/redis_streams.py b/shared/redis_streams.py index 904d57a..718ed16 100644 --- a/shared/redis_streams.py +++ b/shared/redis_streams.py @@ -40,9 +40,12 @@ class StreamConsumer: try: await self.redis.xgroup_create(self.stream, self.group, id="0", mkstream=True) logger.info("Created consumer group %s on %s", self.group, self.stream) - except Exception: - # Group already exists — this is expected on subsequent starts. - pass + except Exception as exc: + # BUSYGROUP means group already exists — expected on subsequent starts. + if "BUSYGROUP" in str(exc): + logger.debug("Consumer group %s already exists on %s", self.group, self.stream) + else: + raise async def consume( self, batch_size: int = 10, block_ms: int = 5000