138 lines
4.5 KiB
Python
138 lines
4.5 KiB
Python
"""WebSocket endpoint — authenticated real-time event stream."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
|
|
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
|
|
|
import jwt as pyjwt
|
|
|
|
from services.api_gateway.auth.jwt import decode_token
|
|
from services.api_gateway.config import ApiGatewayConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(tags=["websocket"])
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Manages active WebSocket connections."""
|
|
|
|
def __init__(self) -> None:
|
|
self.active_connections: list[WebSocket] = []
|
|
|
|
async def connect(self, websocket: WebSocket) -> None:
|
|
await websocket.accept()
|
|
self.active_connections.append(websocket)
|
|
|
|
def disconnect(self, websocket: WebSocket) -> None:
|
|
if websocket in self.active_connections:
|
|
self.active_connections.remove(websocket)
|
|
|
|
async def broadcast(self, message: dict) -> None:
|
|
"""Send a message to all connected clients. Remove any that have
|
|
disconnected."""
|
|
disconnected: list[WebSocket] = []
|
|
for connection in self.active_connections:
|
|
try:
|
|
await connection.send_json(message)
|
|
except Exception:
|
|
disconnected.append(connection)
|
|
for ws in disconnected:
|
|
self.disconnect(ws)
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
@router.websocket("/ws")
|
|
async def websocket_endpoint(
|
|
websocket: WebSocket,
|
|
token: str = Query(default=""),
|
|
) -> None:
|
|
"""Authenticated WebSocket that pushes real-time trading events.
|
|
|
|
Connect via ``ws://host/ws?token=<JWT>``. The server subscribes
|
|
to Redis pub/sub channels and forwards events to all connected
|
|
clients.
|
|
|
|
Event types pushed to clients:
|
|
- ``trade_executed``
|
|
- ``signal_generated``
|
|
- ``portfolio_update``
|
|
"""
|
|
# Authenticate via JWT query parameter
|
|
config: ApiGatewayConfig = websocket.app.state.config
|
|
try:
|
|
payload = decode_token(token, config)
|
|
if payload.get("type") != "access":
|
|
await websocket.close(code=4001, reason="Invalid token type")
|
|
return
|
|
except pyjwt.ExpiredSignatureError:
|
|
await websocket.close(code=4001, reason="Token expired")
|
|
return
|
|
except pyjwt.InvalidTokenError:
|
|
await websocket.close(code=4001, reason="Invalid token")
|
|
return
|
|
|
|
await manager.connect(websocket)
|
|
|
|
# Subscribe to Redis pub/sub channels for real-time events
|
|
redis = websocket.app.state.redis
|
|
pubsub = redis.pubsub()
|
|
await pubsub.subscribe(
|
|
"events:trade_executed",
|
|
"events:signal_generated",
|
|
"events:portfolio_update",
|
|
)
|
|
|
|
try:
|
|
# Run two concurrent tasks: one reads from Redis pub/sub and pushes
|
|
# to the client, the other keeps the WebSocket alive by reading
|
|
# (and discarding) incoming messages.
|
|
|
|
async def _redis_listener() -> None:
|
|
"""Forward Redis pub/sub messages to this WebSocket client."""
|
|
while True:
|
|
message = await pubsub.get_message(
|
|
ignore_subscribe_messages=True, timeout=1.0
|
|
)
|
|
if message and message["type"] == "message":
|
|
try:
|
|
data = json.loads(message["data"])
|
|
except (json.JSONDecodeError, TypeError):
|
|
data = {"raw": str(message["data"])}
|
|
channel = message["channel"]
|
|
if isinstance(channel, bytes):
|
|
channel = channel.decode()
|
|
event_type = channel.replace("events:", "")
|
|
await websocket.send_json(
|
|
{"event": event_type, "data": data}
|
|
)
|
|
await asyncio.sleep(0.1)
|
|
|
|
async def _ws_receiver() -> None:
|
|
"""Keep the connection alive by reading messages."""
|
|
while True:
|
|
await websocket.receive_text()
|
|
|
|
# Run both tasks; whichever finishes first (e.g., client disconnect)
|
|
# will cancel the other.
|
|
listener_task = asyncio.create_task(_redis_listener())
|
|
receiver_task = asyncio.create_task(_ws_receiver())
|
|
done, pending = await asyncio.wait(
|
|
{listener_task, receiver_task},
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
for task in pending:
|
|
task.cancel()
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|
|
finally:
|
|
manager.disconnect(websocket)
|
|
await pubsub.unsubscribe()
|
|
await pubsub.aclose()
|