Fix: Security, reliability, and code quality improvements from PR review

Critical Security Fixes:
- Fix command injection vulnerability in Windows shims (beadboard.cmd, bb.cmd)
  - Added path validation to block traversal (.. and root-relative paths)
  - Added quotes around env var to prevent command injection

Reliability Fixes:
- Fix agent cache null safety bug
  - Fixed callBdAgentShow() to check for cache misses (null check, expiration)
  - Fixed getCachedAgent to properly return entry.data or null
- Fix null body crashes in mail ack route
  - Added null check before casting body to object
  - Returns 400 error instead of 500 for invalid requests

BD Compliance Fixes:
- Fix read-issues to use BD audit record path
  - Ensures all writes go through bd audit record
  - Maintains watcher/SSE parity and Dolt commit tracking

Code Quality Fixes:
- Fix path canonicalization violations
  - Use canonicalizeWindowsPath() and windowsPathKey() from pathing module
  - Prevents Windows edge cases and ensures machine-reproducible paths
- Fix typo: mobile-fronted → mobile-frontend
- Pin GitHub Actions tags
  - softprops/action-gh-release@v1 → specific commit hash
- Register pr14 test in package.json (already registered)

Testing:
- Refactor broad exception handlers in Python scripts
  - Replace except Exception: with specific exceptions
  - Allows KeyboardInterrupt and SystemExit to propagate correctly
  - All tests passing
This commit is contained in:
zenchantlive 2026-03-05 16:33:10 -08:00
parent d54e4f3311
commit ce4700849b
15 changed files with 2995 additions and 756 deletions

View file

