feat: API gateway trading endpoints, controls, backtest, WebSocket

This commit is contained in:
Viktor Barzin 2026-02-22 15:54:20 +00:00
parent e0d138c457
commit 6fe586f722
No known key found for this signature in database
GPG key ID: 0EB088298288D958
11 changed files with 1304 additions and 0 deletions

View file

@ -69,6 +69,28 @@ def create_app(config: ApiGatewayConfig | None = None) -> FastAPI:
# Auth routes (unauthenticated)
app.include_router(auth_router)
# Trading routes (authenticated) — imported lazily to avoid circular deps
from services.api_gateway.routes.portfolio import router as portfolio_router
from services.api_gateway.routes.trades import router as trades_router
from services.api_gateway.routes.signals import router as signals_router
from services.api_gateway.routes.strategies import router as strategies_router
from services.api_gateway.routes.news import router as news_router
from services.api_gateway.routes.controls import router as controls_router
from services.api_gateway.routes.backtest import router as backtest_router
app.include_router(portfolio_router)
app.include_router(trades_router)
app.include_router(signals_router)
app.include_router(strategies_router)
app.include_router(news_router)
app.include_router(controls_router)
app.include_router(backtest_router)
# WebSocket
from services.api_gateway.ws import router as ws_router
app.include_router(ws_router)
# Health check
@app.get("/health", tags=["health"])
async def health() -> dict:

View file

@ -0,0 +1 @@
"""Trading route sub-package."""

View file

@ -0,0 +1,156 @@
"""Backtest endpoints — run backtests and retrieve results."""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel, Field
from services.api_gateway.auth.middleware import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
class BacktestRequest(BaseModel):
"""Request body for starting a new backtest."""
start_date: datetime
end_date: datetime
initial_capital: float = Field(default=100_000.0, gt=0)
commission_per_trade: float = Field(default=0.0, ge=0)
slippage_pct: float = Field(default=0.001, ge=0)
strategy_weights: dict[str, float] = Field(default_factory=dict)
max_position_pct: float = Field(default=0.05, gt=0, le=1.0)
signal_threshold: float = Field(default=0.3, ge=0, le=1.0)
@router.post("/run")
async def run_backtest(
body: BacktestRequest,
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Start a backtest with the given configuration.
Returns a ``run_id`` immediately. The backtest runs in a background
task and its results can be retrieved via ``GET /api/backtest/{run_id}``.
"""
run_id = str(uuid.uuid4())
redis = request.app.state.redis
# Store initial status
await redis.setex(
f"backtest:{run_id}",
86400, # 24h TTL
json.dumps({
"status": "running",
"config": body.model_dump(mode="json"),
"started_at": datetime.now(tz=timezone.utc).isoformat(),
}),
)
# Launch background task
asyncio.create_task(_run_backtest_task(run_id, body, redis))
return {"run_id": run_id, "status": "running"}
async def _run_backtest_task(
run_id: str,
config: BacktestRequest,
redis,
) -> None:
"""Execute the backtest in the background and store results in Redis."""
try:
from backtester.config import BacktestConfig
from backtester.engine import BacktestEngine
from shared.strategies.momentum import MomentumStrategy
from shared.strategies.mean_reversion import MeanReversionStrategy
from shared.strategies.news_driven import NewsDrivenStrategy
bt_config = BacktestConfig(
start_date=config.start_date,
end_date=config.end_date,
initial_capital=config.initial_capital,
commission_per_trade=config.commission_per_trade,
slippage_pct=config.slippage_pct,
strategy_weights=config.strategy_weights,
max_position_pct=config.max_position_pct,
signal_threshold=config.signal_threshold,
)
strategies = [
MomentumStrategy(),
MeanReversionStrategy(),
NewsDrivenStrategy(),
]
engine = BacktestEngine(config=bt_config, strategies=strategies)
# Use an in-memory stub data loader for now; a full implementation
# would read from TimescaleDB.
from backtester.data_loader import BacktestDataLoader
data_loader = BacktestDataLoader(
session_factory=None,
start_date=config.start_date,
end_date=config.end_date,
)
result = await engine.run(data_loader)
await redis.setex(
f"backtest:{run_id}",
86400,
json.dumps({
"status": "completed",
"config": config.model_dump(mode="json"),
"result": {
"total_return": result.total_return,
"annualized_return": result.annualized_return,
"sharpe_ratio": result.sharpe_ratio,
"sortino_ratio": result.sortino_ratio,
"max_drawdown": result.max_drawdown,
"max_drawdown_duration_days": result.max_drawdown_duration_days,
"win_rate": result.win_rate,
"trade_count": result.trade_count,
"avg_hold_duration_hours": result.avg_hold_duration_hours,
"profit_factor": result.profit_factor,
},
"completed_at": datetime.now(tz=timezone.utc).isoformat(),
}),
)
except Exception as exc:
logger.exception("Backtest %s failed", run_id)
await redis.setex(
f"backtest:{run_id}",
86400,
json.dumps({
"status": "failed",
"error": str(exc),
}),
)
@router.get("/{run_id}")
async def get_backtest(
run_id: str,
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Get backtest results by run ID."""
redis = request.app.state.redis
raw = await redis.get(f"backtest:{run_id}")
if raw is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Backtest run not found",
)
return json.loads(raw)

