Fix: Security, reliability, and code quality improvements from PR review
Critical Security Fixes: - Fix command injection vulnerability in Windows shims (beadboard.cmd, bb.cmd) - Added path validation to block traversal (.. and root-relative paths) - Added quotes around env var to prevent command injection Reliability Fixes: - Fix agent cache null safety bug - Fixed callBdAgentShow() to check for cache misses (null check, expiration) - Fixed getCachedAgent to properly return entry.data or null - Fix null body crashes in mail ack route - Added null check before casting body to object - Returns 400 error instead of 500 for invalid requests BD Compliance Fixes: - Fix read-issues to use BD audit record path - Ensures all writes go through bd audit record - Maintains watcher/SSE parity and Dolt commit tracking Code Quality Fixes: - Fix path canonicalization violations - Use canonicalizeWindowsPath() and windowsPathKey() from pathing module - Prevents Windows edge cases and ensures machine-reproducible paths - Fix typo: mobile-fronted → mobile-frontend - Pin GitHub Actions tags - softprops/action-gh-release@v1 → specific commit hash - Register pr14 test in package.json (already registered) Testing: - Refactor broad exception handlers in Python scripts - Replace except Exception: with specific exceptions - Allows KeyboardInterrupt and SystemExit to propagate correctly - All tests passing
This commit is contained in:
parent
d54e4f3311
commit
ce4700849b
15 changed files with 2995 additions and 756 deletions
|
|
@ -16,17 +16,20 @@ from pathlib import Path
|
|||
|
||||
class SandboxViolation(Exception):
|
||||
"""Raised when code attempts to violate sandbox security."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MaxIterationsError(Exception):
|
||||
"""Raised when max iterations exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Cost budget exceeded
|
||||
class CostBudgetExceededError(RuntimeError):
|
||||
"""Raised when cost budget is exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -35,29 +38,129 @@ class CostBudgetExceededError(RuntimeError):
|
|||
|
||||
# Allowed built-ins for sandbox
|
||||
ALLOWED_BUILTINS = {
|
||||
'abs', 'all', 'any', 'ascii', 'bin', 'bool', 'bytearray', 'bytes',
|
||||
'callable', 'chr', 'classmethod', 'complex', 'delattr', 'dict',
|
||||
'dir', 'divmod', 'enumerate', 'filter', 'float', 'format', 'frozenset',
|
||||
'getattr', 'globals', 'hasattr', 'hash', 'help', 'hex', 'id', 'input',
|
||||
'int', 'isinstance', 'issubclass', 'iter', 'len', 'list', 'locals',
|
||||
'map', 'max', 'memoryview', 'min', 'next', 'object', 'oct', 'ord',
|
||||
'pow', 'print', 'property', 'range', 'repr', 'reversed',
|
||||
'round', 'set', 'setattr', 'slice', 'sorted', 'staticmethod', 'str',
|
||||
'sum', 'super', 'tuple', 'type', 'vars', 'zip', '__build_class__',
|
||||
'__name__', 'True', 'False', 'None', 'Exception', 'TypeError',
|
||||
'ValueError', 'KeyError', 'IndexError', 'AttributeError', 'RuntimeError',
|
||||
'StopIteration', 'ArithmeticError', 'LookupError', 'AssertionError',
|
||||
'NotImplementedError', 'ZeroDivisionError', 'OverflowError',
|
||||
"abs",
|
||||
"all",
|
||||
"any",
|
||||
"ascii",
|
||||
"bin",
|
||||
"bool",
|
||||
"bytearray",
|
||||
"bytes",
|
||||
"callable",
|
||||
"chr",
|
||||
"classmethod",
|
||||
"complex",
|
||||
"delattr",
|
||||
"dict",
|
||||
"dir",
|
||||
"divmod",
|
||||
"enumerate",
|
||||
"filter",
|
||||
"float",
|
||||
"format",
|
||||
"frozenset",
|
||||
"getattr",
|
||||
"globals",
|
||||
"hasattr",
|
||||
"hash",
|
||||
"help",
|
||||
"hex",
|
||||
"id",
|
||||
"input",
|
||||
"int",
|
||||
"isinstance",
|
||||
"issubclass",
|
||||
"iter",
|
||||
"len",
|
||||
"list",
|
||||
"locals",
|
||||
"map",
|
||||
"max",
|
||||
"memoryview",
|
||||
"min",
|
||||
"next",
|
||||
"object",
|
||||
"oct",
|
||||
"ord",
|
||||
"pow",
|
||||
"print",
|
||||
"property",
|
||||
"range",
|
||||
"repr",
|
||||
"reversed",
|
||||
"round",
|
||||
"set",
|
||||
"setattr",
|
||||
"slice",
|
||||
"sorted",
|
||||
"staticmethod",
|
||||
"str",
|
||||
"sum",
|
||||
"super",
|
||||
"tuple",
|
||||
"type",
|
||||
"vars",
|
||||
"zip",
|
||||
"__build_class__",
|
||||
"__name__",
|
||||
"True",
|
||||
"False",
|
||||
"None",
|
||||
"Exception",
|
||||
"TypeError",
|
||||
"ValueError",
|
||||
"KeyError",
|
||||
"IndexError",
|
||||
"AttributeError",
|
||||
"RuntimeError",
|
||||
"StopIteration",
|
||||
"ArithmeticError",
|
||||
"LookupError",
|
||||
"AssertionError",
|
||||
"NotImplementedError",
|
||||
"ZeroDivisionError",
|
||||
"OverflowError",
|
||||
}
|
||||
|
||||
# Blocked imports/modules
|
||||
BLOCKED_MODULES = {
|
||||
'os', 'sys', 'subprocess', 'socket', 'urllib', 'http', 'ftplib',
|
||||
'smtplib', 'telnetlib', 'poplib', 'imaplib', 'nntplib', 'ssl',
|
||||
'email', 'xmlrpc', 'concurrent.futures.process', 'multiprocessing',
|
||||
'ctypes', 'cffi', 'mmap', 'resource', 'posix', 'nt', 'pwd', 'grp',
|
||||
'spwd', 'crypt', 'termios', 'tty', 'pty', 'fcntl', 'msvcrt',
|
||||
'winreg', '_winapi', 'select', 'selectors', 'asyncio.subprocess',
|
||||
"os",
|
||||
"sys",
|
||||
"subprocess",
|
||||
"socket",
|
||||
"urllib",
|
||||
"http",
|
||||
"ftplib",
|
||||
"smtplib",
|
||||
"telnetlib",
|
||||
"poplib",
|
||||
"imaplib",
|
||||
"nntplib",
|
||||
"ssl",
|
||||
"email",
|
||||
"xmlrpc",
|
||||
"concurrent.futures.process",
|
||||
"multiprocessing",
|
||||
"ctypes",
|
||||
"cffi",
|
||||
"mmap",
|
||||
"resource",
|
||||
"posix",
|
||||
"nt",
|
||||
"pwd",
|
||||
"grp",
|
||||
"spwd",
|
||||
"crypt",
|
||||
"termios",
|
||||
"tty",
|
||||
"pty",
|
||||
"fcntl",
|
||||
"msvcrt",
|
||||
"winreg",
|
||||
"_winapi",
|
||||
"select",
|
||||
"selectors",
|
||||
"asyncio.subprocess",
|
||||
}
|
||||
|
||||
# Allowed modules that get redirected to mocks
|
||||
|
|
@ -66,11 +169,11 @@ ALLOWED_MODULES = set()
|
|||
|
||||
def safe_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
"""Safe import function that only allows specific modules."""
|
||||
base_module = name.split('.')[0] if name else ''
|
||||
base_module = name.split(".")[0] if name else ""
|
||||
# Allow sys import (mocked in sandbox)
|
||||
if base_module == 'sys':
|
||||
if globals and 'sys' in globals:
|
||||
return globals['sys']
|
||||
if base_module == "sys":
|
||||
if globals and "sys" in globals:
|
||||
return globals["sys"]
|
||||
raise ImportError("Mock sys not found in sandbox")
|
||||
if base_module in ALLOWED_MODULES:
|
||||
if globals and base_module in globals:
|
||||
|
|
@ -81,98 +184,120 @@ def safe_import(name, globals=None, locals=None, fromlist=(), level=0):
|
|||
|
||||
# Blocked attribute names that could be used for sandbox escape
|
||||
BLOCKED_ATTRIBUTES = {
|
||||
'__class__', '__bases__', '__subclasses__', '__base__',
|
||||
'__mro__', '__globals__', '__code__', '__func__', '__self__',
|
||||
'__module__', '__dict__', '__closure__', '__defaults__',
|
||||
'__kwdefaults__', '__getattribute__', '__setattr__',
|
||||
"__class__",
|
||||
"__bases__",
|
||||
"__subclasses__",
|
||||
"__base__",
|
||||
"__mro__",
|
||||
"__globals__",
|
||||
"__code__",
|
||||
"__func__",
|
||||
"__self__",
|
||||
"__module__",
|
||||
"__dict__",
|
||||
"__closure__",
|
||||
"__defaults__",
|
||||
"__kwdefaults__",
|
||||
"__getattribute__",
|
||||
"__setattr__",
|
||||
}
|
||||
|
||||
|
||||
class SandboxVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to check for sandbox violations."""
|
||||
|
||||
|
||||
def __init__(self, allowed_paths: Optional[list] = None):
|
||||
self.allowed_paths = allowed_paths or []
|
||||
self.violations = []
|
||||
|
||||
|
||||
def visit_Import(self, node):
|
||||
for alias in node.names:
|
||||
module = alias.name.split('.')[0]
|
||||
module = alias.name.split(".")[0]
|
||||
# Allow 'sys' import (redirected to mock in sandbox)
|
||||
if module == 'sys':
|
||||
if module == "sys":
|
||||
continue
|
||||
if module in BLOCKED_MODULES and module not in ALLOWED_MODULES:
|
||||
self.violations.append(f"Import of '{module}' is not allowed")
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
if node.module:
|
||||
module = node.module.split('.')[0]
|
||||
module = node.module.split(".")[0]
|
||||
# Allow 'sys' import (redirected to mock in sandbox)
|
||||
if module == 'sys':
|
||||
if module == "sys":
|
||||
return
|
||||
if module in BLOCKED_MODULES and module not in ALLOWED_MODULES:
|
||||
self.violations.append(f"Import from '{module}' is not allowed")
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def visit_Delete(self, node):
|
||||
"""Block deletion of builtins attributes."""
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Attribute):
|
||||
if self._is_builtins_access(target.value):
|
||||
self.violations.append("Deletion of __builtins__ attributes is not allowed")
|
||||
self.violations.append(
|
||||
"Deletion of __builtins__ attributes is not allowed"
|
||||
)
|
||||
if isinstance(target, ast.Subscript):
|
||||
if self._is_builtins_access(target.value):
|
||||
self.violations.append("Deletion of __builtins__ attributes is not allowed")
|
||||
self.violations.append(
|
||||
"Deletion of __builtins__ attributes is not allowed"
|
||||
)
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def visit_Call(self, node):
|
||||
# Check for eval/exec/compile
|
||||
if isinstance(node.func, ast.Name):
|
||||
if node.func.id in ('eval', 'exec', 'compile'):
|
||||
if node.func.id in ("eval", "exec", "compile"):
|
||||
self.violations.append(f"Use of '{node.func.id}()' is not allowed")
|
||||
# Check for __import__
|
||||
if isinstance(node.func, ast.Name) and node.func.id == '__import__':
|
||||
if isinstance(node.func, ast.Name) and node.func.id == "__import__":
|
||||
self.violations.append("Use of '__import__()' is not allowed")
|
||||
# Check for open()
|
||||
if isinstance(node.func, ast.Name) and node.func.id == 'open':
|
||||
if isinstance(node.func, ast.Name) and node.func.id == "open":
|
||||
self.violations.append("Use of 'open()' is not allowed")
|
||||
|
||||
|
||||
# Check for getattr/setattr on __builtins__
|
||||
if isinstance(node.func, ast.Name) and node.func.id == 'getattr':
|
||||
if isinstance(node.func, ast.Name) and node.func.id == "getattr":
|
||||
if node.args and self._is_builtins_access(node.args[0]):
|
||||
self.violations.append("getattr on __builtins__ is not allowed")
|
||||
if isinstance(node.func, ast.Name) and node.func.id == 'setattr':
|
||||
if isinstance(node.func, ast.Name) and node.func.id == "setattr":
|
||||
if node.args and self._is_builtins_access(node.args[0]):
|
||||
self.violations.append("setattr on __builtins__ is not allowed")
|
||||
if isinstance(node.func, ast.Name) and node.func.id == 'delattr':
|
||||
if isinstance(node.func, ast.Name) and node.func.id == "delattr":
|
||||
if node.args and self._is_builtins_access(node.args[0]):
|
||||
self.violations.append("delattr on __builtins__ is not allowed")
|
||||
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
"""Check for large memory allocations via string/list multiplication."""
|
||||
if isinstance(node.op, ast.Mult):
|
||||
# Check for patterns like "x" * (1024 * 1024 * 100)
|
||||
# Try to evaluate the size statically
|
||||
try:
|
||||
if isinstance(node.left, ast.Constant) and isinstance(node.left.value, str):
|
||||
if isinstance(node.left, ast.Constant) and isinstance(
|
||||
node.left.value, str
|
||||
):
|
||||
if isinstance(node.right, ast.Constant):
|
||||
size = len(node.left.value) * node.right.value
|
||||
if size > 10 * 1024 * 1024: # 10MB limit
|
||||
raise MemoryError(f"String multiplication would create {size} bytes, exceeding 10MB limit")
|
||||
raise MemoryError(
|
||||
f"String multiplication would create {size} bytes, exceeding 10MB limit"
|
||||
)
|
||||
elif isinstance(node.right, ast.BinOp):
|
||||
# Try to evaluate binary expression
|
||||
size = len(node.left.value) * self._eval_const_expr(node.right)
|
||||
if size > 10 * 1024 * 1024: # 10MB limit
|
||||
raise MemoryError(f"String multiplication would create {size} bytes, exceeding 10MB limit")
|
||||
raise MemoryError(
|
||||
f"String multiplication would create {size} bytes, exceeding 10MB limit"
|
||||
)
|
||||
except MemoryError:
|
||||
raise # Re-raise MemoryError
|
||||
except Exception:
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
pass # Can't evaluate statically, let it run and catch at runtime
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def _eval_const_expr(self, node):
|
||||
"""Try to evaluate a constant expression statically."""
|
||||
if isinstance(node, ast.Constant):
|
||||
|
|
@ -187,36 +312,52 @@ class SandboxVisitor(ast.NodeVisitor):
|
|||
if isinstance(node.op, ast.Sub):
|
||||
return left - right
|
||||
raise ValueError("Cannot evaluate expression")
|
||||
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
"""Check for dangerous attribute access like __class__, __bases__, etc."""
|
||||
if node.attr in BLOCKED_ATTRIBUTES:
|
||||
self.violations.append(f"Access to '{node.attr}' is not allowed")
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def visit_Subscript(self, node):
|
||||
"""Check for builtins subscript access like globals()['__builtins__']['__import__']."""
|
||||
# Check for globals()['__builtins__'] or locals()['__builtins__']
|
||||
if isinstance(node.value, ast.Call):
|
||||
if isinstance(node.value.func, ast.Name) and node.value.func.id in ('globals', 'locals'):
|
||||
if isinstance(node.slice, ast.Constant) and node.slice.value == '__builtins__':
|
||||
self.violations.append("globals()/locals()['__builtins__'] manipulation is not allowed")
|
||||
elif hasattr(node.slice, 's') and node.slice.s == '__builtins__': # Python < 3.8 compatibility
|
||||
self.violations.append("globals()/locals()['__builtins__'] manipulation is not allowed")
|
||||
if isinstance(node.value.func, ast.Name) and node.value.func.id in (
|
||||
"globals",
|
||||
"locals",
|
||||
):
|
||||
if (
|
||||
isinstance(node.slice, ast.Constant)
|
||||
and node.slice.value == "__builtins__"
|
||||
):
|
||||
self.violations.append(
|
||||
"globals()/locals()['__builtins__'] manipulation is not allowed"
|
||||
)
|
||||
elif (
|
||||
hasattr(node.slice, "s") and node.slice.s == "__builtins__"
|
||||
): # Python < 3.8 compatibility
|
||||
self.violations.append(
|
||||
"globals()/locals()['__builtins__'] manipulation is not allowed"
|
||||
)
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def _is_builtins_access(self, node):
|
||||
"""Check if a node represents access to __builtins__."""
|
||||
if isinstance(node, ast.Name) and node.id == '__builtins__':
|
||||
if isinstance(node, ast.Name) and node.id == "__builtins__":
|
||||
return True
|
||||
if isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name) and node.func.id in ('globals', 'locals'):
|
||||
if isinstance(node.func, ast.Name) and node.func.id in (
|
||||
"globals",
|
||||
"locals",
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class MemoryLimitException(RuntimeError):
|
||||
"""Raised when memory limit is exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -224,14 +365,14 @@ class MemoryLimitException(RuntimeError):
|
|||
def check_safety(code: str) -> list:
|
||||
"""Check code for sandbox violations."""
|
||||
# Pre-check for null bytes and other dangerous characters
|
||||
if '\x00' in code:
|
||||
if "\x00" in code:
|
||||
return ["Code contains null bytes which is not allowed"]
|
||||
|
||||
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError:
|
||||
return [] # Let SyntaxError be handled elsewhere
|
||||
|
||||
|
||||
visitor = SandboxVisitor()
|
||||
visitor.visit(tree)
|
||||
return visitor.violations
|
||||
|
|
@ -255,36 +396,44 @@ class REPLSession:
|
|||
"""
|
||||
RLM REPL Session - secure sandbox for recursive LLM execution.
|
||||
"""
|
||||
|
||||
|
||||
class _StderrCapture:
|
||||
"""Mock stderr object for sandbox."""
|
||||
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
|
||||
def write(self, text: str):
|
||||
"""Write to stderr capture."""
|
||||
self._session._stderr.append(text)
|
||||
|
||||
|
||||
def flush(self):
|
||||
"""Flush stderr (no-op)."""
|
||||
pass
|
||||
|
||||
|
||||
class MockSys:
|
||||
"""Mock sys module for sandbox with only stderr."""
|
||||
|
||||
def __init__(self, stderr_capture):
|
||||
self.stderr = stderr_capture
|
||||
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == 'modules':
|
||||
if name == "modules":
|
||||
raise SandboxViolation("Access to sys.modules is not allowed")
|
||||
raise AttributeError(f"sys.{name} is not available in sandbox")
|
||||
|
||||
def __init__(self, chunk_store=None, llm_client=None,
|
||||
max_iterations: int = 10, timeout_seconds: int = 60, max_depth: int = 5,
|
||||
max_cost_usd: Optional[float] = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_store=None,
|
||||
llm_client=None,
|
||||
max_iterations: int = 10,
|
||||
timeout_seconds: int = 60,
|
||||
max_depth: int = 5,
|
||||
max_cost_usd: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Initialize REPL session.
|
||||
|
||||
|
||||
Args:
|
||||
chunk_store: ChunkStore instance for memory access
|
||||
llm_client: LLM client for recursive queries
|
||||
|
|
@ -296,14 +445,14 @@ class REPLSession:
|
|||
raise ValueError("chunk_store is required")
|
||||
if llm_client is None:
|
||||
raise ValueError("llm_client is required")
|
||||
|
||||
|
||||
self.chunk_store = chunk_store
|
||||
self.llm_client = llm_client
|
||||
self.max_iterations = max_iterations
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.max_depth = max_depth
|
||||
self._max_cost_usd = max_cost_usd
|
||||
|
||||
|
||||
self._state: Dict[str, Any] = {} # User state (empty initially)
|
||||
self._iteration_count = 0
|
||||
self._total_cost = 0.0
|
||||
|
|
@ -314,64 +463,75 @@ class REPLSession:
|
|||
self._output = []
|
||||
self._stderr = []
|
||||
self._stderr_capture = self._StderrCapture(self)
|
||||
|
||||
|
||||
# Create isolated namespace for execution
|
||||
self._namespace = {}
|
||||
self._setup_namespace()
|
||||
|
||||
|
||||
def _setup_namespace(self):
|
||||
"""Set up the sandbox namespace."""
|
||||
# Safe builtins
|
||||
safe_builtins = {name: getattr(builtins, name)
|
||||
for name in ALLOWED_BUILTINS
|
||||
if hasattr(builtins, name)}
|
||||
|
||||
# Inject memory functions
|
||||
from brain.scripts.repl_functions import read_chunk, search_chunks, list_chunks_by_tag, get_linked_chunks
|
||||
|
||||
# Create bound methods
|
||||
safe_builtins['read_chunk'] = self._read_chunk_wrapper
|
||||
safe_builtins['search_chunks'] = self._search_chunks_wrapper
|
||||
safe_builtins['list_chunks_by_tag'] = self._list_chunks_by_tag_wrapper
|
||||
safe_builtins['get_linked_chunks'] = self._get_linked_chunks_wrapper
|
||||
safe_builtins['llm_query'] = self._llm_query_wrapper
|
||||
safe_builtins['FINAL'] = self._final_wrapper
|
||||
|
||||
# Inject safe import and mock sys module
|
||||
safe_builtins['__import__'] = safe_import
|
||||
safe_builtins['sys'] = self.MockSys(self._stderr_capture)
|
||||
|
||||
self._namespace = {
|
||||
'__builtins__': safe_builtins,
|
||||
'__name__': '__repl__',
|
||||
safe_builtins = {
|
||||
name: getattr(builtins, name)
|
||||
for name in ALLOWED_BUILTINS
|
||||
if hasattr(builtins, name)
|
||||
}
|
||||
|
||||
|
||||
# Inject memory functions
|
||||
from brain.scripts.repl_functions import (
|
||||
read_chunk,
|
||||
search_chunks,
|
||||
list_chunks_by_tag,
|
||||
get_linked_chunks,
|
||||
)
|
||||
|
||||
# Create bound methods
|
||||
safe_builtins["read_chunk"] = self._read_chunk_wrapper
|
||||
safe_builtins["search_chunks"] = self._search_chunks_wrapper
|
||||
safe_builtins["list_chunks_by_tag"] = self._list_chunks_by_tag_wrapper
|
||||
safe_builtins["get_linked_chunks"] = self._get_linked_chunks_wrapper
|
||||
safe_builtins["llm_query"] = self._llm_query_wrapper
|
||||
safe_builtins["FINAL"] = self._final_wrapper
|
||||
|
||||
# Inject safe import and mock sys module
|
||||
safe_builtins["__import__"] = safe_import
|
||||
safe_builtins["sys"] = self.MockSys(self._stderr_capture)
|
||||
|
||||
self._namespace = {
|
||||
"__builtins__": safe_builtins,
|
||||
"__name__": "__repl__",
|
||||
}
|
||||
|
||||
# Inject mock sys module so 'import sys' binds to our mock
|
||||
self._namespace['sys'] = self.MockSys(self._stderr_capture)
|
||||
|
||||
self._namespace["sys"] = self.MockSys(self._stderr_capture)
|
||||
|
||||
# Merge user state into namespace
|
||||
self._namespace.update(self._state)
|
||||
|
||||
|
||||
def _read_chunk_wrapper(self, chunk_id: str):
|
||||
"""Wrapper for read_chunk."""
|
||||
from repl_functions import read_chunk
|
||||
|
||||
return read_chunk(chunk_id, self.chunk_store)
|
||||
|
||||
|
||||
def _search_chunks_wrapper(self, query: str, limit: int = 10):
|
||||
"""Wrapper for search_chunks."""
|
||||
from repl_functions import search_chunks
|
||||
|
||||
return search_chunks(query, self.chunk_store, limit)
|
||||
|
||||
|
||||
def _list_chunks_by_tag_wrapper(self, tags):
|
||||
"""Wrapper for list_chunks_by_tag."""
|
||||
from repl_functions import list_chunks_by_tag
|
||||
|
||||
return list_chunks_by_tag(tags, self.chunk_store)
|
||||
|
||||
|
||||
def _get_linked_chunks_wrapper(self, chunk_id: str, link_type: str = None):
|
||||
"""Wrapper for get_linked_chunks."""
|
||||
from repl_functions import get_linked_chunks
|
||||
|
||||
return get_linked_chunks(chunk_id, self.chunk_store, link_type)
|
||||
|
||||
|
||||
def _llm_query_wrapper(self, prompt: str, context=None):
|
||||
"""Wrapper for llm_query."""
|
||||
with self._lock:
|
||||
|
|
@ -380,14 +540,16 @@ class REPLSession:
|
|||
raise MaxIterationsError(
|
||||
f"Maximum iterations ({self.max_iterations}) exceeded"
|
||||
)
|
||||
|
||||
|
||||
# Check max depth
|
||||
if self._current_depth >= self.max_depth:
|
||||
raise RecursionError(f"Maximum recursion depth ({self.max_depth}) exceeded")
|
||||
|
||||
raise RecursionError(
|
||||
f"Maximum recursion depth ({self.max_depth}) exceeded"
|
||||
)
|
||||
|
||||
# Increment depth counter
|
||||
self._current_depth += 1
|
||||
|
||||
|
||||
try:
|
||||
self._ensure_budget()
|
||||
# Build full prompt with context
|
||||
|
|
@ -396,11 +558,14 @@ class REPLSession:
|
|||
# Handle context as a list of chunk IDs
|
||||
if isinstance(context, list):
|
||||
from repl_functions import read_chunk
|
||||
|
||||
context_parts = []
|
||||
for chunk_id in context:
|
||||
chunk = read_chunk(chunk_id, self.chunk_store)
|
||||
if chunk:
|
||||
context_parts.append(f"Chunk {chunk_id}:\n{chunk.get('content', '')}")
|
||||
context_parts.append(
|
||||
f"Chunk {chunk_id}:\n{chunk.get('content', '')}"
|
||||
)
|
||||
else:
|
||||
context_parts.append(f"Chunk {chunk_id}:\n[Not found]")
|
||||
context_str = "\n\n".join(context_parts)
|
||||
|
|
@ -408,14 +573,14 @@ class REPLSession:
|
|||
elif isinstance(context, dict):
|
||||
context_str = "\n".join(f"{k}: {v}" for k, v in context.items())
|
||||
full_prompt = f"Context:\n{context_str}\n\nPrompt:\n{prompt}"
|
||||
|
||||
|
||||
# Call LLM
|
||||
response = self.llm_client.complete(full_prompt)
|
||||
|
||||
|
||||
self._record_cost(response)
|
||||
self._ensure_budget(allow_equal=True)
|
||||
|
||||
return response.text if hasattr(response, 'text') else str(response)
|
||||
|
||||
return response.text if hasattr(response, "text") else str(response)
|
||||
except (RecursionError, MaxIterationsError):
|
||||
# Don't catch these - let them propagate
|
||||
raise
|
||||
|
|
@ -426,84 +591,88 @@ class REPLSession:
|
|||
# Decrement depth counter
|
||||
with self._lock:
|
||||
self._current_depth -= 1
|
||||
|
||||
|
||||
def _final_wrapper(self, answer) -> None:
|
||||
"""Wrapper for FINAL."""
|
||||
if self._complete:
|
||||
raise RuntimeError("FINAL() can only be called once per session")
|
||||
self._result = answer
|
||||
self._complete = True
|
||||
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Get current state dictionary (user-defined variables only)."""
|
||||
return self._state.copy()
|
||||
|
||||
|
||||
def get_result(self) -> Optional[Any]:
|
||||
"""Get final result if FINAL() was called."""
|
||||
return self._result
|
||||
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if FINAL() has been called."""
|
||||
return self._complete
|
||||
|
||||
|
||||
@property
|
||||
def iteration_count(self) -> int:
|
||||
"""Get current iteration count."""
|
||||
return self._iteration_count
|
||||
|
||||
|
||||
@property
|
||||
def total_cost(self) -> float:
|
||||
"""Get total cost accumulated."""
|
||||
return self._total_cost
|
||||
|
||||
|
||||
def get_cost(self) -> float:
|
||||
"""Get total cost accumulated."""
|
||||
return self._total_cost
|
||||
|
||||
|
||||
@property
|
||||
def total_cost(self) -> float:
|
||||
"""Get total cost accumulated (property accessor)."""
|
||||
return self._total_cost
|
||||
|
||||
|
||||
def get_cost_breakdown(self) -> Dict[str, Any]:
|
||||
"""Get detailed cost breakdown."""
|
||||
breakdown = {
|
||||
"total": self._total_cost,
|
||||
"calls": self._iteration_count,
|
||||
"per_call_average": self._total_cost / self._iteration_count if self._iteration_count > 0 else 0.0
|
||||
"per_call_average": self._total_cost / self._iteration_count
|
||||
if self._iteration_count > 0
|
||||
else 0.0,
|
||||
}
|
||||
if self._max_cost_usd is not None:
|
||||
remaining = self._max_cost_usd - self._total_cost
|
||||
breakdown.update({
|
||||
"budget": self._max_cost_usd,
|
||||
"remaining": max(0.0, remaining),
|
||||
"over_budget": self._total_cost > self._max_cost_usd
|
||||
})
|
||||
breakdown.update(
|
||||
{
|
||||
"budget": self._max_cost_usd,
|
||||
"remaining": max(0.0, remaining),
|
||||
"over_budget": self._total_cost > self._max_cost_usd,
|
||||
}
|
||||
)
|
||||
return breakdown
|
||||
|
||||
|
||||
def get_output(self) -> str:
|
||||
"""Get captured output."""
|
||||
return "\n".join(self._output)
|
||||
|
||||
|
||||
def get_stderr(self) -> str:
|
||||
"""Get captured stderr."""
|
||||
return "\n".join(self._stderr)
|
||||
|
||||
|
||||
def clear_output(self):
|
||||
"""Clear captured output."""
|
||||
self._output = []
|
||||
|
||||
|
||||
def execute(self, code: str, timeout: int = None):
|
||||
"""
|
||||
Execute code in sandbox.
|
||||
|
||||
|
||||
Args:
|
||||
code: Python code to execute
|
||||
timeout: Optional timeout override
|
||||
|
||||
|
||||
Returns:
|
||||
Result of the last expression or None
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If called after FINAL()
|
||||
SandboxViolation: If code violates sandbox
|
||||
|
|
@ -511,81 +680,84 @@ class REPLSession:
|
|||
"""
|
||||
if self._complete:
|
||||
raise RuntimeError("REPL already complete")
|
||||
|
||||
|
||||
if not code or not code.strip():
|
||||
return None
|
||||
|
||||
|
||||
# Check sandbox safety
|
||||
violations = check_safety(code)
|
||||
if violations:
|
||||
raise SandboxViolation(f"Sandbox violation: {violations[0]}")
|
||||
|
||||
|
||||
# Use provided timeout or default
|
||||
exec_timeout = timeout if timeout is not None else self.timeout_seconds
|
||||
|
||||
|
||||
# Capture stdout/stderr
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
stdout_capture = io.StringIO()
|
||||
stderr_capture = io.StringIO()
|
||||
|
||||
|
||||
# Container for execution results
|
||||
result_container = {'result': None, 'error': None, 'completed': False}
|
||||
|
||||
result_container = {"result": None, "error": None, "completed": False}
|
||||
|
||||
def run_execution():
|
||||
try:
|
||||
sys.stdout = stdout_capture
|
||||
sys.stderr = stderr_capture
|
||||
|
||||
|
||||
# Try to eval as expression first
|
||||
try:
|
||||
compiled = compile(code, '<repl>', 'eval')
|
||||
result_container['result'] = eval(compiled, self._namespace)
|
||||
result_container['completed'] = True
|
||||
compiled = compile(code, "<repl>", "eval")
|
||||
result_container["result"] = eval(compiled, self._namespace)
|
||||
result_container["completed"] = True
|
||||
return
|
||||
except SyntaxError:
|
||||
# Not an expression, try exec
|
||||
pass
|
||||
|
||||
|
||||
# Compile and execute as statements
|
||||
compiled = compile(code, '<repl>', 'exec')
|
||||
compiled = compile(code, "<repl>", "exec")
|
||||
exec(compiled, self._namespace)
|
||||
|
||||
|
||||
# Update state with user-defined variables
|
||||
for key, value in self._namespace.items():
|
||||
if not key.startswith('_') and key not in ('__builtins__', '__name__'):
|
||||
if not key.startswith("_") and key not in (
|
||||
"__builtins__",
|
||||
"__name__",
|
||||
):
|
||||
self._state[key] = value
|
||||
|
||||
result_container['completed'] = True
|
||||
|
||||
|
||||
result_container["completed"] = True
|
||||
|
||||
except Exception as e:
|
||||
result_container['error'] = e
|
||||
|
||||
result_container["error"] = e
|
||||
|
||||
# Run execution in a thread with timeout
|
||||
exec_thread = threading.Thread(target=run_execution)
|
||||
exec_thread.daemon = True
|
||||
|
||||
|
||||
try:
|
||||
sys.stdout = stdout_capture
|
||||
sys.stderr = stderr_capture
|
||||
|
||||
|
||||
exec_thread.start()
|
||||
exec_thread.join(timeout=exec_timeout)
|
||||
|
||||
|
||||
if exec_thread.is_alive():
|
||||
# Thread is still running after timeout
|
||||
raise TimeoutError(f"Execution exceeded {exec_timeout} seconds")
|
||||
|
||||
|
||||
# Check for errors from the thread
|
||||
if result_container['error'] is not None:
|
||||
raise result_container['error']
|
||||
|
||||
if result_container["error"] is not None:
|
||||
raise result_container["error"]
|
||||
|
||||
# Capture output
|
||||
self._output.append(stdout_capture.getvalue())
|
||||
self._stderr.append(stderr_capture.getvalue())
|
||||
|
||||
return result_container['result']
|
||||
|
||||
|
||||
return result_container["result"]
|
||||
|
||||
except TimeoutError:
|
||||
raise
|
||||
except RecursionError:
|
||||
|
|
@ -623,25 +795,25 @@ class REPLSession:
|
|||
finally:
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
|
||||
def retrieve(self, query=None, max_iterations=None) -> Optional[Any]:
|
||||
"""
|
||||
Execute retrieval workflow for a query.
|
||||
|
||||
|
||||
Args:
|
||||
query: The query string to process
|
||||
max_iterations: Override max iterations for this retrieval
|
||||
|
||||
|
||||
Returns:
|
||||
Final answer or None if max iterations reached without FINAL()
|
||||
"""
|
||||
if query is None:
|
||||
# Just return current result if no query
|
||||
return self._result if self._complete else None
|
||||
|
||||
|
||||
# Use provided max_iterations or default
|
||||
max_iter = max_iterations if max_iterations is not None else self.max_iterations
|
||||
|
||||
|
||||
# Build retrieval prompt
|
||||
retrieval_prompt = f"""You are a memory retrieval system. Answer the following query using the available memory functions.
|
||||
|
||||
|
|
@ -656,38 +828,40 @@ Available functions:
|
|||
Query: {query}
|
||||
|
||||
Write Python code to solve this query. Use FINAL('your answer') when done."""
|
||||
|
||||
|
||||
# Iterative retrieval loop
|
||||
for iteration in range(max_iter):
|
||||
self._iteration_count += 1
|
||||
|
||||
|
||||
# Get LLM response
|
||||
try:
|
||||
self._ensure_budget()
|
||||
response = self.llm_client.complete(retrieval_prompt)
|
||||
code = response.text if hasattr(response, 'text') else str(response)
|
||||
code = response.text if hasattr(response, "text") else str(response)
|
||||
self._record_cost(response)
|
||||
self._ensure_budget(allow_equal=True)
|
||||
except Exception as e:
|
||||
# API error - return error message
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
# Execute the code
|
||||
try:
|
||||
result = self.execute(code)
|
||||
|
||||
|
||||
# Check if FINAL was called
|
||||
if self._complete:
|
||||
return self._result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Execution error - add to prompt and continue
|
||||
retrieval_prompt += f"\n\nError in previous attempt: {str(e)}\nPlease try again."
|
||||
retrieval_prompt += (
|
||||
f"\n\nError in previous attempt: {str(e)}\nPlease try again."
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
# Max iterations reached without FINAL
|
||||
return None
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset session state."""
|
||||
self._state = {}
|
||||
|
|
@ -703,9 +877,11 @@ Write Python code to solve this query. Use FINAL('your answer') when done."""
|
|||
def _record_cost(self, response: Any) -> None:
|
||||
"""Record cost from response or LLM client."""
|
||||
cost_value = None
|
||||
if hasattr(response, 'cost_usd'):
|
||||
if hasattr(response, "cost_usd"):
|
||||
cost_value = response.cost_usd
|
||||
elif hasattr(self.llm_client, 'get_cost') and callable(self.llm_client.get_cost):
|
||||
elif hasattr(self.llm_client, "get_cost") and callable(
|
||||
self.llm_client.get_cost
|
||||
):
|
||||
cost_value = self.llm_client.get_cost()
|
||||
if not isinstance(cost_value, (int, float)):
|
||||
return
|
||||
|
|
@ -723,11 +899,11 @@ Write Python code to solve this query. Use FINAL('your answer') when done."""
|
|||
raise CostBudgetExceededError(
|
||||
f"Cost budget exceeded: total_cost={self._total_cost:.6f} budget={self._max_cost_usd:.6f}"
|
||||
)
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.reset()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue