I1: Add graceful shutdown (SIGTERM/SIGINT) to all 5 background services I2: Fix Dockerfile healthcheck to use curl on /metrics endpoint I3: Fix StreamConsumer.ensure_group() to only catch BUSYGROUP errors I4: Fix SimulatedBroker to reject orders with insufficient cash/shares I5: Move ORM attribute access inside DB session context in trades routes I6: Add Redis-based rate limiting (10 req/min/IP) on all auth endpoints I8: Prevent backtest background task garbage collection I9: Use Numeric(16,6) instead of Float for financial columns in migration I10: Add index on trades.created_at for time-range queries I11: Bind infrastructure ports to 127.0.0.1 in docker-compose I12: Add migrations init service; all app services depend on it I13: Fix user enumeration in login_begin (return options for non-existent users)
317 lines
10 KiB
Python
317 lines
10 KiB
Python
"""Learning Engine service -- main entry point.
|
|
|
|
Consumes ``trades:executed`` from Redis Streams, evaluates closed positions,
|
|
attributes credit to contributing strategies, adjusts strategy weights via
|
|
a multi-armed bandit approach, and stores all adjustments for auditability.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import signal
|
|
from datetime import datetime, timezone
|
|
from uuid import UUID
|
|
|
|
from redis.asyncio import Redis
|
|
|
|
from services.learning_engine.config import LearningEngineConfig
|
|
from services.learning_engine.evaluator import TradeEvaluator
|
|
from services.learning_engine.weight_adjuster import WeightAdjuster
|
|
from shared.redis_streams import StreamConsumer
|
|
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
|
|
from shared.schemas.trading import OrderSide, TradeExecution
|
|
from shared.telemetry import setup_telemetry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Redis key for cached strategy weights
|
|
_STRATEGY_WEIGHTS_KEY = "strategy:weights"
|
|
|
|
|
|
async def _load_strategy_weights(redis: Redis) -> dict[str, float]:
|
|
"""Load current strategy weights from Redis cache.
|
|
|
|
Falls back to equal weights for the three default strategies
|
|
if no cached weights exist.
|
|
"""
|
|
raw = await redis.get(_STRATEGY_WEIGHTS_KEY)
|
|
if raw:
|
|
return json.loads(raw)
|
|
# Default equal weights
|
|
return {
|
|
"momentum": 0.333,
|
|
"mean_reversion": 0.333,
|
|
"news_driven": 0.334,
|
|
}
|
|
|
|
|
|
async def _save_strategy_weights(redis: Redis, weights: dict[str, float]) -> None:
|
|
"""Persist strategy weights to Redis cache."""
|
|
await redis.set(_STRATEGY_WEIGHTS_KEY, json.dumps(weights))
|
|
|
|
|
|
async def _find_opening_trade(
|
|
redis: Redis,
|
|
ticker: str,
|
|
closing_side: OrderSide,
|
|
) -> dict | None:
|
|
"""Look up the opening trade for a position close.
|
|
|
|
Searches the ``positions:history`` Redis hash for the ticker.
|
|
Returns the stored entry data or None if not found.
|
|
"""
|
|
raw = await redis.hget("positions:history", ticker)
|
|
if raw:
|
|
return json.loads(raw)
|
|
return None
|
|
|
|
|
|
async def _store_opening_trade(
|
|
redis: Redis,
|
|
ticker: str,
|
|
trade_data: dict,
|
|
) -> None:
|
|
"""Store a trade as the opening trade for a ticker."""
|
|
await redis.hset("positions:history", ticker, json.dumps(trade_data))
|
|
|
|
|
|
async def _clear_opening_trade(redis: Redis, ticker: str) -> None:
|
|
"""Clear the stored opening trade after position close."""
|
|
await redis.hdel("positions:history", ticker)
|
|
|
|
|
|
def _is_position_close(trade: TradeExecution, opening: dict | None) -> bool:
|
|
"""Determine if a trade closes a position.
|
|
|
|
A trade closes a position if there is an existing opening trade
|
|
on the opposite side for the same ticker.
|
|
"""
|
|
if opening is None:
|
|
return False
|
|
opening_side = opening.get("side", "")
|
|
if trade.side == OrderSide.SELL and opening_side == OrderSide.BUY.value:
|
|
return True
|
|
if trade.side == OrderSide.BUY and opening_side == OrderSide.SELL.value:
|
|
return True
|
|
return False
|
|
|
|
|
|
async def process_trade(
|
|
trade: TradeExecution,
|
|
redis: Redis,
|
|
evaluator: TradeEvaluator,
|
|
adjuster: WeightAdjuster,
|
|
counters: dict,
|
|
) -> list[WeightAdjustment]:
|
|
"""Process a single trade execution.
|
|
|
|
If the trade closes a position:
|
|
1. Evaluate the trade outcome (P&L, ROI)
|
|
2. Attribute credit to contributing strategies
|
|
3. Adjust weights for strategies with enough trades
|
|
4. Normalize all weights
|
|
5. Store adjustments and update cached weights
|
|
|
|
Returns a list of weight adjustments made (empty if none).
|
|
"""
|
|
adjustments: list[WeightAdjustment] = []
|
|
|
|
# Look up opening trade
|
|
opening = await _find_opening_trade(redis, trade.ticker, trade.side)
|
|
|
|
if not _is_position_close(trade, opening):
|
|
# This is an opening trade -- store it for later reference
|
|
await _store_opening_trade(
|
|
redis,
|
|
trade.ticker,
|
|
{
|
|
"trade_id": str(trade.trade_id),
|
|
"side": trade.side.value,
|
|
"price": trade.price,
|
|
"qty": trade.qty,
|
|
"timestamp": trade.timestamp.isoformat(),
|
|
"strategy_sources": [], # would come from signal
|
|
},
|
|
)
|
|
return adjustments
|
|
|
|
# --- Position close detected ---
|
|
entry_price = opening["price"]
|
|
entry_qty = opening.get("qty", trade.qty)
|
|
entry_time_str = opening.get("timestamp", "")
|
|
strategy_sources = opening.get("strategy_sources", [])
|
|
|
|
# Determine direction sign
|
|
opening_side = opening.get("side", "")
|
|
direction_sign = 1.0 if opening_side == OrderSide.BUY.value else -1.0
|
|
|
|
# Compute hold duration
|
|
hold_duration_seconds = 0.0
|
|
if entry_time_str:
|
|
try:
|
|
entry_time = datetime.fromisoformat(entry_time_str)
|
|
hold_duration_seconds = (trade.timestamp - entry_time).total_seconds()
|
|
except (ValueError, TypeError):
|
|
hold_duration_seconds = 0.0
|
|
|
|
# Step 1: Evaluate trade
|
|
outcome = evaluator.evaluate_trade(
|
|
trade_id=trade.trade_id,
|
|
entry_price=entry_price,
|
|
exit_price=trade.price,
|
|
qty=min(trade.qty, entry_qty),
|
|
direction_sign=direction_sign,
|
|
hold_duration_seconds=max(hold_duration_seconds, 0.0),
|
|
)
|
|
|
|
logger.info(
|
|
"Trade outcome: %s PnL=%.2f ROI=%.2f%% profitable=%s",
|
|
trade.ticker,
|
|
outcome.realized_pnl,
|
|
outcome.roi_pct,
|
|
outcome.was_profitable,
|
|
)
|
|
|
|
# Step 2: Attribute credit
|
|
rewards = evaluator.attribute_credit(outcome, strategy_sources)
|
|
|
|
# Record trades and rewards for each strategy
|
|
for strategy_name, reward in rewards.items():
|
|
adjuster.record_trade(strategy_name)
|
|
adjuster.record_reward(strategy_name, reward)
|
|
|
|
# Step 3: Load current weights
|
|
weights = await _load_strategy_weights(redis)
|
|
|
|
# Step 4: Adjust weights for strategies with enough trades
|
|
any_adjusted = False
|
|
for strategy_name, reward in rewards.items():
|
|
if not adjuster.should_adjust(strategy_name):
|
|
logger.debug(
|
|
"Strategy %s has %d trades (need %d) -- skipping adjustment",
|
|
strategy_name,
|
|
adjuster.trade_counts.get(strategy_name, 0),
|
|
adjuster.config.min_trades_before_adjustment,
|
|
)
|
|
continue
|
|
|
|
old_weight = weights.get(strategy_name, adjuster.config.weight_floor)
|
|
decayed_reward = adjuster.get_decayed_reward(strategy_name)
|
|
new_weight = adjuster.adjust_weight(old_weight, decayed_reward)
|
|
weights[strategy_name] = new_weight
|
|
any_adjusted = True
|
|
|
|
adjustment = WeightAdjustment(
|
|
strategy_id=UUID(int=0), # placeholder -- DB would assign real ID
|
|
strategy_name=strategy_name,
|
|
old_weight=old_weight,
|
|
new_weight=new_weight,
|
|
reason=f"bandit_adjustment roi={outcome.roi_pct:.2f}%",
|
|
reward_signal=reward,
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
adjustments.append(adjustment)
|
|
counters["adjustments_made"].add(1)
|
|
|
|
logger.info(
|
|
"Weight adjusted: %s %.4f -> %.4f (reward=%.4f)",
|
|
strategy_name,
|
|
old_weight,
|
|
new_weight,
|
|
reward,
|
|
)
|
|
|
|
# Step 5: Normalize weights
|
|
if any_adjusted:
|
|
weights = adjuster.normalize_weights(weights)
|
|
await _save_strategy_weights(redis, weights)
|
|
|
|
# Track weight drift
|
|
for name, weight in weights.items():
|
|
default = 1.0 / len(weights)
|
|
drift = abs(weight - default)
|
|
counters["weight_drift"].record(drift, {"strategy": name})
|
|
|
|
# Clean up opening trade
|
|
await _clear_opening_trade(redis, trade.ticker)
|
|
|
|
return adjustments
|
|
|
|
|
|
async def run(config: LearningEngineConfig | None = None) -> None:
|
|
"""Main service loop.
|
|
|
|
Connects to Redis, initialises evaluator and weight adjuster, then
|
|
continuously consumes from ``trades:executed`` and processes closed
|
|
positions through the learning pipeline.
|
|
"""
|
|
if config is None:
|
|
config = LearningEngineConfig()
|
|
|
|
logging.basicConfig(level=config.log_level)
|
|
logger.info("Starting Learning Engine service")
|
|
|
|
# --- Telemetry ---
|
|
meter = setup_telemetry("learning-engine", config.otel_metrics_port)
|
|
counters = {
|
|
"adjustments_made": meter.create_counter(
|
|
"adjustments_made",
|
|
description="Total strategy weight adjustments performed",
|
|
),
|
|
"weight_drift": meter.create_histogram(
|
|
"weight_drift",
|
|
description="Absolute deviation of each strategy weight from equal weight",
|
|
),
|
|
}
|
|
|
|
# --- Redis ---
|
|
redis = Redis.from_url(config.redis_url, decode_responses=False)
|
|
consumer = StreamConsumer(redis, "trades:executed", "learning-engine", "worker-1")
|
|
|
|
# --- Components ---
|
|
evaluator = TradeEvaluator()
|
|
adjuster = WeightAdjuster(config)
|
|
|
|
logger.info("Consuming from trades:executed")
|
|
|
|
# Graceful shutdown on SIGTERM/SIGINT
|
|
shutdown_event = asyncio.Event()
|
|
loop = asyncio.get_running_loop()
|
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
loop.add_signal_handler(sig, shutdown_event.set)
|
|
|
|
# --- Consume loop ---
|
|
try:
|
|
async for _msg_id, data in consumer.consume():
|
|
if shutdown_event.is_set():
|
|
break
|
|
try:
|
|
trade = TradeExecution.model_validate(data)
|
|
|
|
if trade.status.value != "FILLED":
|
|
logger.debug("Skipping non-filled trade: %s", trade.status.value)
|
|
continue
|
|
|
|
adjustments = await process_trade(trade, redis, evaluator, adjuster, counters)
|
|
if adjustments:
|
|
logger.info(
|
|
"Made %d weight adjustment(s) for %s",
|
|
len(adjustments),
|
|
trade.ticker,
|
|
)
|
|
except Exception:
|
|
logger.exception("Error processing trade execution: %s", data)
|
|
finally:
|
|
await redis.aclose()
|
|
logger.info("Learning engine stopped gracefully")
|
|
|
|
|
|
def main() -> None:
|
|
"""CLI entry point."""
|
|
asyncio.run(run())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|