View file

@ -0,0 +1,95 @@
"""Control endpoints — pause/resume trading, force close positions."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from services.api_gateway.auth.middleware import get_current_user
router = APIRouter(prefix="/api/controls", tags=["controls"])
TRADING_PAUSED_KEY = "trading:paused"
class ClosePositionRequest(BaseModel):
"""Body for the force-close-position endpoint."""
ticker: str
@router.post("/pause")
async def pause_trading(
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Set Redis flag to pause trading."""
redis = request.app.state.redis
await redis.set(TRADING_PAUSED_KEY, "1")
return {"status": "paused"}
@router.post("/resume")
async def resume_trading(
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Clear pause flag to resume trading."""
redis = request.app.state.redis
await redis.delete(TRADING_PAUSED_KEY)
return {"status": "active"}
@router.post("/close-position")
async def close_position(
body: ClosePositionRequest,
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Force close an open position by ticker.
Publishes a close-position command to Redis so the trade executor
picks it up asynchronously.
"""
import json
from sqlalchemy import select
from shared.models.trading import Position
db = request.app.state.db_session_factory
async with db() as session:
position = (
await session.execute(
select(Position).where(Position.ticker == body.ticker.upper())
)
).scalar_one_or_none()
if position is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No open position for {body.ticker}",
)
# Publish close command to Redis for the trade executor
redis = request.app.state.redis
await redis.publish(
"controls:close_position",
json.dumps({"ticker": body.ticker.upper(), "qty": position.qty}),
)
return {
"status": "close_requested",
"ticker": body.ticker.upper(),
"qty": position.qty,
}
@router.get("/status")
async def get_trading_status(
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Current trading status — active or paused."""
redis = request.app.state.redis
paused = await redis.get(TRADING_PAUSED_KEY)
return {"status": "paused" if paused else "active"}

View file

@ -0,0 +1,87 @@
"""News endpoints — recent scored articles with filtering."""
from __future__ import annotations
from fastapi import APIRouter, Depends, Query, Request
from services.api_gateway.auth.middleware import get_current_user
from sqlalchemy import select, desc, func
router = APIRouter(prefix="/api/news", tags=["news"])
@router.get("")
async def list_news(
request: Request,
_user: dict = Depends(get_current_user),
ticker: str | None = Query(default=None),
source: str | None = Query(default=None),
min_score: float | None = Query(default=None, ge=-1.0, le=1.0),
max_score: float | None = Query(default=None, ge=-1.0, le=1.0),
page: int = Query(default=1, ge=1),
per_page: int = Query(default=20, ge=1, le=100),
) -> dict:
"""Recent scored articles with optional filters."""
from shared.models.news import Article, ArticleSentiment
db = request.app.state.db_session_factory
async with db() as session:
# Base query joining articles with sentiments
query = (
select(Article, ArticleSentiment)
.join(ArticleSentiment, Article.id == ArticleSentiment.article_id)
.order_by(desc(Article.fetched_at))
)
count_query = (
select(func.count())
.select_from(Article)
.join(ArticleSentiment, Article.id == ArticleSentiment.article_id)
)
if ticker:
query = query.where(ArticleSentiment.ticker == ticker.upper())
count_query = count_query.where(
ArticleSentiment.ticker == ticker.upper()
)
if source:
query = query.where(Article.source == source)
count_query = count_query.where(Article.source == source)
if min_score is not None:
query = query.where(ArticleSentiment.score >= min_score)
count_query = count_query.where(ArticleSentiment.score >= min_score)
if max_score is not None:
query = query.where(ArticleSentiment.score <= max_score)
count_query = count_query.where(ArticleSentiment.score <= max_score)
total = (await session.execute(count_query)).scalar() or 0
offset = (page - 1) * per_page
query = query.offset(offset).limit(per_page)
result = await session.execute(query)
rows = result.all()
return {
"articles": [
{
"id": str(article.id),
"source": article.source,
"url": article.url,
"title": article.title,
"published_at": (
article.published_at.isoformat()
if article.published_at
else None
),
"fetched_at": article.fetched_at.isoformat(),
"ticker": sentiment.ticker,
"sentiment_score": sentiment.score,
"confidence": sentiment.confidence,
"model_used": sentiment.model_used,
}
for article, sentiment in rows
],
"total": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page if per_page else 0,
}

View file

@ -0,0 +1,125 @@
"""Portfolio endpoints — current value, positions, equity curve."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from enum import Enum
from fastapi import APIRouter, Depends, Query, Request
from services.api_gateway.auth.middleware import get_current_user
from sqlalchemy import select, desc
router = APIRouter(prefix="/api/portfolio", tags=["portfolio"])
class HistoryPeriod(str, Enum):
ONE_DAY = "1d"
ONE_WEEK = "1w"
ONE_MONTH = "1m"
THREE_MONTHS = "3m"
ONE_YEAR = "1y"
def _period_to_timedelta(period: HistoryPeriod) -> timedelta:
"""Convert a period enum value to a timedelta."""
mapping = {
HistoryPeriod.ONE_DAY: timedelta(days=1),
HistoryPeriod.ONE_WEEK: timedelta(weeks=1),
HistoryPeriod.ONE_MONTH: timedelta(days=30),
HistoryPeriod.THREE_MONTHS: timedelta(days=90),
HistoryPeriod.ONE_YEAR: timedelta(days=365),
}
return mapping[period]
@router.get("")
async def get_portfolio(
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Current portfolio summary — value, cash, buying power, daily P&L."""
from shared.models.timeseries import PortfolioSnapshot
db = request.app.state.db_session_factory
async with db() as session:
latest = (
await session.execute(
select(PortfolioSnapshot)
.order_by(desc(PortfolioSnapshot.timestamp))
.limit(1)
)
).scalar_one_or_none()
if latest is None:
return {
"total_value": 0.0,
"cash": 0.0,
"buying_power": 0.0,
"daily_pnl": 0.0,
}
return {
"total_value": latest.total_value,
"cash": latest.cash,
"buying_power": latest.cash,
"daily_pnl": latest.daily_pnl,
}
@router.get("/positions")
async def get_positions(
request: Request,
_user: dict = Depends(get_current_user),
) -> list[dict]:
"""All open positions with unrealized P&L."""
from shared.models.trading import Position
db = request.app.state.db_session_factory
async with db() as session:
result = await session.execute(select(Position))
positions = result.scalars().all()
return [
{
"id": str(p.id),
"ticker": p.ticker,
"qty": p.qty,
"avg_entry": p.avg_entry,
"unrealized_pnl": p.unrealized_pnl or 0.0,
"stop_loss": p.stop_loss,
"take_profit": p.take_profit,
}
for p in positions
]
@router.get("/history")
async def get_portfolio_history(
request: Request,
_user: dict = Depends(get_current_user),
period: HistoryPeriod = Query(default=HistoryPeriod.ONE_MONTH),
) -> list[dict]:
"""Equity curve from portfolio_snapshots over a given period."""
from shared.models.timeseries import PortfolioSnapshot
since = datetime.now(timezone.utc) - _period_to_timedelta(period)
db = request.app.state.db_session_factory
async with db() as session:
result = await session.execute(
select(PortfolioSnapshot)
.where(PortfolioSnapshot.timestamp >= since)
.order_by(PortfolioSnapshot.timestamp)
)
snapshots = result.scalars().all()
return [
{
"timestamp": s.timestamp.isoformat(),
"total_value": s.total_value,
"cash": s.cash,
"positions_value": s.positions_value,
"daily_pnl": s.daily_pnl,
}
for s in snapshots
]

View file

@ -0,0 +1,58 @@
"""Signal endpoints — recent signals with filtering."""
from __future__ import annotations
from fastapi import APIRouter, Depends, Query, Request
from services.api_gateway.auth.middleware import get_current_user
from sqlalchemy import select, desc, func
router = APIRouter(prefix="/api/signals", tags=["signals"])
@router.get("")
async def list_signals(
request: Request,
_user: dict = Depends(get_current_user),
ticker: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
per_page: int = Query(default=20, ge=1, le=100),
) -> dict:
"""Recent signals with optional ticker filter and pagination."""
from shared.models.trading import Signal
db = request.app.state.db_session_factory
async with db() as session:
query = select(Signal).order_by(desc(Signal.created_at))
count_query = select(func.count()).select_from(Signal)
if ticker:
query = query.where(Signal.ticker == ticker.upper())
count_query = count_query.where(Signal.ticker == ticker.upper())
total = (await session.execute(count_query)).scalar() or 0
offset = (page - 1) * per_page
query = query.offset(offset).limit(per_page)
result = await session.execute(query)
signals = result.scalars().all()
return {
"signals": [
{
"id": str(s.id),
"ticker": s.ticker,
"direction": s.direction.value,
"strength": s.strength,
"strategy_sources": s.strategy_sources,
"sentiment_score": s.sentiment_score,
"acted_on": s.acted_on,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in signals
],
"total": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page if per_page else 0,
}

View file

@ -0,0 +1,111 @@
"""Strategy endpoints — list strategies, weight history, metrics."""
from __future__ import annotations
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Request, status
from services.api_gateway.auth.middleware import get_current_user
from sqlalchemy import select, desc
router = APIRouter(prefix="/api/strategies", tags=["strategies"])
@router.get("")
async def list_strategies(
request: Request,
_user: dict = Depends(get_current_user),
) -> list[dict]:
"""All strategies with current weights."""
from shared.models.trading import Strategy
db = request.app.state.db_session_factory
async with db() as session:
result = await session.execute(select(Strategy))
strategies = result.scalars().all()
return [
{
"id": str(s.id),
"name": s.name,
"description": s.description,
"current_weight": s.current_weight,
"active": s.active,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in strategies
]
@router.get("/{strategy_id}/history")
async def get_strategy_weight_history(
strategy_id: UUID,
request: Request,
_user: dict = Depends(get_current_user),
) -> list[dict]:
"""Weight history for a specific strategy."""
from shared.models.trading import StrategyWeightHistory, Strategy
db = request.app.state.db_session_factory
async with db() as session:
# Verify strategy exists
strategy = (
await session.execute(
select(Strategy).where(Strategy.id == strategy_id)
)
).scalar_one_or_none()
if strategy is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Strategy not found",
)
result = await session.execute(
select(StrategyWeightHistory)
.where(StrategyWeightHistory.strategy_id == strategy_id)
.order_by(desc(StrategyWeightHistory.created_at))
)
history = result.scalars().all()
return [
{
"id": str(h.id),
"old_weight": h.old_weight,
"new_weight": h.new_weight,
"reason": h.reason,
"created_at": h.created_at.isoformat() if h.created_at else None,
}
for h in history
]
@router.get("/{strategy_id}/metrics")
async def get_strategy_metrics(
strategy_id: UUID,
request: Request,
_user: dict = Depends(get_current_user),
) -> list[dict]:
"""Performance metrics over time for a specific strategy."""
from shared.models.timeseries import StrategyMetric
db = request.app.state.db_session_factory
async with db() as session:
result = await session.execute(
select(StrategyMetric)
.where(StrategyMetric.strategy_id == strategy_id)
.order_by(desc(StrategyMetric.timestamp))
.limit(100)
)
metrics = result.scalars().all()
return [
{
"timestamp": m.timestamp.isoformat(),
"win_rate": m.win_rate,
"total_pnl": m.total_pnl,
"trade_count": m.trade_count,
"sharpe_ratio": m.sharpe_ratio,
}
for m in metrics
]

View file

@ -0,0 +1,125 @@
"""Trade endpoints — paginated trade history and detail."""
from __future__ import annotations
from datetime import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from services.api_gateway.auth.middleware import get_current_user
from sqlalchemy import select, desc, func
router = APIRouter(prefix="/api/trades", tags=["trades"])
@router.get("")
async def list_trades(
request: Request,
_user: dict = Depends(get_current_user),
ticker: str | None = Query(default=None),
start_date: datetime | None = Query(default=None),
end_date: datetime | None = Query(default=None),
strategy: str | None = Query(default=None),
profitable: bool | None = Query(default=None),
page: int = Query(default=1, ge=1),
per_page: int = Query(default=20, ge=1, le=100),
) -> dict:
"""Paginated trade history with optional filters."""
from shared.models.trading import Trade, Strategy
db = request.app.state.db_session_factory
async with db() as session:
query = select(Trade).order_by(desc(Trade.created_at))
count_query = select(func.count()).select_from(Trade)
# Apply filters
if ticker:
query = query.where(Trade.ticker == ticker.upper())
count_query = count_query.where(Trade.ticker == ticker.upper())
if start_date:
query = query.where(Trade.created_at >= start_date)
count_query = count_query.where(Trade.created_at >= start_date)
if end_date:
query = query.where(Trade.created_at <= end_date)
count_query = count_query.where(Trade.created_at <= end_date)
if strategy:
# Join with Strategy to filter by name
query = query.join(Strategy, Trade.strategy_id == Strategy.id).where(
Strategy.name == strategy
)
count_query = count_query.join(
Strategy, Trade.strategy_id == Strategy.id
).where(Strategy.name == strategy)
if profitable is not None:
if profitable:
query = query.where(Trade.pnl > 0)
count_query = count_query.where(Trade.pnl > 0)
else:
query = query.where(Trade.pnl <= 0)
count_query = count_query.where(Trade.pnl <= 0)
# Pagination
total = (await session.execute(count_query)).scalar() or 0
offset = (page - 1) * per_page
query = query.offset(offset).limit(per_page)
result = await session.execute(query)
trades = result.scalars().all()
return {
"trades": [
{
"id": str(t.id),
"ticker": t.ticker,
"side": t.side.value,
"qty": t.qty,
"price": t.price,
"status": t.status.value,
"pnl": t.pnl,
"strategy_id": str(t.strategy_id) if t.strategy_id else None,
"signal_id": str(t.signal_id) if t.signal_id else None,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in trades
],
"total": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page if per_page else 0,
}
@router.get("/{trade_id}")
async def get_trade(
trade_id: UUID,
request: Request,
_user: dict = Depends(get_current_user),
) -> dict:
"""Single trade detail with linked signal and outcome."""
from shared.models.trading import Trade
db = request.app.state.db_session_factory
async with db() as session:
trade = (
await session.execute(select(Trade).where(Trade.id == trade_id))
).scalar_one_or_none()
if trade is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Trade not found",
)
return {
"id": str(trade.id),
"ticker": trade.ticker,
"side": trade.side.value,
"qty": trade.qty,
"price": trade.price,
"status": trade.status.value,
"pnl": trade.pnl,
"strategy_id": str(trade.strategy_id) if trade.strategy_id else None,
"signal_id": str(trade.signal_id) if trade.signal_id else None,
"created_at": trade.created_at.isoformat() if trade.created_at else None,
}

138
services/api_gateway/ws.py Normal file
View file

@ -0,0 +1,138 @@
"""WebSocket endpoint — authenticated real-time event stream."""
from __future__ import annotations
import asyncio
import json
import logging
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
import jwt as pyjwt
from services.api_gateway.auth.jwt import decode_token
from services.api_gateway.config import ApiGatewayConfig
logger = logging.getLogger(__name__)
router = APIRouter(tags=["websocket"])
class ConnectionManager:
"""Manages active WebSocket connections."""
def __init__(self) -> None:
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket) -> None:
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket) -> None:
if websocket in self.active_connections:
self.active_connections.remove(websocket)
async def broadcast(self, message: dict) -> None:
"""Send a message to all connected clients. Remove any that have
disconnected."""
disconnected: list[WebSocket] = []
for connection in self.active_connections:
try:
await connection.send_json(message)
except Exception:
disconnected.append(connection)
for ws in disconnected:
self.disconnect(ws)
manager = ConnectionManager()
@router.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str = Query(default=""),
) -> None:
"""Authenticated WebSocket that pushes real-time trading events.
Connect via ``ws://host/ws?token=<JWT>``. The server subscribes
to Redis pub/sub channels and forwards events to all connected
clients.
Event types pushed to clients:
- ``trade_executed``
- ``signal_generated``
- ``portfolio_update``
"""
# Authenticate via JWT query parameter
config: ApiGatewayConfig = websocket.app.state.config
try:
payload = decode_token(token, config)
if payload.get("type") != "access":
await websocket.close(code=4001, reason="Invalid token type")
return
except pyjwt.ExpiredSignatureError:
await websocket.close(code=4001, reason="Token expired")
return
except pyjwt.InvalidTokenError:
await websocket.close(code=4001, reason="Invalid token")
return
await manager.connect(websocket)
# Subscribe to Redis pub/sub channels for real-time events
redis = websocket.app.state.redis
pubsub = redis.pubsub()
await pubsub.subscribe(
"events:trade_executed",
"events:signal_generated",
"events:portfolio_update",
)
try:
# Run two concurrent tasks: one reads from Redis pub/sub and pushes
# to the client, the other keeps the WebSocket alive by reading
# (and discarding) incoming messages.
async def _redis_listener() -> None:
"""Forward Redis pub/sub messages to this WebSocket client."""
while True:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if message and message["type"] == "message":
try:
data = json.loads(message["data"])
except (json.JSONDecodeError, TypeError):
data = {"raw": str(message["data"])}
channel = message["channel"]
if isinstance(channel, bytes):
channel = channel.decode()
event_type = channel.replace("events:", "")
await websocket.send_json(
{"event": event_type, "data": data}
)
await asyncio.sleep(0.1)
async def _ws_receiver() -> None:
"""Keep the connection alive by reading messages."""
while True:
await websocket.receive_text()
# Run both tasks; whichever finishes first (e.g., client disconnect)
# will cancel the other.
listener_task = asyncio.create_task(_redis_listener())
receiver_task = asyncio.create_task(_ws_receiver())
done, pending = await asyncio.wait(
{listener_task, receiver_task},
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
except WebSocketDisconnect:
pass
finally:
manager.disconnect(websocket)
await pubsub.unsubscribe()
await pubsub.aclose()