Critical Security Fixes: - Fix command injection vulnerability in Windows shims (beadboard.cmd, bb.cmd) - Added path validation to block traversal (.. and root-relative paths) - Added quotes around env var to prevent command injection Reliability Fixes: - Fix agent cache null safety bug - Fixed callBdAgentShow() to check for cache misses (null check, expiration) - Fixed getCachedAgent to properly return entry.data or null - Fix null body crashes in mail ack route - Added null check before casting body to object - Returns 400 error instead of 500 for invalid requests BD Compliance Fixes: - Fix read-issues to use BD audit record path - Ensures all writes go through bd audit record - Maintains watcher/SSE parity and Dolt commit tracking Code Quality Fixes: - Fix path canonicalization violations - Use canonicalizeWindowsPath() and windowsPathKey() from pathing module - Prevents Windows edge cases and ensures machine-reproducible paths - Fix typo: mobile-fronted → mobile-frontend - Pin GitHub Actions tags - softprops/action-gh-release@v1 → specific commit hash - Register pr14 test in package.json (already registered) Testing: - Refactor broad exception handlers in Python scripts - Replace except Exception: with specific exceptions - Allows KeyboardInterrupt and SystemExit to propagate correctly - All tests passing
910 lines
30 KiB
Python
910 lines
30 KiB
Python
"""
|
|
RLM-MEM - REPL Environment (D1.3)
|
|
RLM-style external memory REPL with secure sandbox execution.
|
|
"""
|
|
|
|
import ast
|
|
import builtins
|
|
import threading
|
|
import time
|
|
import io
|
|
import sys
|
|
from contextlib import contextmanager
|
|
from typing import Any, Dict, Optional, Callable
|
|
from pathlib import Path
|
|
|
|
|
|
class SandboxViolation(Exception):
|
|
"""Raised when code attempts to violate sandbox security."""
|
|
|
|
pass
|
|
|
|
|
|
class MaxIterationsError(Exception):
|
|
"""Raised when max iterations exceeded."""
|
|
|
|
pass
|
|
|
|
|
|
# Cost budget exceeded
|
|
class CostBudgetExceededError(RuntimeError):
|
|
"""Raised when cost budget is exceeded."""
|
|
|
|
pass
|
|
|
|
|
|
# Use built-in TimeoutError
|
|
|
|
|
|
# Allowed built-ins for sandbox
|
|
ALLOWED_BUILTINS = {
|
|
"abs",
|
|
"all",
|
|
"any",
|
|
"ascii",
|
|
"bin",
|
|
"bool",
|
|
"bytearray",
|
|
"bytes",
|
|
"callable",
|
|
"chr",
|
|
"classmethod",
|
|
"complex",
|
|
"delattr",
|
|
"dict",
|
|
"dir",
|
|
"divmod",
|
|
"enumerate",
|
|
"filter",
|
|
"float",
|
|
"format",
|
|
"frozenset",
|
|
"getattr",
|
|
"globals",
|
|
"hasattr",
|
|
"hash",
|
|
"help",
|
|
"hex",
|
|
"id",
|
|
"input",
|
|
"int",
|
|
"isinstance",
|
|
"issubclass",
|
|
"iter",
|
|
"len",
|
|
"list",
|
|
"locals",
|
|
"map",
|
|
"max",
|
|
"memoryview",
|
|
"min",
|
|
"next",
|
|
"object",
|
|
"oct",
|
|
"ord",
|
|
"pow",
|
|
"print",
|
|
"property",
|
|
"range",
|
|
"repr",
|
|
"reversed",
|
|
"round",
|
|
"set",
|
|
"setattr",
|
|
"slice",
|
|
"sorted",
|
|
"staticmethod",
|
|
"str",
|
|
"sum",
|
|
"super",
|
|
"tuple",
|
|
"type",
|
|
"vars",
|
|
"zip",
|
|
"__build_class__",
|
|
"__name__",
|
|
"True",
|
|
"False",
|
|
"None",
|
|
"Exception",
|
|
"TypeError",
|
|
"ValueError",
|
|
"KeyError",
|
|
"IndexError",
|
|
"AttributeError",
|
|
"RuntimeError",
|
|
"StopIteration",
|
|
"ArithmeticError",
|
|
"LookupError",
|
|
"AssertionError",
|
|
"NotImplementedError",
|
|
"ZeroDivisionError",
|
|
"OverflowError",
|
|
}
|
|
|
|
# Blocked imports/modules
|
|
BLOCKED_MODULES = {
|
|
"os",
|
|
"sys",
|
|
"subprocess",
|
|
"socket",
|
|
"urllib",
|
|
"http",
|
|
"ftplib",
|
|
"smtplib",
|
|
"telnetlib",
|
|
"poplib",
|
|
"imaplib",
|
|
"nntplib",
|
|
"ssl",
|
|
"email",
|
|
"xmlrpc",
|
|
"concurrent.futures.process",
|
|
"multiprocessing",
|
|
"ctypes",
|
|
"cffi",
|
|
"mmap",
|
|
"resource",
|
|
"posix",
|
|
"nt",
|
|
"pwd",
|
|
"grp",
|
|
"spwd",
|
|
"crypt",
|
|
"termios",
|
|
"tty",
|
|
"pty",
|
|
"fcntl",
|
|
"msvcrt",
|
|
"winreg",
|
|
"_winapi",
|
|
"select",
|
|
"selectors",
|
|
"asyncio.subprocess",
|
|
}
|
|
|
|
# Allowed modules that get redirected to mocks
|
|
ALLOWED_MODULES = set()
|
|
|
|
|
|
def safe_import(name, globals=None, locals=None, fromlist=(), level=0):
|
|
"""Safe import function that only allows specific modules."""
|
|
base_module = name.split(".")[0] if name else ""
|
|
# Allow sys import (mocked in sandbox)
|
|
if base_module == "sys":
|
|
if globals and "sys" in globals:
|
|
return globals["sys"]
|
|
raise ImportError("Mock sys not found in sandbox")
|
|
if base_module in ALLOWED_MODULES:
|
|
if globals and base_module in globals:
|
|
return globals[base_module]
|
|
raise ImportError(f"Mock {name} not found in sandbox")
|
|
raise ImportError(f"Import of '{name}' is not allowed in sandbox")
|
|
|
|
|
|
# Blocked attribute names that could be used for sandbox escape
|
|
BLOCKED_ATTRIBUTES = {
|
|
"__class__",
|
|
"__bases__",
|
|
"__subclasses__",
|
|
"__base__",
|
|
"__mro__",
|
|
"__globals__",
|
|
"__code__",
|
|
"__func__",
|
|
"__self__",
|
|
"__module__",
|
|
"__dict__",
|
|
"__closure__",
|
|
"__defaults__",
|
|
"__kwdefaults__",
|
|
"__getattribute__",
|
|
"__setattr__",
|
|
}
|
|
|
|
|
|
class SandboxVisitor(ast.NodeVisitor):
|
|
"""AST visitor to check for sandbox violations."""
|
|
|
|
def __init__(self, allowed_paths: Optional[list] = None):
|
|
self.allowed_paths = allowed_paths or []
|
|
self.violations = []
|
|
|
|
def visit_Import(self, node):
|
|
for alias in node.names:
|
|
module = alias.name.split(".")[0]
|
|
# Allow 'sys' import (redirected to mock in sandbox)
|
|
if module == "sys":
|
|
continue
|
|
if module in BLOCKED_MODULES and module not in ALLOWED_MODULES:
|
|
self.violations.append(f"Import of '{module}' is not allowed")
|
|
self.generic_visit(node)
|
|
|
|
def visit_ImportFrom(self, node):
|
|
if node.module:
|
|
module = node.module.split(".")[0]
|
|
# Allow 'sys' import (redirected to mock in sandbox)
|
|
if module == "sys":
|
|
return
|
|
if module in BLOCKED_MODULES and module not in ALLOWED_MODULES:
|
|
self.violations.append(f"Import from '{module}' is not allowed")
|
|
self.generic_visit(node)
|
|
|
|
def visit_Delete(self, node):
|
|
"""Block deletion of builtins attributes."""
|
|
for target in node.targets:
|
|
if isinstance(target, ast.Attribute):
|
|
if self._is_builtins_access(target.value):
|
|
self.violations.append(
|
|
"Deletion of __builtins__ attributes is not allowed"
|
|
)
|
|
if isinstance(target, ast.Subscript):
|
|
if self._is_builtins_access(target.value):
|
|
self.violations.append(
|
|
"Deletion of __builtins__ attributes is not allowed"
|
|
)
|
|
self.generic_visit(node)
|
|
|
|
def visit_Call(self, node):
|
|
# Check for eval/exec/compile
|
|
if isinstance(node.func, ast.Name):
|
|
if node.func.id in ("eval", "exec", "compile"):
|
|
self.violations.append(f"Use of '{node.func.id}()' is not allowed")
|
|
# Check for __import__
|
|
if isinstance(node.func, ast.Name) and node.func.id == "__import__":
|
|
self.violations.append("Use of '__import__()' is not allowed")
|
|
# Check for open()
|
|
if isinstance(node.func, ast.Name) and node.func.id == "open":
|
|
self.violations.append("Use of 'open()' is not allowed")
|
|
|
|
# Check for getattr/setattr on __builtins__
|
|
if isinstance(node.func, ast.Name) and node.func.id == "getattr":
|
|
if node.args and self._is_builtins_access(node.args[0]):
|
|
self.violations.append("getattr on __builtins__ is not allowed")
|
|
if isinstance(node.func, ast.Name) and node.func.id == "setattr":
|
|
if node.args and self._is_builtins_access(node.args[0]):
|
|
self.violations.append("setattr on __builtins__ is not allowed")
|
|
if isinstance(node.func, ast.Name) and node.func.id == "delattr":
|
|
if node.args and self._is_builtins_access(node.args[0]):
|
|
self.violations.append("delattr on __builtins__ is not allowed")
|
|
|
|
self.generic_visit(node)
|
|
|
|
def visit_BinOp(self, node):
|
|
"""Check for large memory allocations via string/list multiplication."""
|
|
if isinstance(node.op, ast.Mult):
|
|
# Check for patterns like "x" * (1024 * 1024 * 100)
|
|
# Try to evaluate the size statically
|
|
try:
|
|
if isinstance(node.left, ast.Constant) and isinstance(
|
|
node.left.value, str
|
|
):
|
|
if isinstance(node.right, ast.Constant):
|
|
size = len(node.left.value) * node.right.value
|
|
if size > 10 * 1024 * 1024: # 10MB limit
|
|
raise MemoryError(
|
|
f"String multiplication would create {size} bytes, exceeding 10MB limit"
|
|
)
|
|
elif isinstance(node.right, ast.BinOp):
|
|
# Try to evaluate binary expression
|
|
size = len(node.left.value) * self._eval_const_expr(node.right)
|
|
if size > 10 * 1024 * 1024: # 10MB limit
|
|
raise MemoryError(
|
|
f"String multiplication would create {size} bytes, exceeding 10MB limit"
|
|
)
|
|
except MemoryError:
|
|
raise # Re-raise MemoryError
|
|
except (ValueError, TypeError, AttributeError):
|
|
pass # Can't evaluate statically, let it run and catch at runtime
|
|
self.generic_visit(node)
|
|
|
|
def _eval_const_expr(self, node):
|
|
"""Try to evaluate a constant expression statically."""
|
|
if isinstance(node, ast.Constant):
|
|
return node.value
|
|
if isinstance(node, ast.BinOp):
|
|
left = self._eval_const_expr(node.left)
|
|
right = self._eval_const_expr(node.right)
|
|
if isinstance(node.op, ast.Mult):
|
|
return left * right
|
|
if isinstance(node.op, ast.Add):
|
|
return left + right
|
|
if isinstance(node.op, ast.Sub):
|
|
return left - right
|
|
raise ValueError("Cannot evaluate expression")
|
|
|
|
def visit_Attribute(self, node):
|
|
"""Check for dangerous attribute access like __class__, __bases__, etc."""
|
|
if node.attr in BLOCKED_ATTRIBUTES:
|
|
self.violations.append(f"Access to '{node.attr}' is not allowed")
|
|
self.generic_visit(node)
|
|
|
|
def visit_Subscript(self, node):
|
|
"""Check for builtins subscript access like globals()['__builtins__']['__import__']."""
|
|
# Check for globals()['__builtins__'] or locals()['__builtins__']
|
|
if isinstance(node.value, ast.Call):
|
|
if isinstance(node.value.func, ast.Name) and node.value.func.id in (
|
|
"globals",
|
|
"locals",
|
|
):
|
|
if (
|
|
isinstance(node.slice, ast.Constant)
|
|
and node.slice.value == "__builtins__"
|
|
):
|
|
self.violations.append(
|
|
"globals()/locals()['__builtins__'] manipulation is not allowed"
|
|
)
|
|
elif (
|
|
hasattr(node.slice, "s") and node.slice.s == "__builtins__"
|
|
): # Python < 3.8 compatibility
|
|
self.violations.append(
|
|
"globals()/locals()['__builtins__'] manipulation is not allowed"
|
|
)
|
|
self.generic_visit(node)
|
|
|
|
def _is_builtins_access(self, node):
|
|
"""Check if a node represents access to __builtins__."""
|
|
if isinstance(node, ast.Name) and node.id == "__builtins__":
|
|
return True
|
|
if isinstance(node, ast.Call):
|
|
if isinstance(node.func, ast.Name) and node.func.id in (
|
|
"globals",
|
|
"locals",
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
class MemoryLimitException(RuntimeError):
|
|
"""Raised when memory limit is exceeded."""
|
|
|
|
pass
|
|
|
|
|
|
# Module-level check_safety function
|
|
def check_safety(code: str) -> list:
|
|
"""Check code for sandbox violations."""
|
|
# Pre-check for null bytes and other dangerous characters
|
|
if "\x00" in code:
|
|
return ["Code contains null bytes which is not allowed"]
|
|
|
|
try:
|
|
tree = ast.parse(code)
|
|
except SyntaxError:
|
|
return [] # Let SyntaxError be handled elsewhere
|
|
|
|
visitor = SandboxVisitor()
|
|
visitor.visit(tree)
|
|
return visitor.violations
|
|
|
|
|
|
# Standalone llm_query function for import compatibility
|
|
def llm_query(prompt: str, context: Dict[str, Any] = None) -> str:
|
|
"""
|
|
Standalone llm_query function.
|
|
Note: This is a placeholder - use REPLSession.llm_query() for actual queries.
|
|
"""
|
|
raise RuntimeError("llm_query must be called from a REPLSession instance")
|
|
|
|
|
|
def FINAL(answer) -> None:
|
|
"""Signal that REPL has reached final answer."""
|
|
raise RuntimeError("FINAL() must be called from within a REPL session")
|
|
|
|
|
|
class REPLSession:
|
|
"""
|
|
RLM REPL Session - secure sandbox for recursive LLM execution.
|
|
"""
|
|
|
|
class _StderrCapture:
|
|
"""Mock stderr object for sandbox."""
|
|
|
|
def __init__(self, session):
|
|
self._session = session
|
|
|
|
def write(self, text: str):
|
|
"""Write to stderr capture."""
|
|
self._session._stderr.append(text)
|
|
|
|
def flush(self):
|
|
"""Flush stderr (no-op)."""
|
|
pass
|
|
|
|
class MockSys:
|
|
"""Mock sys module for sandbox with only stderr."""
|
|
|
|
def __init__(self, stderr_capture):
|
|
self.stderr = stderr_capture
|
|
|
|
def __getattr__(self, name):
|
|
if name == "modules":
|
|
raise SandboxViolation("Access to sys.modules is not allowed")
|
|
raise AttributeError(f"sys.{name} is not available in sandbox")
|
|
|
|
def __init__(
|
|
self,
|
|
chunk_store=None,
|
|
llm_client=None,
|
|
max_iterations: int = 10,
|
|
timeout_seconds: int = 60,
|
|
max_depth: int = 5,
|
|
max_cost_usd: Optional[float] = None,
|
|
):
|
|
"""
|
|
Initialize REPL session.
|
|
|
|
Args:
|
|
chunk_store: ChunkStore instance for memory access
|
|
llm_client: LLM client for recursive queries
|
|
max_iterations: Maximum recursive iterations allowed
|
|
timeout_seconds: Execution timeout
|
|
max_depth: Maximum recursion depth
|
|
"""
|
|
if chunk_store is None:
|
|
raise ValueError("chunk_store is required")
|
|
if llm_client is None:
|
|
raise ValueError("llm_client is required")
|
|
|
|
self.chunk_store = chunk_store
|
|
self.llm_client = llm_client
|
|
self.max_iterations = max_iterations
|
|
self.timeout_seconds = timeout_seconds
|
|
self.max_depth = max_depth
|
|
self._max_cost_usd = max_cost_usd
|
|
|
|
self._state: Dict[str, Any] = {} # User state (empty initially)
|
|
self._iteration_count = 0
|
|
self._total_cost = 0.0
|
|
self._current_depth = 0
|
|
self._result = None
|
|
self._complete = False
|
|
self._lock = threading.RLock()
|
|
self._output = []
|
|
self._stderr = []
|
|
self._stderr_capture = self._StderrCapture(self)
|
|
|
|
# Create isolated namespace for execution
|
|
self._namespace = {}
|
|
self._setup_namespace()
|
|
|
|
def _setup_namespace(self):
|
|
"""Set up the sandbox namespace."""
|
|
# Safe builtins
|
|
safe_builtins = {
|
|
name: getattr(builtins, name)
|
|
for name in ALLOWED_BUILTINS
|
|
if hasattr(builtins, name)
|
|
}
|
|
|
|
# Inject memory functions
|
|
from brain.scripts.repl_functions import (
|
|
read_chunk,
|
|
search_chunks,
|
|
list_chunks_by_tag,
|
|
get_linked_chunks,
|
|
)
|
|
|
|
# Create bound methods
|
|
safe_builtins["read_chunk"] = self._read_chunk_wrapper
|
|
safe_builtins["search_chunks"] = self._search_chunks_wrapper
|
|
safe_builtins["list_chunks_by_tag"] = self._list_chunks_by_tag_wrapper
|
|
safe_builtins["get_linked_chunks"] = self._get_linked_chunks_wrapper
|
|
safe_builtins["llm_query"] = self._llm_query_wrapper
|
|
safe_builtins["FINAL"] = self._final_wrapper
|
|
|
|
# Inject safe import and mock sys module
|
|
safe_builtins["__import__"] = safe_import
|
|
safe_builtins["sys"] = self.MockSys(self._stderr_capture)
|
|
|
|
self._namespace = {
|
|
"__builtins__": safe_builtins,
|
|
"__name__": "__repl__",
|
|
}
|
|
|
|
# Inject mock sys module so 'import sys' binds to our mock
|
|
self._namespace["sys"] = self.MockSys(self._stderr_capture)
|
|
|
|
# Merge user state into namespace
|
|
self._namespace.update(self._state)
|
|
|
|
def _read_chunk_wrapper(self, chunk_id: str):
|
|
"""Wrapper for read_chunk."""
|
|
from repl_functions import read_chunk
|
|
|
|
return read_chunk(chunk_id, self.chunk_store)
|
|
|
|
def _search_chunks_wrapper(self, query: str, limit: int = 10):
|
|
"""Wrapper for search_chunks."""
|
|
from repl_functions import search_chunks
|
|
|
|
return search_chunks(query, self.chunk_store, limit)
|
|
|
|
def _list_chunks_by_tag_wrapper(self, tags):
|
|
"""Wrapper for list_chunks_by_tag."""
|
|
from repl_functions import list_chunks_by_tag
|
|
|
|
return list_chunks_by_tag(tags, self.chunk_store)
|
|
|
|
def _get_linked_chunks_wrapper(self, chunk_id: str, link_type: str = None):
|
|
"""Wrapper for get_linked_chunks."""
|
|
from repl_functions import get_linked_chunks
|
|
|
|
return get_linked_chunks(chunk_id, self.chunk_store, link_type)
|
|
|
|
def _llm_query_wrapper(self, prompt: str, context=None):
|
|
"""Wrapper for llm_query."""
|
|
with self._lock:
|
|
self._iteration_count += 1
|
|
if self._iteration_count > self.max_iterations:
|
|
raise MaxIterationsError(
|
|
f"Maximum iterations ({self.max_iterations}) exceeded"
|
|
)
|
|
|
|
# Check max depth
|
|
if self._current_depth >= self.max_depth:
|
|
raise RecursionError(
|
|
f"Maximum recursion depth ({self.max_depth}) exceeded"
|
|
)
|
|
|
|
# Increment depth counter
|
|
self._current_depth += 1
|
|
|
|
try:
|
|
self._ensure_budget()
|
|
# Build full prompt with context
|
|
full_prompt = prompt
|
|
if context:
|
|
# Handle context as a list of chunk IDs
|
|
if isinstance(context, list):
|
|
from repl_functions import read_chunk
|
|
|
|
context_parts = []
|
|
for chunk_id in context:
|
|
chunk = read_chunk(chunk_id, self.chunk_store)
|
|
if chunk:
|
|
context_parts.append(
|
|
f"Chunk {chunk_id}:\n{chunk.get('content', '')}"
|
|
)
|
|
else:
|
|
context_parts.append(f"Chunk {chunk_id}:\n[Not found]")
|
|
context_str = "\n\n".join(context_parts)
|
|
full_prompt = f"Context:\n{context_str}\n\nPrompt:\n{prompt}"
|
|
elif isinstance(context, dict):
|
|
context_str = "\n".join(f"{k}: {v}" for k, v in context.items())
|
|
full_prompt = f"Context:\n{context_str}\n\nPrompt:\n{prompt}"
|
|
|
|
# Call LLM
|
|
response = self.llm_client.complete(full_prompt)
|
|
|
|
self._record_cost(response)
|
|
self._ensure_budget(allow_equal=True)
|
|
|
|
return response.text if hasattr(response, "text") else str(response)
|
|
except (RecursionError, MaxIterationsError):
|
|
# Don't catch these - let them propagate
|
|
raise
|
|
except Exception as e:
|
|
# Handle API errors gracefully
|
|
return f"Error: {str(e)}"
|
|
finally:
|
|
# Decrement depth counter
|
|
with self._lock:
|
|
self._current_depth -= 1
|
|
|
|
def _final_wrapper(self, answer) -> None:
|
|
"""Wrapper for FINAL."""
|
|
if self._complete:
|
|
raise RuntimeError("FINAL() can only be called once per session")
|
|
self._result = answer
|
|
self._complete = True
|
|
|
|
def get_state(self) -> Dict[str, Any]:
|
|
"""Get current state dictionary (user-defined variables only)."""
|
|
return self._state.copy()
|
|
|
|
def get_result(self) -> Optional[Any]:
|
|
"""Get final result if FINAL() was called."""
|
|
return self._result
|
|
|
|
def is_complete(self) -> bool:
|
|
"""Check if FINAL() has been called."""
|
|
return self._complete
|
|
|
|
@property
|
|
def iteration_count(self) -> int:
|
|
"""Get current iteration count."""
|
|
return self._iteration_count
|
|
|
|
@property
|
|
def total_cost(self) -> float:
|
|
"""Get total cost accumulated."""
|
|
return self._total_cost
|
|
|
|
def get_cost(self) -> float:
|
|
"""Get total cost accumulated."""
|
|
return self._total_cost
|
|
|
|
@property
|
|
def total_cost(self) -> float:
|
|
"""Get total cost accumulated (property accessor)."""
|
|
return self._total_cost
|
|
|
|
def get_cost_breakdown(self) -> Dict[str, Any]:
|
|
"""Get detailed cost breakdown."""
|
|
breakdown = {
|
|
"total": self._total_cost,
|
|
"calls": self._iteration_count,
|
|
"per_call_average": self._total_cost / self._iteration_count
|
|
if self._iteration_count > 0
|
|
else 0.0,
|
|
}
|
|
if self._max_cost_usd is not None:
|
|
remaining = self._max_cost_usd - self._total_cost
|
|
breakdown.update(
|
|
{
|
|
"budget": self._max_cost_usd,
|
|
"remaining": max(0.0, remaining),
|
|
"over_budget": self._total_cost > self._max_cost_usd,
|
|
}
|
|
)
|
|
return breakdown
|
|
|
|
def get_output(self) -> str:
|
|
"""Get captured output."""
|
|
return "\n".join(self._output)
|
|
|
|
def get_stderr(self) -> str:
|
|
"""Get captured stderr."""
|
|
return "\n".join(self._stderr)
|
|
|
|
def clear_output(self):
|
|
"""Clear captured output."""
|
|
self._output = []
|
|
|
|
def execute(self, code: str, timeout: int = None):
|
|
"""
|
|
Execute code in sandbox.
|
|
|
|
Args:
|
|
code: Python code to execute
|
|
timeout: Optional timeout override
|
|
|
|
Returns:
|
|
Result of the last expression or None
|
|
|
|
Raises:
|
|
RuntimeError: If called after FINAL()
|
|
SandboxViolation: If code violates sandbox
|
|
TimeoutError: If execution times out
|
|
"""
|
|
if self._complete:
|
|
raise RuntimeError("REPL already complete")
|
|
|
|
if not code or not code.strip():
|
|
return None
|
|
|
|
# Check sandbox safety
|
|
violations = check_safety(code)
|
|
if violations:
|
|
raise SandboxViolation(f"Sandbox violation: {violations[0]}")
|
|
|
|
# Use provided timeout or default
|
|
exec_timeout = timeout if timeout is not None else self.timeout_seconds
|
|
|
|
# Capture stdout/stderr
|
|
old_stdout = sys.stdout
|
|
old_stderr = sys.stderr
|
|
stdout_capture = io.StringIO()
|
|
stderr_capture = io.StringIO()
|
|
|
|
# Container for execution results
|
|
result_container = {"result": None, "error": None, "completed": False}
|
|
|
|
def run_execution():
|
|
try:
|
|
sys.stdout = stdout_capture
|
|
sys.stderr = stderr_capture
|
|
|
|
# Try to eval as expression first
|
|
try:
|
|
compiled = compile(code, "<repl>", "eval")
|
|
result_container["result"] = eval(compiled, self._namespace)
|
|
result_container["completed"] = True
|
|
return
|
|
except SyntaxError:
|
|
# Not an expression, try exec
|
|
pass
|
|
|
|
# Compile and execute as statements
|
|
compiled = compile(code, "<repl>", "exec")
|
|
exec(compiled, self._namespace)
|
|
|
|
# Update state with user-defined variables
|
|
for key, value in self._namespace.items():
|
|
if not key.startswith("_") and key not in (
|
|
"__builtins__",
|
|
"__name__",
|
|
):
|
|
self._state[key] = value
|
|
|
|
result_container["completed"] = True
|
|
|
|
except Exception as e:
|
|
result_container["error"] = e
|
|
|
|
# Run execution in a thread with timeout
|
|
exec_thread = threading.Thread(target=run_execution)
|
|
exec_thread.daemon = True
|
|
|
|
try:
|
|
sys.stdout = stdout_capture
|
|
sys.stderr = stderr_capture
|
|
|
|
exec_thread.start()
|
|
exec_thread.join(timeout=exec_timeout)
|
|
|
|
if exec_thread.is_alive():
|
|
# Thread is still running after timeout
|
|
raise TimeoutError(f"Execution exceeded {exec_timeout} seconds")
|
|
|
|
# Check for errors from the thread
|
|
if result_container["error"] is not None:
|
|
raise result_container["error"]
|
|
|
|
# Capture output
|
|
self._output.append(stdout_capture.getvalue())
|
|
self._stderr.append(stderr_capture.getvalue())
|
|
|
|
return result_container["result"]
|
|
|
|
except TimeoutError:
|
|
raise
|
|
except RecursionError:
|
|
# Let RecursionError propagate for depth limit testing
|
|
raise
|
|
except SandboxViolation:
|
|
# Let SandboxViolation propagate for security tests
|
|
raise
|
|
except SyntaxError as e:
|
|
error_msg = f"Syntax error: {e}"
|
|
self._output.append(error_msg)
|
|
return error_msg
|
|
except ZeroDivisionError as e:
|
|
error_msg = f"Zero division error: {e}"
|
|
self._output.append(error_msg)
|
|
return error_msg
|
|
except NameError as e:
|
|
# Return NameError as string for undefined name tests
|
|
error_msg = f"Name error: {e}"
|
|
self._output.append(error_msg)
|
|
return error_msg
|
|
except AttributeError as e:
|
|
error_msg = f"Attribute error: {e}"
|
|
self._output.append(error_msg)
|
|
return error_msg
|
|
except MemoryError as e:
|
|
error_msg = f"Memory error: {e}"
|
|
self._output.append(error_msg)
|
|
return error_msg
|
|
except Exception as e:
|
|
# Other exceptions - return as error string
|
|
error_msg = f"Runtime error: {e}"
|
|
self._output.append(error_msg)
|
|
return error_msg
|
|
finally:
|
|
sys.stdout = old_stdout
|
|
sys.stderr = old_stderr
|
|
|
|
def retrieve(self, query=None, max_iterations=None) -> Optional[Any]:
|
|
"""
|
|
Execute retrieval workflow for a query.
|
|
|
|
Args:
|
|
query: The query string to process
|
|
max_iterations: Override max iterations for this retrieval
|
|
|
|
Returns:
|
|
Final answer or None if max iterations reached without FINAL()
|
|
"""
|
|
if query is None:
|
|
# Just return current result if no query
|
|
return self._result if self._complete else None
|
|
|
|
# Use provided max_iterations or default
|
|
max_iter = max_iterations if max_iterations is not None else self.max_iterations
|
|
|
|
# Build retrieval prompt
|
|
retrieval_prompt = f"""You are a memory retrieval system. Answer the following query using the available memory functions.
|
|
|
|
Available functions:
|
|
- read_chunk(chunk_id): Read a chunk by ID
|
|
- search_chunks(query, limit=10): Search for chunks
|
|
- list_chunks_by_tag(tag): List chunks with a tag
|
|
- get_linked_chunks(chunk_id, link_type=None): Get linked chunks
|
|
- llm_query(prompt, context=None): Ask LLM for help
|
|
- FINAL(answer): Call when you have the final answer
|
|
|
|
Query: {query}
|
|
|
|
Write Python code to solve this query. Use FINAL('your answer') when done."""
|
|
|
|
# Iterative retrieval loop
|
|
for iteration in range(max_iter):
|
|
self._iteration_count += 1
|
|
|
|
# Get LLM response
|
|
try:
|
|
self._ensure_budget()
|
|
response = self.llm_client.complete(retrieval_prompt)
|
|
code = response.text if hasattr(response, "text") else str(response)
|
|
self._record_cost(response)
|
|
self._ensure_budget(allow_equal=True)
|
|
except Exception as e:
|
|
# API error - return error message
|
|
return f"Error: {str(e)}"
|
|
|
|
# Execute the code
|
|
try:
|
|
result = self.execute(code)
|
|
|
|
# Check if FINAL was called
|
|
if self._complete:
|
|
return self._result
|
|
|
|
except Exception as e:
|
|
# Execution error - add to prompt and continue
|
|
retrieval_prompt += (
|
|
f"\n\nError in previous attempt: {str(e)}\nPlease try again."
|
|
)
|
|
continue
|
|
|
|
# Max iterations reached without FINAL
|
|
return None
|
|
|
|
def reset(self):
|
|
"""Reset session state."""
|
|
self._state = {}
|
|
self._iteration_count = 0
|
|
self._total_cost = 0.0
|
|
self._current_depth = 0
|
|
self._result = None
|
|
self._complete = False
|
|
self._output = []
|
|
self._stderr = []
|
|
self._setup_namespace()
|
|
|
|
def _record_cost(self, response: Any) -> None:
|
|
"""Record cost from response or LLM client."""
|
|
cost_value = None
|
|
if hasattr(response, "cost_usd"):
|
|
cost_value = response.cost_usd
|
|
elif hasattr(self.llm_client, "get_cost") and callable(
|
|
self.llm_client.get_cost
|
|
):
|
|
cost_value = self.llm_client.get_cost()
|
|
if not isinstance(cost_value, (int, float)):
|
|
return
|
|
self._total_cost += float(cost_value)
|
|
|
|
def _ensure_budget(self, allow_equal: bool = False) -> None:
|
|
"""Ensure cost budget has not been exceeded."""
|
|
if self._max_cost_usd is None:
|
|
return
|
|
if allow_equal:
|
|
over_budget = self._total_cost > self._max_cost_usd
|
|
else:
|
|
over_budget = self._total_cost >= self._max_cost_usd
|
|
if over_budget:
|
|
raise CostBudgetExceededError(
|
|
f"Cost budget exceeded: total_cost={self._total_cost:.6f} budget={self._max_cost_usd:.6f}"
|
|
)
|
|
|
|
def __enter__(self):
|
|
"""Context manager entry."""
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
"""Context manager exit."""
|
|
self.reset()
|
|
return False
|