trading/services/learning_engine/main.py

304 lines
9.7 KiB
Python

"""Learning Engine service -- main entry point.
Consumes ``trades:executed`` from Redis Streams, evaluates closed positions,
attributes credit to contributing strategies, adjusts strategy weights via
a multi-armed bandit approach, and stores all adjustments for auditability.
"""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timezone
from uuid import UUID
from redis.asyncio import Redis
from services.learning_engine.config import LearningEngineConfig
from services.learning_engine.evaluator import TradeEvaluator
from services.learning_engine.weight_adjuster import WeightAdjuster
from shared.redis_streams import StreamConsumer
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
from shared.schemas.trading import OrderSide, TradeExecution
from shared.telemetry import setup_telemetry
logger = logging.getLogger(__name__)
# Redis key for cached strategy weights
_STRATEGY_WEIGHTS_KEY = "strategy:weights"
async def _load_strategy_weights(redis: Redis) -> dict[str, float]:
"""Load current strategy weights from Redis cache.
Falls back to equal weights for the three default strategies
if no cached weights exist.
"""
raw = await redis.get(_STRATEGY_WEIGHTS_KEY)
if raw:
return json.loads(raw)
# Default equal weights
return {
"momentum": 0.333,
"mean_reversion": 0.333,
"news_driven": 0.334,
}
async def _save_strategy_weights(redis: Redis, weights: dict[str, float]) -> None:
"""Persist strategy weights to Redis cache."""
await redis.set(_STRATEGY_WEIGHTS_KEY, json.dumps(weights))
async def _find_opening_trade(
redis: Redis,
ticker: str,
closing_side: OrderSide,
) -> dict | None:
"""Look up the opening trade for a position close.
Searches the ``positions:history`` Redis hash for the ticker.
Returns the stored entry data or None if not found.
"""
raw = await redis.hget("positions:history", ticker)
if raw:
return json.loads(raw)
return None
async def _store_opening_trade(
redis: Redis,
ticker: str,
trade_data: dict,
) -> None:
"""Store a trade as the opening trade for a ticker."""
await redis.hset("positions:history", ticker, json.dumps(trade_data))
async def _clear_opening_trade(redis: Redis, ticker: str) -> None:
"""Clear the stored opening trade after position close."""
await redis.hdel("positions:history", ticker)
def _is_position_close(trade: TradeExecution, opening: dict | None) -> bool:
"""Determine if a trade closes a position.
A trade closes a position if there is an existing opening trade
on the opposite side for the same ticker.
"""
if opening is None:
return False
opening_side = opening.get("side", "")
if trade.side == OrderSide.SELL and opening_side == OrderSide.BUY.value:
return True
if trade.side == OrderSide.BUY and opening_side == OrderSide.SELL.value:
return True
return False
async def process_trade(
trade: TradeExecution,
redis: Redis,
evaluator: TradeEvaluator,
adjuster: WeightAdjuster,
counters: dict,
) -> list[WeightAdjustment]:
"""Process a single trade execution.
If the trade closes a position:
1. Evaluate the trade outcome (P&L, ROI)
2. Attribute credit to contributing strategies
3. Adjust weights for strategies with enough trades
4. Normalize all weights
5. Store adjustments and update cached weights
Returns a list of weight adjustments made (empty if none).
"""
adjustments: list[WeightAdjustment] = []
# Look up opening trade
opening = await _find_opening_trade(redis, trade.ticker, trade.side)
if not _is_position_close(trade, opening):
# This is an opening trade -- store it for later reference
await _store_opening_trade(
redis,
trade.ticker,
{
"trade_id": str(trade.trade_id),
"side": trade.side.value,
"price": trade.price,
"qty": trade.qty,
"timestamp": trade.timestamp.isoformat(),
"strategy_sources": [], # would come from signal
},
)
return adjustments
# --- Position close detected ---
entry_price = opening["price"]
entry_qty = opening.get("qty", trade.qty)
entry_time_str = opening.get("timestamp", "")
strategy_sources = opening.get("strategy_sources", [])
# Determine direction sign
opening_side = opening.get("side", "")
direction_sign = 1.0 if opening_side == OrderSide.BUY.value else -1.0
# Compute hold duration
hold_duration_seconds = 0.0
if entry_time_str:
try:
entry_time = datetime.fromisoformat(entry_time_str)
hold_duration_seconds = (trade.timestamp - entry_time).total_seconds()
except (ValueError, TypeError):
hold_duration_seconds = 0.0
# Step 1: Evaluate trade
outcome = evaluator.evaluate_trade(
trade_id=trade.trade_id,
entry_price=entry_price,
exit_price=trade.price,
qty=min(trade.qty, entry_qty),
direction_sign=direction_sign,
hold_duration_seconds=max(hold_duration_seconds, 0.0),
)
logger.info(
"Trade outcome: %s PnL=%.2f ROI=%.2f%% profitable=%s",
trade.ticker,
outcome.realized_pnl,
outcome.roi_pct,
outcome.was_profitable,
)
# Step 2: Attribute credit
rewards = evaluator.attribute_credit(outcome, strategy_sources)
# Record trades and rewards for each strategy
for strategy_name, reward in rewards.items():
adjuster.record_trade(strategy_name)
adjuster.record_reward(strategy_name, reward)
# Step 3: Load current weights
weights = await _load_strategy_weights(redis)
# Step 4: Adjust weights for strategies with enough trades
any_adjusted = False
for strategy_name, reward in rewards.items():
if not adjuster.should_adjust(strategy_name):
logger.debug(
"Strategy %s has %d trades (need %d) -- skipping adjustment",
strategy_name,
adjuster.trade_counts.get(strategy_name, 0),
adjuster.config.min_trades_before_adjustment,
)
continue
old_weight = weights.get(strategy_name, adjuster.config.weight_floor)
decayed_reward = adjuster.get_decayed_reward(strategy_name)
new_weight = adjuster.adjust_weight(old_weight, decayed_reward)
weights[strategy_name] = new_weight
any_adjusted = True
adjustment = WeightAdjustment(
strategy_id=UUID(int=0), # placeholder -- DB would assign real ID
strategy_name=strategy_name,
old_weight=old_weight,
new_weight=new_weight,
reason=f"bandit_adjustment roi={outcome.roi_pct:.2f}%",
reward_signal=reward,
timestamp=datetime.now(timezone.utc),
)
adjustments.append(adjustment)
counters["adjustments_made"].add(1)
logger.info(
"Weight adjusted: %s %.4f -> %.4f (reward=%.4f)",
strategy_name,
old_weight,
new_weight,
reward,
)
# Step 5: Normalize weights
if any_adjusted:
weights = adjuster.normalize_weights(weights)
await _save_strategy_weights(redis, weights)
# Track weight drift
for name, weight in weights.items():
default = 1.0 / len(weights)
drift = abs(weight - default)
counters["weight_drift"].record(drift, {"strategy": name})
# Clean up opening trade
await _clear_opening_trade(redis, trade.ticker)
return adjustments
async def run(config: LearningEngineConfig | None = None) -> None:
"""Main service loop.
Connects to Redis, initialises evaluator and weight adjuster, then
continuously consumes from ``trades:executed`` and processes closed
positions through the learning pipeline.
"""
if config is None:
config = LearningEngineConfig()
logging.basicConfig(level=config.log_level)
logger.info("Starting Learning Engine service")
# --- Telemetry ---
meter = setup_telemetry("learning-engine", config.otel_metrics_port)
counters = {
"adjustments_made": meter.create_counter(
"adjustments_made",
description="Total strategy weight adjustments performed",
),
"weight_drift": meter.create_histogram(
"weight_drift",
description="Absolute deviation of each strategy weight from equal weight",
),
}
# --- Redis ---
redis = Redis.from_url(config.redis_url, decode_responses=False)
consumer = StreamConsumer(redis, "trades:executed", "learning-engine", "worker-1")
# --- Components ---
evaluator = TradeEvaluator()
adjuster = WeightAdjuster(config)
logger.info("Consuming from trades:executed")
# --- Consume loop ---
async for _msg_id, data in consumer.consume():
try:
trade = TradeExecution.model_validate(data)
if trade.status.value != "FILLED":
logger.debug("Skipping non-filled trade: %s", trade.status.value)
continue
adjustments = await process_trade(trade, redis, evaluator, adjuster, counters)
if adjustments:
logger.info(
"Made %d weight adjustment(s) for %s",
len(adjustments),
trade.ticker,
)
except Exception:
logger.exception("Error processing trade execution: %s", data)
def main() -> None:
"""CLI entry point."""
asyncio.run(run())
if __name__ == "__main__":
main()