734 lines
29 KiB
Python
734 lines
29 KiB
Python
"""
|
|
RLM-MEM - REPL Environment (D1.3)
|
|
RLM-style external memory REPL with secure sandbox execution.
|
|
"""
|
|
|
|
import ast
|
|
import builtins
|
|
import threading
|
|
import time
|
|
import io
|
|
import sys
|
|
from contextlib import contextmanager
|
|
from typing import Any, Dict, Optional, Callable
|
|
from pathlib import Path
|
|
|
|
|
|
class SandboxViolation(Exception):
|
|
"""Raised when code attempts to violate sandbox security."""
|
|
pass
|
|
|
|
|
|
class MaxIterationsError(Exception):
|
|
"""Raised when max iterations exceeded."""
|
|
pass
|
|
|
|
|
|
# Cost budget exceeded
|
|
class CostBudgetExceededError(RuntimeError):
|
|
"""Raised when cost budget is exceeded."""
|
|
pass
|
|
|
|
|
|
# Use built-in TimeoutError
|
|
|
|
|
|
# Allowed built-ins for sandbox
|
|
ALLOWED_BUILTINS = {
|
|
'abs', 'all', 'any', 'ascii', 'bin', 'bool', 'bytearray', 'bytes',
|
|
'callable', 'chr', 'classmethod', 'complex', 'delattr', 'dict',
|
|
'dir', 'divmod', 'enumerate', 'filter', 'float', 'format', 'frozenset',
|
|
'getattr', 'globals', 'hasattr', 'hash', 'help', 'hex', 'id', 'input',
|
|
'int', 'isinstance', 'issubclass', 'iter', 'len', 'list', 'locals',
|
|
'map', 'max', 'memoryview', 'min', 'next', 'object', 'oct', 'ord',
|
|
'pow', 'print', 'property', 'range', 'repr', 'reversed',
|
|
'round', 'set', 'setattr', 'slice', 'sorted', 'staticmethod', 'str',
|
|
'sum', 'super', 'tuple', 'type', 'vars', 'zip', '__build_class__',
|
|
'__name__', 'True', 'False', 'None', 'Exception', 'TypeError',
|
|
'ValueError', 'KeyError', 'IndexError', 'AttributeError', 'RuntimeError',
|
|
'StopIteration', 'ArithmeticError', 'LookupError', 'AssertionError',
|
|
'NotImplementedError', 'ZeroDivisionError', 'OverflowError',
|
|
}
|
|
|
|
# Blocked imports/modules
|
|
BLOCKED_MODULES = {
|
|
'os', 'sys', 'subprocess', 'socket', 'urllib', 'http', 'ftplib',
|
|
'smtplib', 'telnetlib', 'poplib', 'imaplib', 'nntplib', 'ssl',
|
|
'email', 'xmlrpc', 'concurrent.futures.process', 'multiprocessing',
|
|
'ctypes', 'cffi', 'mmap', 'resource', 'posix', 'nt', 'pwd', 'grp',
|
|
'spwd', 'crypt', 'termios', 'tty', 'pty', 'fcntl', 'msvcrt',
|
|
'winreg', '_winapi', 'select', 'selectors', 'asyncio.subprocess',
|
|
}
|
|
|
|
# Allowed modules that get redirected to mocks
|
|
ALLOWED_MODULES = set()
|
|
|
|
|
|
def safe_import(name, globals=None, locals=None, fromlist=(), level=0):
|
|
"""Safe import function that only allows specific modules."""
|
|
base_module = name.split('.')[0] if name else ''
|
|
# Allow sys import (mocked in sandbox)
|
|
if base_module == 'sys':
|
|
if globals and 'sys' in globals:
|
|
return globals['sys']
|
|
raise ImportError("Mock sys not found in sandbox")
|
|
if base_module in ALLOWED_MODULES:
|
|
if globals and base_module in globals:
|
|
return globals[base_module]
|
|
raise ImportError(f"Mock {name} not found in sandbox")
|
|
raise ImportError(f"Import of '{name}' is not allowed in sandbox")
|
|
|
|
|
|
# Blocked attribute names that could be used for sandbox escape
|
|
BLOCKED_ATTRIBUTES = {
|
|
'__class__', '__bases__', '__subclasses__', '__base__',
|
|
'__mro__', '__globals__', '__code__', '__func__', '__self__',
|
|
'__module__', '__dict__', '__closure__', '__defaults__',
|
|
'__kwdefaults__', '__getattribute__', '__setattr__',
|
|
}
|
|
|
|
|
|
class SandboxVisitor(ast.NodeVisitor):
|
|
"""AST visitor to check for sandbox violations."""
|
|
|
|
def __init__(self, allowed_paths: Optional[list] = None):
|
|
self.allowed_paths = allowed_paths or []
|
|
self.violations = []
|
|
|
|
def visit_Import(self, node):
|
|
for alias in node.names:
|
|
module = alias.name.split('.')[0]
|
|
# Allow 'sys' import (redirected to mock in sandbox)
|
|
if module == 'sys':
|
|
continue
|
|
if module in BLOCKED_MODULES and module not in ALLOWED_MODULES:
|
|
self.violations.append(f"Import of '{module}' is not allowed")
|
|
self.generic_visit(node)
|
|
|
|
def visit_ImportFrom(self, node):
|
|
if node.module:
|
|
module = node.module.split('.')[0]
|
|
# Allow 'sys' import (redirected to mock in sandbox)
|
|
if module == 'sys':
|
|
return
|
|
if module in BLOCKED_MODULES and module not in ALLOWED_MODULES:
|
|
self.violations.append(f"Import from '{module}' is not allowed")
|
|
self.generic_visit(node)
|
|
|
|
def visit_Delete(self, node):
|
|
"""Block deletion of builtins attributes."""
|
|
for target in node.targets:
|
|
if isinstance(target, ast.Attribute):
|
|
if self._is_builtins_access(target.value):
|
|
self.violations.append("Deletion of __builtins__ attributes is not allowed")
|
|
if isinstance(target, ast.Subscript):
|
|
if self._is_builtins_access(target.value):
|
|
self.violations.append("Deletion of __builtins__ attributes is not allowed")
|
|
self.generic_visit(node)
|
|
|
|
def visit_Call(self, node):
|
|
# Check for eval/exec/compile
|
|
if isinstance(node.func, ast.Name):
|
|
if node.func.id in ('eval', 'exec', 'compile'):
|
|
self.violations.append(f"Use of '{node.func.id}()' is not allowed")
|
|
# Check for __import__
|
|
if isinstance(node.func, ast.Name) and node.func.id == '__import__':
|
|
self.violations.append("Use of '__import__()' is not allowed")
|
|
# Check for open()
|
|
if isinstance(node.func, ast.Name) and node.func.id == 'open':
|
|
self.violations.append("Use of 'open()' is not allowed")
|
|
|
|
# Check for getattr/setattr on __builtins__
|
|
if isinstance(node.func, ast.Name) and node.func.id == 'getattr':
|
|
if node.args and self._is_builtins_access(node.args[0]):
|
|
self.violations.append("getattr on __builtins__ is not allowed")
|
|
if isinstance(node.func, ast.Name) and node.func.id == 'setattr':
|
|
if node.args and self._is_builtins_access(node.args[0]):
|
|
self.violations.append("setattr on __builtins__ is not allowed")
|
|
if isinstance(node.func, ast.Name) and node.func.id == 'delattr':
|
|
if node.args and self._is_builtins_access(node.args[0]):
|
|
self.violations.append("delattr on __builtins__ is not allowed")
|
|
|
|
self.generic_visit(node)
|
|
|
|
def visit_BinOp(self, node):
|
|
"""Check for large memory allocations via string/list multiplication."""
|
|
if isinstance(node.op, ast.Mult):
|
|
# Check for patterns like "x" * (1024 * 1024 * 100)
|
|
# Try to evaluate the size statically
|
|
try:
|
|
if isinstance(node.left, ast.Constant) and isinstance(node.left.value, str):
|
|
if isinstance(node.right, ast.Constant):
|
|
size = len(node.left.value) * node.right.value
|
|
if size > 10 * 1024 * 1024: # 10MB limit
|
|
raise MemoryError(f"String multiplication would create {size} bytes, exceeding 10MB limit")
|
|
elif isinstance(node.right, ast.BinOp):
|
|
# Try to evaluate binary expression
|
|
size = len(node.left.value) * self._eval_const_expr(node.right)
|
|
if size > 10 * 1024 * 1024: # 10MB limit
|
|
raise MemoryError(f"String multiplication would create {size} bytes, exceeding 10MB limit")
|
|
except MemoryError:
|
|
raise # Re-raise MemoryError
|
|
except Exception:
|
|
pass # Can't evaluate statically, let it run and catch at runtime
|
|
self.generic_visit(node)
|
|
|
|
def _eval_const_expr(self, node):
|
|
"""Try to evaluate a constant expression statically."""
|
|
if isinstance(node, ast.Constant):
|
|
return node.value
|
|
if isinstance(node, ast.BinOp):
|
|
left = self._eval_const_expr(node.left)
|
|
right = self._eval_const_expr(node.right)
|
|
if isinstance(node.op, ast.Mult):
|
|
return left * right
|
|
if isinstance(node.op, ast.Add):
|
|
return left + right
|
|
if isinstance(node.op, ast.Sub):
|
|
return left - right
|
|
raise ValueError("Cannot evaluate expression")
|
|
|
|
def visit_Attribute(self, node):
|
|
"""Check for dangerous attribute access like __class__, __bases__, etc."""
|
|
if node.attr in BLOCKED_ATTRIBUTES:
|
|
self.violations.append(f"Access to '{node.attr}' is not allowed")
|
|
self.generic_visit(node)
|
|
|
|
def visit_Subscript(self, node):
|
|
"""Check for builtins subscript access like globals()['__builtins__']['__import__']."""
|
|
# Check for globals()['__builtins__'] or locals()['__builtins__']
|
|
if isinstance(node.value, ast.Call):
|
|
if isinstance(node.value.func, ast.Name) and node.value.func.id in ('globals', 'locals'):
|
|
if isinstance(node.slice, ast.Constant) and node.slice.value == '__builtins__':
|
|
self.violations.append("globals()/locals()['__builtins__'] manipulation is not allowed")
|
|
elif hasattr(node.slice, 's') and node.slice.s == '__builtins__': # Python < 3.8 compatibility
|
|
self.violations.append("globals()/locals()['__builtins__'] manipulation is not allowed")
|
|
self.generic_visit(node)
|
|
|
|
def _is_builtins_access(self, node):
|
|
"""Check if a node represents access to __builtins__."""
|
|
if isinstance(node, ast.Name) and node.id == '__builtins__':
|
|
return True
|
|
if isinstance(node, ast.Call):
|
|
if isinstance(node.func, ast.Name) and node.func.id in ('globals', 'locals'):
|
|
return True
|
|
return False
|
|
|
|
|
|
class MemoryLimitException(RuntimeError):
|
|
"""Raised when memory limit is exceeded."""
|
|
pass
|
|
|
|
|
|
# Module-level check_safety function
|
|
def check_safety(code: str) -> list:
|
|
"""Check code for sandbox violations."""
|
|
# Pre-check for null bytes and other dangerous characters
|
|
if '\x00' in code:
|
|
return ["Code contains null bytes which is not allowed"]
|
|
|
|
try:
|
|
tree = ast.parse(code)
|
|
except SyntaxError:
|
|
return [] # Let SyntaxError be handled elsewhere
|
|
|
|
visitor = SandboxVisitor()
|
|
visitor.visit(tree)
|
|
return visitor.violations
|
|
|
|
|
|
# Standalone llm_query function for import compatibility
|
|
def llm_query(prompt: str, context: Dict[str, Any] = None) -> str:
|
|
"""
|
|
Standalone llm_query function.
|
|
Note: This is a placeholder - use REPLSession.llm_query() for actual queries.
|
|
"""
|
|
raise RuntimeError("llm_query must be called from a REPLSession instance")
|
|
|
|
|
|
def FINAL(answer) -> None:
|
|
"""Signal that REPL has reached final answer."""
|
|
raise RuntimeError("FINAL() must be called from within a REPL session")
|
|
|
|
|
|
class REPLSession:
|
|
"""
|
|
RLM REPL Session - secure sandbox for recursive LLM execution.
|
|
"""
|
|
|
|
class _StderrCapture:
|
|
"""Mock stderr object for sandbox."""
|
|
def __init__(self, session):
|
|
self._session = session
|
|
|
|
def write(self, text: str):
|
|
"""Write to stderr capture."""
|
|
self._session._stderr.append(text)
|
|
|
|
def flush(self):
|
|
"""Flush stderr (no-op)."""
|
|
pass
|
|
|
|
class MockSys:
|
|
"""Mock sys module for sandbox with only stderr."""
|
|
def __init__(self, stderr_capture):
|
|
self.stderr = stderr_capture
|
|
|
|
def __getattr__(self, name):
|
|
if name == 'modules':
|
|
raise SandboxViolation("Access to sys.modules is not allowed")
|
|
raise AttributeError(f"sys.{name} is not available in sandbox")
|
|
|
|
def __init__(self, chunk_store=None, llm_client=None,
|
|
max_iterations: int = 10, timeout_seconds: int = 60, max_depth: int = 5,
|
|
max_cost_usd: Optional[float] = None):
|
|
"""
|
|
Initialize REPL session.
|
|
|
|
Args:
|
|
chunk_store: ChunkStore instance for memory access
|
|
llm_client: LLM client for recursive queries
|
|
max_iterations: Maximum recursive iterations allowed
|
|
timeout_seconds: Execution timeout
|
|
max_depth: Maximum recursion depth
|
|
"""
|
|
if chunk_store is None:
|
|
raise ValueError("chunk_store is required")
|
|
if llm_client is None:
|
|
raise ValueError("llm_client is required")
|
|
|
|
self.chunk_store = chunk_store
|
|
self.llm_client = llm_client
|
|
self.max_iterations = max_iterations
|
|
self.timeout_seconds = timeout_seconds
|
|
self.max_depth = max_depth
|
|
self._max_cost_usd = max_cost_usd
|
|
|
|
self._state: Dict[str, Any] = {} # User state (empty initially)
|
|
self._iteration_count = 0
|
|
self._total_cost = 0.0
|
|
self._current_depth = 0
|
|
self._result = None
|
|
self._complete = False
|
|
self._lock = threading.RLock()
|
|
self._output = []
|
|
self._stderr = []
|
|
self._stderr_capture = self._StderrCapture(self)
|
|
|
|
# Create isolated namespace for execution
|
|
self._namespace = {}
|
|
self._setup_namespace()
|
|
|
|
def _setup_namespace(self):
|
|
"""Set up the sandbox namespace."""
|
|
# Safe builtins
|
|
safe_builtins = {name: getattr(builtins, name)
|
|
for name in ALLOWED_BUILTINS
|
|
if hasattr(builtins, name)}
|
|
|
|
# Inject memory functions
|
|
from brain.scripts.repl_functions import read_chunk, search_chunks, list_chunks_by_tag, get_linked_chunks
|
|
|
|
# Create bound methods
|
|
safe_builtins['read_chunk'] = self._read_chunk_wrapper
|
|
safe_builtins['search_chunks'] = self._search_chunks_wrapper
|
|
safe_builtins['list_chunks_by_tag'] = self._list_chunks_by_tag_wrapper
|
|
safe_builtins['get_linked_chunks'] = self._get_linked_chunks_wrapper
|
|
safe_builtins['llm_query'] = self._llm_query_wrapper
|
|
safe_builtins['FINAL'] = self._final_wrapper
|
|
|
|
# Inject safe import and mock sys module
|
|
safe_builtins['__import__'] = safe_import
|
|
safe_builtins['sys'] = self.MockSys(self._stderr_capture)
|
|
|
|
self._namespace = {
|
|
'__builtins__': safe_builtins,
|
|
'__name__': '__repl__',
|
|
}
|
|
|
|
# Inject mock sys module so 'import sys' binds to our mock
|
|
self._namespace['sys'] = self.MockSys(self._stderr_capture)
|
|
|
|
# Merge user state into namespace
|
|
self._namespace.update(self._state)
|
|
|
|
def _read_chunk_wrapper(self, chunk_id: str):
|
|
"""Wrapper for read_chunk."""
|
|
from repl_functions import read_chunk
|
|
return read_chunk(chunk_id, self.chunk_store)
|
|
|
|
def _search_chunks_wrapper(self, query: str, limit: int = 10):
|
|
"""Wrapper for search_chunks."""
|
|
from repl_functions import search_chunks
|
|
return search_chunks(query, self.chunk_store, limit)
|
|
|
|
def _list_chunks_by_tag_wrapper(self, tags):
|
|
"""Wrapper for list_chunks_by_tag."""
|
|
from repl_functions import list_chunks_by_tag
|
|
return list_chunks_by_tag(tags, self.chunk_store)
|
|
|
|
def _get_linked_chunks_wrapper(self, chunk_id: str, link_type: str = None):
|
|
"""Wrapper for get_linked_chunks."""
|
|
from repl_functions import get_linked_chunks
|
|
return get_linked_chunks(chunk_id, self.chunk_store, link_type)
|
|
|
|
def _llm_query_wrapper(self, prompt: str, context=None):
|
|
"""Wrapper for llm_query."""
|
|
with self._lock:
|
|
self._iteration_count += 1
|
|
if self._iteration_count > self.max_iterations:
|
|
raise MaxIterationsError(
|
|
f"Maximum iterations ({self.max_iterations}) exceeded"
|
|
)
|
|
|
|
# Check max depth
|
|
if self._current_depth >= self.max_depth:
|
|
raise RecursionError(f"Maximum recursion depth ({self.max_depth}) exceeded")
|
|
|
|
# Increment depth counter
|
|
self._current_depth += 1
|
|
|
|
try:
|
|
self._ensure_budget()
|
|
# Build full prompt with context
|
|
full_prompt = prompt
|
|
if context:
|
|
# Handle context as a list of chunk IDs
|
|
if isinstance(context, list):
|
|
from repl_functions import read_chunk
|
|
context_parts = []
|
|
for chunk_id in context:
|
|
chunk = read_chunk(chunk_id, self.chunk_store)
|
|
if chunk:
|
|
context_parts.append(f"Chunk {chunk_id}:\n{chunk.get('content', '')}")
|
|
else:
|
|
context_parts.append(f"Chunk {chunk_id}:\n[Not found]")
|
|
context_str = "\n\n".join(context_parts)
|
|
full_prompt = f"Context:\n{context_str}\n\nPrompt:\n{prompt}"
|
|
elif isinstance(context, dict):
|
|
context_str = "\n".join(f"{k}: {v}" for k, v in context.items())
|
|
full_prompt = f"Context:\n{context_str}\n\nPrompt:\n{prompt}"
|
|
|
|
# Call LLM
|
|
response = self.llm_client.complete(full_prompt)
|
|
|
|
self._record_cost(response)
|
|
self._ensure_budget(allow_equal=True)
|
|
|
|
return response.text if hasattr(response, 'text') else str(response)
|
|
except (RecursionError, MaxIterationsError):
|
|
# Don't catch these - let them propagate
|
|
raise
|
|
except Exception as e:
|
|
# Handle API errors gracefully
|
|
return f"Error: {str(e)}"
|
|
finally:
|
|
# Decrement depth counter
|
|
with self._lock:
|
|
self._current_depth -= 1
|
|
|
|
def _final_wrapper(self, answer) -> None:
|
|
"""Wrapper for FINAL."""
|
|
if self._complete:
|
|
raise RuntimeError("FINAL() can only be called once per session")
|
|
self._result = answer
|
|
self._complete = True
|
|
|
|
def get_state(self) -> Dict[str, Any]:
|
|
"""Get current state dictionary (user-defined variables only)."""
|
|
return self._state.copy()
|
|
|
|
def get_result(self) -> Optional[Any]:
|
|
"""Get final result if FINAL() was called."""
|
|
return self._result
|
|
|
|
def is_complete(self) -> bool:
|
|
"""Check if FINAL() has been called."""
|
|
return self._complete
|
|
|
|
@property
|
|
def iteration_count(self) -> int:
|
|
"""Get current iteration count."""
|
|
return self._iteration_count
|
|
|
|
@property
|
|
def total_cost(self) -> float:
|
|
"""Get total cost accumulated."""
|
|
return self._total_cost
|
|
|
|
def get_cost(self) -> float:
|
|
"""Get total cost accumulated."""
|
|
return self._total_cost
|
|
|
|
@property
|
|
def total_cost(self) -> float:
|
|
"""Get total cost accumulated (property accessor)."""
|
|
return self._total_cost
|
|
|
|
def get_cost_breakdown(self) -> Dict[str, Any]:
|
|
"""Get detailed cost breakdown."""
|
|
breakdown = {
|
|
"total": self._total_cost,
|
|
"calls": self._iteration_count,
|
|
"per_call_average": self._total_cost / self._iteration_count if self._iteration_count > 0 else 0.0
|
|
}
|
|
if self._max_cost_usd is not None:
|
|
remaining = self._max_cost_usd - self._total_cost
|
|
breakdown.update({
|
|
"budget": self._max_cost_usd,
|
|
"remaining": max(0.0, remaining),
|
|
"over_budget": self._total_cost > self._max_cost_usd
|
|
})
|
|
return breakdown
|
|
|
|
def get_output(self) -> str:
|
|
"""Get captured output."""
|
|
return "\n".join(self._output)
|
|
|
|
def get_stderr(self) -> str:
|
|
"""Get captured stderr."""
|
|
return "\n".join(self._stderr)
|
|
|
|
def clear_output(self):
|
|
"""Clear captured output."""
|
|
self._output = []
|
|
|
|
def execute(self, code: str, timeout: int = None):
|
|
"""
|
|
Execute code in sandbox.
|
|
|
|
Args:
|
|
code: Python code to execute
|
|
timeout: Optional timeout override
|
|
|
|
Returns:
|
|
Result of the last expression or None
|
|
|
|
Raises:
|
|
RuntimeError: If called after FINAL()
|
|
SandboxViolation: If code violates sandbox
|
|
TimeoutError: If execution times out
|
|
"""
|
|
if self._complete:
|
|
raise RuntimeError("REPL already complete")
|
|
|
|
if not code or not code.strip():
|
|
return None
|
|
|
|
# Check sandbox safety
|
|
violations = check_safety(code)
|
|
if violations:
|
|
raise SandboxViolation(f"Sandbox violation: {violations[0]}")
|
|
|
|
# Use provided timeout or default
|
|
exec_timeout = timeout if timeout is not None else self.timeout_seconds
|
|
|
|
# Capture stdout/stderr
|
|
old_stdout = sys.stdout
|
|
old_stderr = sys.stderr
|
|
stdout_capture = io.StringIO()
|
|
stderr_capture = io.StringIO()
|
|
|
|
# Container for execution results
|
|
result_container = {'result': None, 'error': None, 'completed': False}
|
|
|
|
def run_execution():
|
|
try:
|
|
sys.stdout = stdout_capture
|
|
sys.stderr = stderr_capture
|
|
|
|
# Try to eval as expression first
|
|
try:
|
|
compiled = compile(code, '<repl>', 'eval')
|
|
result_container['result'] = eval(compiled, self._namespace)
|
|
result_container['completed'] = True
|
|
return
|
|
except SyntaxError:
|
|
# Not an expression, try exec
|
|
pass
|
|
|
|
# Compile and execute as statements
|
|
compiled = compile(code, '<repl>', 'exec')
|
|
exec(compiled, self._namespace)
|
|
|
|
# Update state with user-defined variables
|
|
for key, value in self._namespace.items():
|
|
if not key.startswith('_') and key not in ('__builtins__', '__name__'):
|
|
self._state[key] = value
|
|
|
|
result_container['completed'] = True
|
|
|
|
except Exception as e:
|
|
result_container['error'] = e
|
|
|
|
# Run execution in a thread with timeout
|
|
exec_thread = threading.Thread(target=run_execution)
|
|
exec_thread.daemon = True
|
|
|
|
try:
|
|
sys.stdout = stdout_capture
|
|
sys.stderr = stderr_capture
|
|
|
|
exec_thread.start()
|
|
exec_thread.join(timeout=exec_timeout)
|
|
|
|
if exec_thread.is_alive():
|
|
# Thread is still running after timeout
|
|
raise TimeoutError(f"Execution exceeded {exec_timeout} seconds")
|
|
|
|
# Check for errors from the thread
|
|
if result_container['error'] is not None:
|
|
raise result_container['error']
|
|
|
|
# Capture output
|
|
self._output.append(stdout_capture.getvalue())
|
|
self._stderr.append(stderr_capture.getvalue())
|
|
|
|
return result_container['result']
|
|
|
|
except TimeoutError:
|
|
raise
|
|
except RecursionError:
|
|
# Let RecursionError propagate for depth limit testing
|
|
raise
|
|
except SandboxViolation:
|
|
# Let SandboxViolation propagate for security tests
|
|
raise
|
|
except SyntaxError as e:
|
|
error_msg = f"Syntax error: {e}"
|
|
self._output.append(error_msg)
|
|
return error_msg
|
|
except ZeroDivisionError as e:
|
|
error_msg = f"Zero division error: {e}"
|
|
self._output.append(error_msg)
|
|
return error_msg
|
|
except NameError as e:
|
|
# Return NameError as string for undefined name tests
|
|
error_msg = f"Name error: {e}"
|
|
self._output.append(error_msg)
|
|
return error_msg
|
|
except AttributeError as e:
|
|
error_msg = f"Attribute error: {e}"
|
|
self._output.append(error_msg)
|
|
return error_msg
|
|
except MemoryError as e:
|
|
error_msg = f"Memory error: {e}"
|
|
self._output.append(error_msg)
|
|
return error_msg
|
|
except Exception as e:
|
|
# Other exceptions - return as error string
|
|
error_msg = f"Runtime error: {e}"
|
|
self._output.append(error_msg)
|
|
return error_msg
|
|
finally:
|
|
sys.stdout = old_stdout
|
|
sys.stderr = old_stderr
|
|
|
|
def retrieve(self, query=None, max_iterations=None) -> Optional[Any]:
|
|
"""
|
|
Execute retrieval workflow for a query.
|
|
|
|
Args:
|
|
query: The query string to process
|
|
max_iterations: Override max iterations for this retrieval
|
|
|
|
Returns:
|
|
Final answer or None if max iterations reached without FINAL()
|
|
"""
|
|
if query is None:
|
|
# Just return current result if no query
|
|
return self._result if self._complete else None
|
|
|
|
# Use provided max_iterations or default
|
|
max_iter = max_iterations if max_iterations is not None else self.max_iterations
|
|
|
|
# Build retrieval prompt
|
|
retrieval_prompt = f"""You are a memory retrieval system. Answer the following query using the available memory functions.
|
|
|
|
Available functions:
|
|
- read_chunk(chunk_id): Read a chunk by ID
|
|
- search_chunks(query, limit=10): Search for chunks
|
|
- list_chunks_by_tag(tag): List chunks with a tag
|
|
- get_linked_chunks(chunk_id, link_type=None): Get linked chunks
|
|
- llm_query(prompt, context=None): Ask LLM for help
|
|
- FINAL(answer): Call when you have the final answer
|
|
|
|
Query: {query}
|
|
|
|
Write Python code to solve this query. Use FINAL('your answer') when done."""
|
|
|
|
# Iterative retrieval loop
|
|
for iteration in range(max_iter):
|
|
self._iteration_count += 1
|
|
|
|
# Get LLM response
|
|
try:
|
|
self._ensure_budget()
|
|
response = self.llm_client.complete(retrieval_prompt)
|
|
code = response.text if hasattr(response, 'text') else str(response)
|
|
self._record_cost(response)
|
|
self._ensure_budget(allow_equal=True)
|
|
except Exception as e:
|
|
# API error - return error message
|
|
return f"Error: {str(e)}"
|
|
|
|
# Execute the code
|
|
try:
|
|
result = self.execute(code)
|
|
|
|
# Check if FINAL was called
|
|
if self._complete:
|
|
return self._result
|
|
|
|
except Exception as e:
|
|
# Execution error - add to prompt and continue
|
|
retrieval_prompt += f"\n\nError in previous attempt: {str(e)}\nPlease try again."
|
|
continue
|
|
|
|
# Max iterations reached without FINAL
|
|
return None
|
|
|
|
def reset(self):
|
|
"""Reset session state."""
|
|
self._state = {}
|
|
self._iteration_count = 0
|
|
self._total_cost = 0.0
|
|
self._current_depth = 0
|
|
self._result = None
|
|
self._complete = False
|
|
self._output = []
|
|
self._stderr = []
|
|
self._setup_namespace()
|
|
|
|
def _record_cost(self, response: Any) -> None:
|
|
"""Record cost from response or LLM client."""
|
|
cost_value = None
|
|
if hasattr(response, 'cost_usd'):
|
|
cost_value = response.cost_usd
|
|
elif hasattr(self.llm_client, 'get_cost') and callable(self.llm_client.get_cost):
|
|
cost_value = self.llm_client.get_cost()
|
|
if not isinstance(cost_value, (int, float)):
|
|
return
|
|
self._total_cost += float(cost_value)
|
|
|
|
def _ensure_budget(self, allow_equal: bool = False) -> None:
|
|
"""Ensure cost budget has not been exceeded."""
|
|
if self._max_cost_usd is None:
|
|
return
|
|
if allow_equal:
|
|
over_budget = self._total_cost > self._max_cost_usd
|
|
else:
|
|
over_budget = self._total_cost >= self._max_cost_usd
|
|
if over_budget:
|
|
raise CostBudgetExceededError(
|
|
f"Cost budget exceeded: total_cost={self._total_cost:.6f} budget={self._max_cost_usd:.6f}"
|
|
)
|
|
|
|
def __enter__(self):
|
|
"""Context manager entry."""
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
"""Context manager exit."""
|
|
self.reset()
|
|
return False
|