"""WebSocket endpoint — authenticated real-time event stream.""" from __future__ import annotations import asyncio import json import logging from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect import jwt as pyjwt from services.api_gateway.auth.jwt import decode_token from services.api_gateway.config import ApiGatewayConfig logger = logging.getLogger(__name__) router = APIRouter(tags=["websocket"]) class ConnectionManager: """Manages active WebSocket connections.""" def __init__(self) -> None: self.active_connections: list[WebSocket] = [] async def connect(self, websocket: WebSocket) -> None: await websocket.accept() self.active_connections.append(websocket) def disconnect(self, websocket: WebSocket) -> None: if websocket in self.active_connections: self.active_connections.remove(websocket) async def broadcast(self, message: dict) -> None: """Send a message to all connected clients. Remove any that have disconnected.""" disconnected: list[WebSocket] = [] for connection in self.active_connections: try: await connection.send_json(message) except Exception: disconnected.append(connection) for ws in disconnected: self.disconnect(ws) manager = ConnectionManager() @router.websocket("/ws") async def websocket_endpoint( websocket: WebSocket, token: str = Query(default=""), ) -> None: """Authenticated WebSocket that pushes real-time trading events. Connect via ``ws://host/ws?token=``. The server subscribes to Redis pub/sub channels and forwards events to all connected clients. Event types pushed to clients: - ``trade_executed`` - ``signal_generated`` - ``portfolio_update`` """ # Authenticate via JWT query parameter config: ApiGatewayConfig = websocket.app.state.config try: payload = decode_token(token, config) if payload.get("type") != "access": await websocket.close(code=4001, reason="Invalid token type") return except pyjwt.ExpiredSignatureError: await websocket.close(code=4001, reason="Token expired") return except pyjwt.InvalidTokenError: await websocket.close(code=4001, reason="Invalid token") return await manager.connect(websocket) # Subscribe to Redis pub/sub channels for real-time events redis = websocket.app.state.redis pubsub = redis.pubsub() await pubsub.subscribe( "events:trade_executed", "events:signal_generated", "events:portfolio_update", ) try: # Run two concurrent tasks: one reads from Redis pub/sub and pushes # to the client, the other keeps the WebSocket alive by reading # (and discarding) incoming messages. async def _redis_listener() -> None: """Forward Redis pub/sub messages to this WebSocket client.""" while True: message = await pubsub.get_message( ignore_subscribe_messages=True, timeout=1.0 ) if message and message["type"] == "message": try: data = json.loads(message["data"]) except (json.JSONDecodeError, TypeError): data = {"raw": str(message["data"])} channel = message["channel"] if isinstance(channel, bytes): channel = channel.decode() event_type = channel.replace("events:", "") await websocket.send_json( {"event": event_type, "data": data} ) await asyncio.sleep(0.1) async def _ws_receiver() -> None: """Keep the connection alive by reading messages.""" while True: await websocket.receive_text() # Run both tasks; whichever finishes first (e.g., client disconnect) # will cancel the other. listener_task = asyncio.create_task(_redis_listener()) receiver_task = asyncio.create_task(_ws_receiver()) done, pending = await asyncio.wait( {listener_task, receiver_task}, return_when=asyncio.FIRST_COMPLETED, ) for task in pending: task.cancel() except WebSocketDisconnect: pass finally: manager.disconnect(websocket) await pubsub.unsubscribe() await pubsub.aclose()