""" LLM Query Wrapper (D2.1) Provides a standardized interface for LLM calls with retry logic and cost tracking. """ from dataclasses import dataclass import os import time from typing import Any, Dict, List, Optional @dataclass class LLMResponse: """Response object with usage metadata.""" text: str input_tokens: int output_tokens: int total_tokens: int cost_usd: float latency_ms: int provider: str model: str class LLMError(RuntimeError): """Base error for LLM failures.""" def __init__(self, message: str, provider: str, retries: int, is_transient: bool = False): super().__init__(message) self.provider = provider self.retries = retries self.is_transient = is_transient class LLMTransientError(LLMError): """Retryable LLM error.""" def __init__(self, message: str, provider: str = "unknown", retries: int = 0): super().__init__(message, provider=provider, retries=retries, is_transient=True) class LLMPermanentError(LLMError): """Non-retryable LLM error.""" def __init__(self, message: str, provider: str = "unknown", retries: int = 0): super().__init__(message, provider=provider, retries=retries, is_transient=False) class LLMBudgetExceededError(LLMError): """Raised when LLM budget is exceeded.""" def __init__(self, message: str, provider: str = "unknown", retries: int = 0): super().__init__(message, provider=provider, retries=retries, is_transient=False) class LLMClient: """Standardized LLM client with retry and usage tracking.""" _DEFAULT_MODELS = { "openai": "gpt-4o-mini", "anthropic": "claude-3-5-sonnet-20240620", "local": "local", "mock": "mock" } _ENV_KEYS = { "openai": "OPENAI_API_KEY", "anthropic": "ANTHROPIC_API_KEY" } _DEFAULT_RATES = { "openai": {"input": 5.0, "output": 15.0}, "anthropic": {"input": 3.0, "output": 15.0}, "local": {"input": 0.0, "output": 0.0}, "mock": {"input": 0.0, "output": 0.0} } def __init__( self, provider: str, api_key: Optional[str] = None, model: Optional[str] = None, max_retries: int = 3, backoff_base: float = 1.0, sleep_fn=time.sleep, mock_sequence: Optional[List[Any]] = None, rate_table: Optional[Dict[str, Dict[str, float]]] = None, max_cost_usd: Optional[float] = None ): self.provider = provider.lower() if self.provider not in self._DEFAULT_MODELS: raise ValueError(f"Unsupported provider: {provider}") self.api_key = api_key or self._load_api_key() if self.provider in self._ENV_KEYS and not self.api_key: raise ValueError(f"API key required for provider '{self.provider}'") self.model = model or self._DEFAULT_MODELS[self.provider] self.max_retries = max_retries self.backoff_base = backoff_base self.sleep_fn = sleep_fn self._mock_sequence = list(mock_sequence) if mock_sequence is not None else [] self._rate_table = rate_table or self._DEFAULT_RATES self._max_cost_usd = max_cost_usd self._usage = { "calls": 0, "input_tokens": 0, "output_tokens": 0, "total_tokens": 0, "total_cost_usd": 0.0 } def _load_api_key(self) -> Optional[str]: env_key = self._ENV_KEYS.get(self.provider) if env_key: return os.getenv(env_key) return None def _count_tokens(self, text: str) -> int: if not text: return 0 return max(1, len(text) // 4) def _calculate_cost(self, input_tokens: int, output_tokens: int) -> float: rates = self._rate_table.get(self.provider, {"input": 0.0, "output": 0.0}) input_cost = (input_tokens / 1000.0) * rates.get("input", 0.0) output_cost = (output_tokens / 1000.0) * rates.get("output", 0.0) return input_cost + output_cost def _is_transient_error(self, error: Exception) -> bool: if isinstance(error, LLMTransientError): return True message = str(error).lower() return any(keyword in message for keyword in ("rate limit", "timeout", "temporarily")) def _ensure_budget(self, allow_equal: bool = False) -> None: if self._max_cost_usd is None: return total_cost = self._usage["total_cost_usd"] if allow_equal: over_budget = total_cost > self._max_cost_usd else: over_budget = total_cost >= self._max_cost_usd if over_budget: raise LLMBudgetExceededError( f"Cost budget exceeded: total_cost={total_cost:.6f} budget={self._max_cost_usd:.6f}", provider=self.provider ) def _mock_complete(self, prompt: str) -> str: if self._mock_sequence: next_item = self._mock_sequence.pop(0) if isinstance(next_item, Exception): raise next_item return str(next_item) return prompt def _complete_provider(self, prompt: str, **kwargs) -> str: if self.provider == "mock": return self._mock_complete(prompt) if self.provider == "local": return prompt raise LLMPermanentError(f"Provider '{self.provider}' not implemented", provider=self.provider) def complete(self, prompt: str, **kwargs) -> LLMResponse: retries = 0 start = time.perf_counter() while True: try: self._ensure_budget() text = self._complete_provider(prompt, **kwargs) input_tokens = self._count_tokens(prompt) output_tokens = self._count_tokens(text) total_tokens = input_tokens + output_tokens cost_usd = self._calculate_cost(input_tokens, output_tokens) latency_ms = max(1, int((time.perf_counter() - start) * 1000)) response = LLMResponse( text=text, input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens, cost_usd=cost_usd, latency_ms=latency_ms, provider=self.provider, model=self.model ) self._record_usage(response) self._ensure_budget(allow_equal=True) return response except Exception as exc: if isinstance(exc, LLMBudgetExceededError): raise if not self._is_transient_error(exc): raise LLMError( str(exc), provider=self.provider, retries=0, is_transient=False ) from exc if retries >= self.max_retries: raise LLMError( str(exc), provider=self.provider, retries=retries, is_transient=True ) from exc sleep_seconds = self.backoff_base * (2 ** retries) self.sleep_fn(sleep_seconds) retries += 1 def _record_usage(self, response: LLMResponse) -> None: self._usage["calls"] += 1 self._usage["input_tokens"] += response.input_tokens self._usage["output_tokens"] += response.output_tokens self._usage["total_tokens"] += response.total_tokens self._usage["total_cost_usd"] += response.cost_usd def get_cost(self) -> float: return float(self._usage["total_cost_usd"]) def get_usage_stats(self) -> Dict[str, Any]: return dict(self._usage) def get_budget_status(self) -> Dict[str, Any]: total = float(self._usage["total_cost_usd"]) budget = self._max_cost_usd remaining = None if budget is None else max(0.0, budget - total) return { "total_cost_usd": total, "budget_usd": budget, "remaining_usd": remaining, "over_budget": budget is not None and total > budget }