"""Learning Engine service -- main entry point. Consumes ``trades:executed`` from Redis Streams, evaluates closed positions, attributes credit to contributing strategies, adjusts strategy weights via a multi-armed bandit approach, and stores all adjustments for auditability. """ from __future__ import annotations import asyncio import json import logging import signal from datetime import datetime, timezone from uuid import UUID from redis.asyncio import Redis from sqlalchemy import select from services.learning_engine.config import LearningEngineConfig from services.learning_engine.evaluator import TradeEvaluator from services.learning_engine.weight_adjuster import WeightAdjuster from shared.db import create_db from shared.models.trading import Strategy from shared.redis_streams import StreamConsumer from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment from shared.schemas.trading import OrderSide, TradeExecution from shared.telemetry import setup_telemetry logger = logging.getLogger(__name__) # Redis key for cached strategy weights _STRATEGY_WEIGHTS_KEY = "strategy:weights" async def _load_strategy_weights(redis: Redis) -> dict[str, float]: """Load current strategy weights from Redis cache. Falls back to equal weights for the three default strategies if no cached weights exist. """ raw = await redis.get(_STRATEGY_WEIGHTS_KEY) if raw: return json.loads(raw) # Default equal weights for all 9 strategies (matches seed_strategies.py) return { "momentum": 0.111, "mean_reversion": 0.111, "news_driven": 0.111, "value": 0.111, "macd_crossover": 0.111, "bollinger_breakout": 0.111, "vwap": 0.111, "liquidity": 0.112, "ma_stack": 0.111, } async def _save_strategy_weights(redis: Redis, weights: dict[str, float]) -> None: """Persist strategy weights to Redis cache.""" await redis.set(_STRATEGY_WEIGHTS_KEY, json.dumps(weights)) async def _find_opening_trade( redis: Redis, ticker: str, closing_side: OrderSide, ) -> dict | None: """Look up the opening trade for a position close. Searches the ``positions:history`` Redis hash for the ticker. Returns the stored entry data or None if not found. """ raw = await redis.hget("positions:history", ticker) if raw: return json.loads(raw) return None async def _store_opening_trade( redis: Redis, ticker: str, trade_data: dict, ) -> None: """Store a trade as the opening trade for a ticker.""" await redis.hset("positions:history", ticker, json.dumps(trade_data)) async def _clear_opening_trade(redis: Redis, ticker: str) -> None: """Clear the stored opening trade after position close.""" await redis.hdel("positions:history", ticker) def _is_position_close(trade: TradeExecution, opening: dict | None) -> bool: """Determine if a trade closes a position. A trade closes a position if there is an existing opening trade on the opposite side for the same ticker. """ if opening is None: return False opening_side = opening.get("side", "") if trade.side == OrderSide.SELL and opening_side == OrderSide.BUY.value: return True if trade.side == OrderSide.BUY and opening_side == OrderSide.SELL.value: return True return False async def process_trade( trade: TradeExecution, redis: Redis, evaluator: TradeEvaluator, adjuster: WeightAdjuster, counters: dict, strategy_id_lookup: dict[str, UUID] | None = None, ) -> list[WeightAdjustment]: """Process a single trade execution. If the trade closes a position: 1. Evaluate the trade outcome (P&L, ROI) 2. Attribute credit to contributing strategies 3. Adjust weights for strategies with enough trades 4. Normalize all weights 5. Store adjustments and update cached weights Returns a list of weight adjustments made (empty if none). """ adjustments: list[WeightAdjustment] = [] # Look up opening trade opening = await _find_opening_trade(redis, trade.ticker, trade.side) if not _is_position_close(trade, opening): # This is an opening trade -- store it for later reference await _store_opening_trade( redis, trade.ticker, { "trade_id": str(trade.trade_id), "side": trade.side.value, "price": trade.price, "qty": trade.qty, "timestamp": trade.timestamp.isoformat(), "strategy_sources": trade.strategy_sources, }, ) return adjustments # --- Position close detected --- entry_price = opening["price"] entry_qty = opening.get("qty", trade.qty) entry_time_str = opening.get("timestamp", "") strategy_sources = opening.get("strategy_sources", []) # Determine direction sign opening_side = opening.get("side", "") direction_sign = 1.0 if opening_side == OrderSide.BUY.value else -1.0 # Compute hold duration hold_duration_seconds = 0.0 if entry_time_str: try: entry_time = datetime.fromisoformat(entry_time_str) hold_duration_seconds = (trade.timestamp - entry_time).total_seconds() except (ValueError, TypeError): hold_duration_seconds = 0.0 # Step 1: Evaluate trade outcome = evaluator.evaluate_trade( trade_id=trade.trade_id, entry_price=entry_price, exit_price=trade.price, qty=min(trade.qty, entry_qty), direction_sign=direction_sign, hold_duration_seconds=max(hold_duration_seconds, 0.0), ) logger.info( "Trade outcome: %s PnL=%.2f ROI=%.2f%% profitable=%s", trade.ticker, outcome.realized_pnl, outcome.roi_pct, outcome.was_profitable, ) # Step 2: Attribute credit rewards = evaluator.attribute_credit(outcome, strategy_sources) # Record trades and rewards for each strategy for strategy_name, reward in rewards.items(): adjuster.record_trade(strategy_name) adjuster.record_reward(strategy_name, reward) # Step 3: Load current weights weights = await _load_strategy_weights(redis) # Step 4: Adjust weights for strategies with enough trades any_adjusted = False for strategy_name, reward in rewards.items(): if not adjuster.should_adjust(strategy_name): logger.debug( "Strategy %s has %d trades (need %d) -- skipping adjustment", strategy_name, adjuster.trade_counts.get(strategy_name, 0), adjuster.config.min_trades_before_adjustment, ) continue old_weight = weights.get(strategy_name, adjuster.config.weight_floor) decayed_reward = adjuster.get_decayed_reward(strategy_name) new_weight = adjuster.adjust_weight(old_weight, decayed_reward) weights[strategy_name] = new_weight any_adjusted = True sid = (strategy_id_lookup or {}).get(strategy_name, UUID(int=0)) adjustment = WeightAdjustment( strategy_id=sid, strategy_name=strategy_name, old_weight=old_weight, new_weight=new_weight, reason=f"bandit_adjustment roi={outcome.roi_pct:.2f}%", reward_signal=reward, timestamp=datetime.now(timezone.utc), ) adjustments.append(adjustment) counters["adjustments_made"].add(1) logger.info( "Weight adjusted: %s %.4f -> %.4f (reward=%.4f)", strategy_name, old_weight, new_weight, reward, ) # Step 5: Normalize weights if any_adjusted: weights = adjuster.normalize_weights(weights) await _save_strategy_weights(redis, weights) # Track weight drift for name, weight in weights.items(): default = 1.0 / len(weights) drift = abs(weight - default) counters["weight_drift"].record(drift, {"strategy": name}) # Clean up opening trade await _clear_opening_trade(redis, trade.ticker) return adjustments async def run(config: LearningEngineConfig | None = None) -> None: """Main service loop. Connects to Redis, initialises evaluator and weight adjuster, then continuously consumes from ``trades:executed`` and processes closed positions through the learning pipeline. """ if config is None: config = LearningEngineConfig() logging.basicConfig(level=config.log_level) logger.info("Starting Learning Engine service") # --- Telemetry --- meter = setup_telemetry("learning-engine", config.otel_metrics_port) counters = { "adjustments_made": meter.create_counter( "adjustments_made", description="Total strategy weight adjustments performed", ), "weight_drift": meter.create_histogram( "weight_drift", description="Absolute deviation of each strategy weight from equal weight", ), } # --- Redis --- redis = Redis.from_url(config.redis_url, decode_responses=False) consumer = StreamConsumer(redis, "trades:executed", "learning-engine", "worker-1") # --- Components --- evaluator = TradeEvaluator() adjuster = WeightAdjuster(config) # --- Load strategy name -> UUID lookup from DB --- strategy_id_lookup: dict[str, UUID] = {} try: _engine, session_factory = create_db(config) async with session_factory() as session: result = await session.execute(select(Strategy)) for s in result.scalars().all(): strategy_id_lookup[s.name] = s.id await _engine.dispose() logger.info("Loaded %d strategy IDs from DB", len(strategy_id_lookup)) except Exception: logger.exception("Failed to load strategy IDs — using fallback UUID(int=0)") logger.info("Consuming from trades:executed") # Graceful shutdown on SIGTERM/SIGINT shutdown_event = asyncio.Event() loop = asyncio.get_running_loop() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, shutdown_event.set) # --- Consume loop --- try: async for _msg_id, data in consumer.consume(): if shutdown_event.is_set(): break try: trade = TradeExecution.model_validate(data) if trade.status.value != "FILLED": logger.debug("Skipping non-filled trade: %s", trade.status.value) continue adjustments = await process_trade( trade, redis, evaluator, adjuster, counters, strategy_id_lookup ) if adjustments: logger.info( "Made %d weight adjustment(s) for %s", len(adjustments), trade.ticker, ) except Exception: logger.exception("Error processing trade execution: %s", data) finally: await redis.aclose() logger.info("Learning engine stopped gracefully") def main() -> None: """CLI entry point.""" asyncio.run(run()) if __name__ == "__main__": main()