trading/services/api_gateway/ws.py

138 lines
4.5 KiB
Python

"""WebSocket endpoint — authenticated real-time event stream."""
from __future__ import annotations
import asyncio
import json
import logging
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
import jwt as pyjwt
from services.api_gateway.auth.jwt import decode_token
from services.api_gateway.config import ApiGatewayConfig
logger = logging.getLogger(__name__)
router = APIRouter(tags=["websocket"])
class ConnectionManager:
"""Manages active WebSocket connections."""
def __init__(self) -> None:
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket) -> None:
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket) -> None:
if websocket in self.active_connections:
self.active_connections.remove(websocket)
async def broadcast(self, message: dict) -> None:
"""Send a message to all connected clients. Remove any that have
disconnected."""
disconnected: list[WebSocket] = []
for connection in self.active_connections:
try:
await connection.send_json(message)
except Exception:
disconnected.append(connection)
for ws in disconnected:
self.disconnect(ws)
manager = ConnectionManager()
@router.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str = Query(default=""),
) -> None:
"""Authenticated WebSocket that pushes real-time trading events.
Connect via ``ws://host/ws?token=<JWT>``. The server subscribes
to Redis pub/sub channels and forwards events to all connected
clients.
Event types pushed to clients:
- ``trade_executed``
- ``signal_generated``
- ``portfolio_update``
"""
# Authenticate via JWT query parameter
config: ApiGatewayConfig = websocket.app.state.config
try:
payload = decode_token(token, config)
if payload.get("type") != "access":
await websocket.close(code=4001, reason="Invalid token type")
return
except pyjwt.ExpiredSignatureError:
await websocket.close(code=4001, reason="Token expired")
return
except pyjwt.InvalidTokenError:
await websocket.close(code=4001, reason="Invalid token")
return
await manager.connect(websocket)
# Subscribe to Redis pub/sub channels for real-time events
redis = websocket.app.state.redis
pubsub = redis.pubsub()
await pubsub.subscribe(
"events:trade_executed",
"events:signal_generated",
"events:portfolio_update",
)
try:
# Run two concurrent tasks: one reads from Redis pub/sub and pushes
# to the client, the other keeps the WebSocket alive by reading
# (and discarding) incoming messages.
async def _redis_listener() -> None:
"""Forward Redis pub/sub messages to this WebSocket client."""
while True:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if message and message["type"] == "message":
try:
data = json.loads(message["data"])
except (json.JSONDecodeError, TypeError):
data = {"raw": str(message["data"])}
channel = message["channel"]
if isinstance(channel, bytes):
channel = channel.decode()
event_type = channel.replace("events:", "")
await websocket.send_json(
{"event": event_type, "data": data}
)
await asyncio.sleep(0.1)
async def _ws_receiver() -> None:
"""Keep the connection alive by reading messages."""
while True:
await websocket.receive_text()
# Run both tasks; whichever finishes first (e.g., client disconnect)
# will cancel the other.
listener_task = asyncio.create_task(_redis_listener())
receiver_task = asyncio.create_task(_ws_receiver())
done, pending = await asyncio.wait(
{listener_task, receiver_task},
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
except WebSocketDisconnect:
pass
finally:
manager.disconnect(websocket)
await pubsub.unsubscribe()
await pubsub.aclose()