@ -16,17 +16,20 @@ from pathlib import Path
class SandboxViolation(Exception):
"""Raised when code attempts to violate sandbox security."""
pass
class MaxIterationsError(Exception):
"""Raised when max iterations exceeded."""
pass
# Cost budget exceeded
class CostBudgetExceededError(RuntimeError):
"""Raised when cost budget is exceeded."""
pass
@ -35,29 +38,129 @@ class CostBudgetExceededError(RuntimeError):
# Allowed built-ins for sandbox
ALLOWED_BUILTINS = {
'abs', 'all', 'any', 'ascii', 'bin', 'bool', 'bytearray', 'bytes',
'callable', 'chr', 'classmethod', 'complex', 'delattr', 'dict',
'dir', 'divmod', 'enumerate', 'filter', 'float', 'format', 'frozenset',
'getattr', 'globals', 'hasattr', 'hash', 'help', 'hex', 'id', 'input',
'int', 'isinstance', 'issubclass', 'iter', 'len', 'list', 'locals',
'map', 'max', 'memoryview', 'min', 'next', 'object', 'oct', 'ord',
'pow', 'print', 'property', 'range', 'repr', 'reversed',
'round', 'set', 'setattr', 'slice', 'sorted', 'staticmethod', 'str',
'sum', 'super', 'tuple', 'type', 'vars', 'zip', '__build_class__',
'__name__', 'True', 'False', 'None', 'Exception', 'TypeError',
'ValueError', 'KeyError', 'IndexError', 'AttributeError', 'RuntimeError',
'StopIteration', 'ArithmeticError', 'LookupError', 'AssertionError',
'NotImplementedError', 'ZeroDivisionError', 'OverflowError',
"abs",
"all",
"any",
"ascii",
"bin",
"bool",
"bytearray",
"bytes",
"callable",
"chr",
"classmethod",
"complex",
"delattr",
"dict",
"dir",
"divmod",
"enumerate",
"filter",
"float",
"format",
"frozenset",
"getattr",
"globals",
"hasattr",
"hash",
"help",
"hex",
"id",
"input",
"int",
"isinstance",
"issubclass",
"iter",
"len",
"list",
"locals",
"map",
"max",
"memoryview",
"min",
"next",
"object",
"oct",
"ord",
"pow",
"print",
"property",
"range",
"repr",
"reversed",
"round",
"set",
"setattr",
"slice",
"sorted",
"staticmethod",
"str",
"sum",
"super",
"tuple",
"type",
"vars",
"zip",
"__build_class__",
"__name__",
"True",
"False",
"None",
"Exception",
"TypeError",
"ValueError",
"KeyError",
"IndexError",
"AttributeError",
"RuntimeError",
"StopIteration",
"ArithmeticError",
"LookupError",
"AssertionError",
"NotImplementedError",
"ZeroDivisionError",
"OverflowError",
}
# Blocked imports/modules
BLOCKED_MODULES = {
'os', 'sys', 'subprocess', 'socket', 'urllib', 'http', 'ftplib',
'smtplib', 'telnetlib', 'poplib', 'imaplib', 'nntplib', 'ssl',
'email', 'xmlrpc', 'concurrent.futures.process', 'multiprocessing',
'ctypes', 'cffi', 'mmap', 'resource', 'posix', 'nt', 'pwd', 'grp',
'spwd', 'crypt', 'termios', 'tty', 'pty', 'fcntl', 'msvcrt',
'winreg', '_winapi', 'select', 'selectors', 'asyncio.subprocess',
"os",
"sys",
"subprocess",
"socket",
"urllib",
"http",
"ftplib",
"smtplib",
"telnetlib",
"poplib",
"imaplib",
"nntplib",
"ssl",
"email",
"xmlrpc",
"concurrent.futures.process",
"multiprocessing",
"ctypes",
"cffi",
"mmap",
"resource",
"posix",
"nt",
"pwd",
"grp",
"spwd",
"crypt",
"termios",
"tty",
"pty",
"fcntl",
"msvcrt",
"winreg",
"_winapi",
"select",
"selectors",
"asyncio.subprocess",
}
# Allowed modules that get redirected to mocks
@ -66,11 +169,11 @@ ALLOWED_MODULES = set()
def safe_import(name, globals=None, locals=None, fromlist=(), level=0):
"""Safe import function that only allows specific modules."""
base_module = name.split('.')[0] if name else ''
base_module = name.split(".")[0] if name else ""
# Allow sys import (mocked in sandbox)
if base_module == 'sys':
if globals and 'sys' in globals:
return globals['sys']
if base_module == "sys":
if globals and "sys" in globals:
return globals["sys"]
raise ImportError("Mock sys not found in sandbox")
if base_module in ALLOWED_MODULES:
if globals and base_module in globals:
@ -81,98 +184,120 @@ def safe_import(name, globals=None, locals=None, fromlist=(), level=0):
# Blocked attribute names that could be used for sandbox escape
BLOCKED_ATTRIBUTES = {
'__class__', '__bases__', '__subclasses__', '__base__',
'__mro__', '__globals__', '__code__', '__func__', '__self__',
'__module__', '__dict__', '__closure__', '__defaults__',
'__kwdefaults__', '__getattribute__', '__setattr__',
"__class__",
"__bases__",
"__subclasses__",
"__base__",
"__mro__",
"__globals__",
"__code__",
"__func__",
"__self__",
"__module__",
"__dict__",
"__closure__",
"__defaults__",
"__kwdefaults__",
"__getattribute__",
"__setattr__",
}
class SandboxVisitor(ast.NodeVisitor):
"""AST visitor to check for sandbox violations."""
def __init__(self, allowed_paths: Optional[list] = None):
self.allowed_paths = allowed_paths or []
self.violations = []
def visit_Import(self, node):
for alias in node.names:
module = alias.name.split('.')[0]
module = alias.name.split(".")[0]
# Allow 'sys' import (redirected to mock in sandbox)
if module == 'sys':
if module == "sys":
continue
if module in BLOCKED_MODULES and module not in ALLOWED_MODULES:
self.violations.append(f"Import of '{module}' is not allowed")
self.generic_visit(node)
def visit_ImportFrom(self, node):
if node.module:
module = node.module.split('.')[0]
module = node.module.split(".")[0]
# Allow 'sys' import (redirected to mock in sandbox)
if module == 'sys':
if module == "sys":
return
if module in BLOCKED_MODULES and module not in ALLOWED_MODULES:
self.violations.append(f"Import from '{module}' is not allowed")
self.generic_visit(node)
def visit_Delete(self, node):
"""Block deletion of builtins attributes."""
for target in node.targets:
if isinstance(target, ast.Attribute):
if self._is_builtins_access(target.value):
self.violations.append("Deletion of __builtins__ attributes is not allowed")
self.violations.append(
"Deletion of __builtins__ attributes is not allowed"
)
if isinstance(target, ast.Subscript):
if self._is_builtins_access(target.value):
self.violations.append("Deletion of __builtins__ attributes is not allowed")
self.violations.append(
"Deletion of __builtins__ attributes is not allowed"
)
self.generic_visit(node)
def visit_Call(self, node):
# Check for eval/exec/compile
if isinstance(node.func, ast.Name):
if node.func.id in ('eval', 'exec', 'compile'):
if node.func.id in ("eval", "exec", "compile"):
self.violations.append(f"Use of '{node.func.id}()' is not allowed")
# Check for __import__
if isinstance(node.func, ast.Name) and node.func.id == '__import__':
if isinstance(node.func, ast.Name) and node.func.id == "__import__":
self.violations.append("Use of '__import__()' is not allowed")
# Check for open()
if isinstance(node.func, ast.Name) and node.func.id == 'open':
if isinstance(node.func, ast.Name) and node.func.id == "open":
self.violations.append("Use of 'open()' is not allowed")
# Check for getattr/setattr on __builtins__
if isinstance(node.func, ast.Name) and node.func.id == 'getattr':
if isinstance(node.func, ast.Name) and node.func.id == "getattr":
if node.args and self._is_builtins_access(node.args[0]):
self.violations.append("getattr on __builtins__ is not allowed")
if isinstance(node.func, ast.Name) and node.func.id == 'setattr':
if isinstance(node.func, ast.Name) and node.func.id == "setattr":
if node.args and self._is_builtins_access(node.args[0]):
self.violations.append("setattr on __builtins__ is not allowed")
if isinstance(node.func, ast.Name) and node.func.id == 'delattr':
if isinstance(node.func, ast.Name) and node.func.id == "delattr":
if node.args and self._is_builtins_access(node.args[0]):
self.violations.append("delattr on __builtins__ is not allowed")
self.generic_visit(node)
def visit_BinOp(self, node):
"""Check for large memory allocations via string/list multiplication."""
if isinstance(node.op, ast.Mult):
# Check for patterns like "x" * (1024 * 1024 * 100)
# Try to evaluate the size statically
try:
if isinstance(node.left, ast.Constant) and isinstance(node.left.value, str):
if isinstance(node.left, ast.Constant) and isinstance(
node.left.value, str
):
if isinstance(node.right, ast.Constant):
size = len(node.left.value) * node.right.value
if size > 10 * 1024 * 1024: # 10MB limit
raise MemoryError(f"String multiplication would create {size} bytes, exceeding 10MB limit")
raise MemoryError(
f"String multiplication would create {size} bytes, exceeding 10MB limit"
)
elif isinstance(node.right, ast.BinOp):
# Try to evaluate binary expression
size = len(node.left.value) * self._eval_const_expr(node.right)
if size > 10 * 1024 * 1024: # 10MB limit
raise MemoryError(f"String multiplication would create {size} bytes, exceeding 10MB limit")
raise MemoryError(
f"String multiplication would create {size} bytes, exceeding 10MB limit"
)
except MemoryError:
raise # Re-raise MemoryError
except Exception:
except (ValueError, TypeError, AttributeError):
pass # Can't evaluate statically, let it run and catch at runtime
self.generic_visit(node)
def _eval_const_expr(self, node):
"""Try to evaluate a constant expression statically."""
if isinstance(node, ast.Constant):
@ -187,36 +312,52 @@ class SandboxVisitor(ast.NodeVisitor):
if isinstance(node.op, ast.Sub):
return left - right
raise ValueError("Cannot evaluate expression")
def visit_Attribute(self, node):
"""Check for dangerous attribute access like __class__, __bases__, etc."""
if node.attr in BLOCKED_ATTRIBUTES:
self.violations.append(f"Access to '{node.attr}' is not allowed")
self.generic_visit(node)
def visit_Subscript(self, node):
"""Check for builtins subscript access like globals()['__builtins__']['__import__']."""
# Check for globals()['__builtins__'] or locals()['__builtins__']
if isinstance(node.value, ast.Call):
if isinstance(node.value.func, ast.Name) and node.value.func.id in ('globals', 'locals'):
if isinstance(node.slice, ast.Constant) and node.slice.value == '__builtins__':
self.violations.append("globals()/locals()['__builtins__'] manipulation is not allowed")
elif hasattr(node.slice, 's') and node.slice.s == '__builtins__': # Python < 3.8 compatibility
self.violations.append("globals()/locals()['__builtins__'] manipulation is not allowed")
if isinstance(node.value.func, ast.Name) and node.value.func.id in (
"globals",
"locals",
):
if (
isinstance(node.slice, ast.Constant)
and node.slice.value == "__builtins__"
):
self.violations.append(
"globals()/locals()['__builtins__'] manipulation is not allowed"
)
elif (
hasattr(node.slice, "s") and node.slice.s == "__builtins__"
): # Python < 3.8 compatibility
self.violations.append(
"globals()/locals()['__builtins__'] manipulation is not allowed"
)
self.generic_visit(node)
def _is_builtins_access(self, node):
"""Check if a node represents access to __builtins__."""
if isinstance(node, ast.Name) and node.id == '__builtins__':
if isinstance(node, ast.Name) and node.id == "__builtins__":
return True
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name) and node.func.id in ('globals', 'locals'):
if isinstance(node.func, ast.Name) and node.func.id in (
"globals",
"locals",
):
return True
return False
class MemoryLimitException(RuntimeError):
"""Raised when memory limit is exceeded."""
pass
@ -224,14 +365,14 @@ class MemoryLimitException(RuntimeError):
def check_safety(code: str) -> list:
"""Check code for sandbox violations."""
# Pre-check for null bytes and other dangerous characters
if '\x00' in code:
if "\x00" in code:
return ["Code contains null bytes which is not allowed"]
try:
tree = ast.parse(code)
except SyntaxError:
return [] # Let SyntaxError be handled elsewhere
visitor = SandboxVisitor()
visitor.visit(tree)
return visitor.violations
@ -255,36 +396,44 @@ class REPLSession:
"""
RLM REPL Session - secure sandbox for recursive LLM execution.
"""
class _StderrCapture:
"""Mock stderr object for sandbox."""
def __init__(self, session):
self._session = session
def write(self, text: str):
"""Write to stderr capture."""
self._session._stderr.append(text)
def flush(self):
"""Flush stderr (no-op)."""
pass
class MockSys:
"""Mock sys module for sandbox with only stderr."""
def __init__(self, stderr_capture):
self.stderr = stderr_capture
def __getattr__(self, name):
if name == 'modules':
if name == "modules":
raise SandboxViolation("Access to sys.modules is not allowed")
raise AttributeError(f"sys.{name} is not available in sandbox")
def __init__(self, chunk_store=None, llm_client=None,
max_iterations: int = 10, timeout_seconds: int = 60, max_depth: int = 5,
max_cost_usd: Optional[float] = None):
def __init__(
self,
chunk_store=None,
llm_client=None,
max_iterations: int = 10,
timeout_seconds: int = 60,
max_depth: int = 5,
max_cost_usd: Optional[float] = None,
):
"""
Initialize REPL session.
Args:
chunk_store: ChunkStore instance for memory access
llm_client: LLM client for recursive queries
@ -296,14 +445,14 @@ class REPLSession:
raise ValueError("chunk_store is required")
if llm_client is None:
raise ValueError("llm_client is required")
self.chunk_store = chunk_store
self.llm_client = llm_client
self.max_iterations = max_iterations
self.timeout_seconds = timeout_seconds
self.max_depth = max_depth
self._max_cost_usd = max_cost_usd
self._state: Dict[str, Any] = {} # User state (empty initially)
self._iteration_count = 0
self._total_cost = 0.0
@ -314,64 +463,75 @@ class REPLSession:
self._output = []
self._stderr = []
self._stderr_capture = self._StderrCapture(self)
# Create isolated namespace for execution
self._namespace = {}
self._setup_namespace()
def _setup_namespace(self):
"""Set up the sandbox namespace."""
# Safe builtins
safe_builtins = {name: getattr(builtins, name)
for name in ALLOWED_BUILTINS
if hasattr(builtins, name)}
# Inject memory functions
from brain.scripts.repl_functions import read_chunk, search_chunks, list_chunks_by_tag, get_linked_chunks
# Create bound methods
safe_builtins['read_chunk'] = self._read_chunk_wrapper
safe_builtins['search_chunks'] = self._search_chunks_wrapper
safe_builtins['list_chunks_by_tag'] = self._list_chunks_by_tag_wrapper
safe_builtins['get_linked_chunks'] = self._get_linked_chunks_wrapper
safe_builtins['llm_query'] = self._llm_query_wrapper
safe_builtins['FINAL'] = self._final_wrapper
# Inject safe import and mock sys module
safe_builtins['__import__'] = safe_import
safe_builtins['sys'] = self.MockSys(self._stderr_capture)
self._namespace = {
'__builtins__': safe_builtins,
'__name__': '__repl__',
safe_builtins = {
name: getattr(builtins, name)
for name in ALLOWED_BUILTINS
if hasattr(builtins, name)
}
# Inject memory functions
from brain.scripts.repl_functions import (
read_chunk,
search_chunks,
list_chunks_by_tag,
get_linked_chunks,
)
# Create bound methods
safe_builtins["read_chunk"] = self._read_chunk_wrapper
safe_builtins["search_chunks"] = self._search_chunks_wrapper
safe_builtins["list_chunks_by_tag"] = self._list_chunks_by_tag_wrapper
safe_builtins["get_linked_chunks"] = self._get_linked_chunks_wrapper
safe_builtins["llm_query"] = self._llm_query_wrapper
safe_builtins["FINAL"] = self._final_wrapper
# Inject safe import and mock sys module
safe_builtins["__import__"] = safe_import
safe_builtins["sys"] = self.MockSys(self._stderr_capture)
self._namespace = {
"__builtins__": safe_builtins,
"__name__": "__repl__",
}
# Inject mock sys module so 'import sys' binds to our mock
self._namespace['sys'] = self.MockSys(self._stderr_capture)
self._namespace["sys"] = self.MockSys(self._stderr_capture)
# Merge user state into namespace
self._namespace.update(self._state)
def _read_chunk_wrapper(self, chunk_id: str):
"""Wrapper for read_chunk."""
from repl_functions import read_chunk
return read_chunk(chunk_id, self.chunk_store)
def _search_chunks_wrapper(self, query: str, limit: int = 10):
"""Wrapper for search_chunks."""
from repl_functions import search_chunks
return search_chunks(query, self.chunk_store, limit)
def _list_chunks_by_tag_wrapper(self, tags):
"""Wrapper for list_chunks_by_tag."""
from repl_functions import list_chunks_by_tag
return list_chunks_by_tag(tags, self.chunk_store)
def _get_linked_chunks_wrapper(self, chunk_id: str, link_type: str = None):
"""Wrapper for get_linked_chunks."""
from repl_functions import get_linked_chunks
return get_linked_chunks(chunk_id, self.chunk_store, link_type)
def _llm_query_wrapper(self, prompt: str, context=None):
"""Wrapper for llm_query."""
with self._lock:
@ -380,14 +540,16 @@ class REPLSession:
raise MaxIterationsError(
f"Maximum iterations ({self.max_iterations}) exceeded"
)
# Check max depth
if self._current_depth >= self.max_depth:
raise RecursionError(f"Maximum recursion depth ({self.max_depth}) exceeded")
raise RecursionError(
f"Maximum recursion depth ({self.max_depth}) exceeded"
)
# Increment depth counter
self._current_depth += 1
try:
self._ensure_budget()
# Build full prompt with context
@ -396,11 +558,14 @@ class REPLSession:
# Handle context as a list of chunk IDs
if isinstance(context, list):
from repl_functions import read_chunk
context_parts = []
for chunk_id in context:
chunk = read_chunk(chunk_id, self.chunk_store)
if chunk:
context_parts.append(f"Chunk {chunk_id}:\n{chunk.get('content', '')}")
context_parts.append(
f"Chunk {chunk_id}:\n{chunk.get('content', '')}"
)
else:
context_parts.append(f"Chunk {chunk_id}:\n[Not found]")
context_str = "\n\n".join(context_parts)
@ -408,14 +573,14 @@ class REPLSession:
elif isinstance(context, dict):
context_str = "\n".join(f"{k}: {v}" for k, v in context.items())
full_prompt = f"Context:\n{context_str}\n\nPrompt:\n{prompt}"
# Call LLM
response = self.llm_client.complete(full_prompt)
self._record_cost(response)
self._ensure_budget(allow_equal=True)
return response.text if hasattr(response, 'text') else str(response)
return response.text if hasattr(response, "text") else str(response)
except (RecursionError, MaxIterationsError):
# Don't catch these - let them propagate
raise
@ -426,84 +591,88 @@ class REPLSession:
# Decrement depth counter
with self._lock:
self._current_depth -= 1
def _final_wrapper(self, answer) -> None:
"""Wrapper for FINAL."""
if self._complete:
raise RuntimeError("FINAL() can only be called once per session")
self._result = answer
self._complete = True
def get_state(self) -> Dict[str, Any]:
"""Get current state dictionary (user-defined variables only)."""
return self._state.copy()
def get_result(self) -> Optional[Any]:
"""Get final result if FINAL() was called."""
return self._result
def is_complete(self) -> bool:
"""Check if FINAL() has been called."""
return self._complete
@property
def iteration_count(self) -> int:
"""Get current iteration count."""
return self._iteration_count
@property
def total_cost(self) -> float:
"""Get total cost accumulated."""
return self._total_cost
def get_cost(self) -> float:
"""Get total cost accumulated."""
return self._total_cost
@property
def total_cost(self) -> float:
"""Get total cost accumulated (property accessor)."""
return self._total_cost
def get_cost_breakdown(self) -> Dict[str, Any]:
"""Get detailed cost breakdown."""
breakdown = {
"total": self._total_cost,
"calls": self._iteration_count,
"per_call_average": self._total_cost / self._iteration_count if self._iteration_count > 0 else 0.0
"per_call_average": self._total_cost / self._iteration_count
if self._iteration_count > 0
else 0.0,
}
if self._max_cost_usd is not None:
remaining = self._max_cost_usd - self._total_cost
breakdown.update({
"budget": self._max_cost_usd,
"remaining": max(0.0, remaining),
"over_budget": self._total_cost > self._max_cost_usd
})
breakdown.update(
{
"budget": self._max_cost_usd,
"remaining": max(0.0, remaining),
"over_budget": self._total_cost > self._max_cost_usd,
}
)
return breakdown
def get_output(self) -> str:
"""Get captured output."""
return "\n".join(self._output)
def get_stderr(self) -> str:
"""Get captured stderr."""
return "\n".join(self._stderr)
def clear_output(self):
"""Clear captured output."""
self._output = []
def execute(self, code: str, timeout: int = None):
"""
Execute code in sandbox.
Args:
code: Python code to execute
timeout: Optional timeout override
Returns:
Result of the last expression or None
Raises:
RuntimeError: If called after FINAL()
SandboxViolation: If code violates sandbox
@ -511,81 +680,84 @@ class REPLSession:
"""
if self._complete:
raise RuntimeError("REPL already complete")
if not code or not code.strip():
return None
# Check sandbox safety
violations = check_safety(code)
if violations:
raise SandboxViolation(f"Sandbox violation: {violations[0]}")
# Use provided timeout or default
exec_timeout = timeout if timeout is not None else self.timeout_seconds
# Capture stdout/stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
# Container for execution results
result_container = {'result': None, 'error': None, 'completed': False}
result_container = {"result": None, "error": None, "completed": False}
def run_execution():
try:
sys.stdout = stdout_capture
sys.stderr = stderr_capture
# Try to eval as expression first
try:
compiled = compile(code, '<repl>', 'eval')
result_container['result'] = eval(compiled, self._namespace)
result_container['completed'] = True
compiled = compile(code, "<repl>", "eval")
result_container["result"] = eval(compiled, self._namespace)
result_container["completed"] = True
return
except SyntaxError:
# Not an expression, try exec
pass
# Compile and execute as statements
compiled = compile(code, '<repl>', 'exec')
compiled = compile(code, "<repl>", "exec")
exec(compiled, self._namespace)
# Update state with user-defined variables
for key, value in self._namespace.items():
if not key.startswith('_') and key not in ('__builtins__', '__name__'):
if not key.startswith("_") and key not in (
"__builtins__",
"__name__",
):
self._state[key] = value
result_container['completed'] = True
result_container["completed"] = True
except Exception as e:
result_container['error'] = e
result_container["error"] = e
# Run execution in a thread with timeout
exec_thread = threading.Thread(target=run_execution)
exec_thread.daemon = True
try:
sys.stdout = stdout_capture
sys.stderr = stderr_capture
exec_thread.start()
exec_thread.join(timeout=exec_timeout)
if exec_thread.is_alive():
# Thread is still running after timeout
raise TimeoutError(f"Execution exceeded {exec_timeout} seconds")
# Check for errors from the thread
if result_container['error'] is not None:
raise result_container['error']
if result_container["error"] is not None:
raise result_container["error"]
# Capture output
self._output.append(stdout_capture.getvalue())
self._stderr.append(stderr_capture.getvalue())
return result_container['result']
return result_container["result"]
except TimeoutError:
raise
except RecursionError:
@ -623,25 +795,25 @@ class REPLSession:
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
def retrieve(self, query=None, max_iterations=None) -> Optional[Any]:
"""
Execute retrieval workflow for a query.
Args:
query: The query string to process
max_iterations: Override max iterations for this retrieval
Returns:
Final answer or None if max iterations reached without FINAL()
"""
if query is None:
# Just return current result if no query
return self._result if self._complete else None
# Use provided max_iterations or default
max_iter = max_iterations if max_iterations is not None else self.max_iterations
# Build retrieval prompt
retrieval_prompt = f"""You are a memory retrieval system. Answer the following query using the available memory functions.
@ -656,38 +828,40 @@ Available functions:
Query: {query}
Write Python code to solve this query. Use FINAL('your answer') when done."""
# Iterative retrieval loop
for iteration in range(max_iter):
self._iteration_count += 1
# Get LLM response
try:
self._ensure_budget()
response = self.llm_client.complete(retrieval_prompt)
code = response.text if hasattr(response, 'text') else str(response)
code = response.text if hasattr(response, "text") else str(response)
self._record_cost(response)
self._ensure_budget(allow_equal=True)
except Exception as e:
# API error - return error message
return f"Error: {str(e)}"
# Execute the code
try:
result = self.execute(code)
# Check if FINAL was called
if self._complete:
return self._result
except Exception as e:
# Execution error - add to prompt and continue
retrieval_prompt += f"\n\nError in previous attempt: {str(e)}\nPlease try again."
retrieval_prompt += (
f"\n\nError in previous attempt: {str(e)}\nPlease try again."
)
continue
# Max iterations reached without FINAL
return None
def reset(self):
"""Reset session state."""
self._state = {}
@ -703,9 +877,11 @@ Write Python code to solve this query. Use FINAL('your answer') when done."""
def _record_cost(self, response: Any) -> None:
"""Record cost from response or LLM client."""
cost_value = None
if hasattr(response, 'cost_usd'):
if hasattr(response, "cost_usd"):
cost_value = response.cost_usd
elif hasattr(self.llm_client, 'get_cost') and callable(self.llm_client.get_cost):
elif hasattr(self.llm_client, "get_cost") and callable(
self.llm_client.get_cost
):
cost_value = self.llm_client.get_cost()
if not isinstance(cost_value, (int, float)):
return
@ -723,11 +899,11 @@ Write Python code to solve this query. Use FINAL('your answer') when done."""
raise CostBudgetExceededError(
f"Cost budget exceeded: total_cost={self._total_cost:.6f} budget={self._max_cost_usd:.6f}"
)
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.reset()