""" RLM-MEM - REPL Environment Tests (D1.3) Linus-style rigorous tests for the RLM REPL sandbox. Run: python brain/scripts/test_repl.py """ import unittest from unittest.mock import Mock, patch, call, MagicMock import tempfile import shutil import threading import time import sys import io import contextlib from pathlib import Path from typing import Optional, Dict, Any, List # Import the modules under test (will be created in D1.3) try: from repl_environment import REPLSession, FINAL, llm_query, SandboxViolation from repl_functions import read_chunk, search_chunks, list_chunks_by_tag, get_linked_chunks except ImportError: # Placeholder for when modules don't exist yet REPLSession = None FINAL = None llm_query = None SandboxViolation = None read_chunk = None search_chunks = None list_chunks_by_tag = None get_linked_chunks = None # Skip all tests if REPL module doesn't exist yet @unittest.skipIf(REPLSession is None, "REPL Environment not yet implemented") class TestREPLInitialization(unittest.TestCase): """Test REPL setup and configuration.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() self.base_path = Path(self.temp_dir) / "brain" / "memory" # Mock ChunkStore self.mock_store = Mock() self.mock_store.base_path = self.base_path # Mock LLM client self.mock_llm = Mock() self.mock_llm.complete = Mock(return_value="FINAL('test answer')") def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_requires_chunk_store(self): """Should fail fast if ChunkStore not provided.""" with self.assertRaises((ValueError, TypeError)): REPLSession(chunk_store=None, llm_client=self.mock_llm) def test_requires_llm_client(self): """Should fail fast if LLM client not provided.""" with self.assertRaises((ValueError, TypeError)): REPLSession(chunk_store=self.mock_store, llm_client=None) def test_initial_state_empty(self): """Fresh REPL should have empty state.""" repl = REPLSession(chunk_store=self.mock_store, llm_client=self.mock_llm) self.assertEqual(repl.get_state(), {}) self.assertIsNone(repl.get_result()) self.assertEqual(repl.iteration_count, 0) def test_initialization_with_config(self): """Should accept configuration parameters.""" repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm, max_iterations=5, timeout_seconds=30 ) self.assertEqual(repl.max_iterations, 5) self.assertEqual(repl.timeout_seconds, 30) @unittest.skipIf(REPLSession is None, "REPL Environment not yet implemented") class TestSafeExecution(unittest.TestCase): """Test Python sandboxing - CRITICAL for security.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() self.mock_store = Mock() self.mock_store.base_path = Path(self.temp_dir) self.mock_llm = Mock() self.repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm ) def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_blocks_import(self): """Should block __import__ attempts.""" # Malicious: __import__('os').system('rm -rf /') with self.assertRaises(SandboxViolation): self.repl.execute('__import__("os")') def test_blocks_import_statement(self): """Should block import statements.""" with self.assertRaises(SandboxViolation): self.repl.execute('import os') def test_blocks_from_import(self): """Should block from...import statements.""" with self.assertRaises(SandboxViolation): self.repl.execute('from os import system') def test_blocks_open(self): """Should block file open attempts.""" # Malicious: open('/etc/passwd').read() with self.assertRaises(SandboxViolation): self.repl.execute('open("/etc/passwd")') def test_blocks_file_builtin(self): """Should block file() builtin if Python 2 style.""" result = self.repl.execute('file("/etc/passwd")') # In Python 3, file() doesn't exist so it's a NameError # Should be caught and returned as error string self.assertIn("name", str(result).lower()) def test_blocks_exec(self): """Should block exec() calls.""" with self.assertRaises(SandboxViolation): self.repl.execute('exec("import os")') def test_blocks_eval(self): """Should block eval() calls.""" with self.assertRaises(SandboxViolation): self.repl.execute('eval("1 + 1")') def test_blocks_compile(self): """Should block compile() calls.""" with self.assertRaises(SandboxViolation): self.repl.execute('compile("pass", "", "exec")') def test_blocks_subprocess(self): """Should block subprocess imports and calls.""" with self.assertRaises(SandboxViolation): self.repl.execute('import subprocess; subprocess.call(["ls"])') def test_blocks_sys_modules_manipulation(self): """Should block sys.modules manipulation.""" with self.assertRaises(SandboxViolation): self.repl.execute('import sys; sys.modules["os"] = None') def test_allows_safe_builtins(self): """Should allow len(), str(), list(), dict().""" result = self.repl.execute('len("hello")') self.assertEqual(result, 5) result = self.repl.execute('str(42)') self.assertEqual(result, "42") result = self.repl.execute('list([1, 2, 3])') self.assertEqual(result, [1, 2, 3]) result = self.repl.execute('dict(a=1, b=2)') self.assertEqual(result, {"a": 1, "b": 2}) def test_allows_safe_math(self): """Should allow basic math operations.""" result = self.repl.execute('2 + 2 * 10') self.assertEqual(result, 22) result = self.repl.execute('max([1, 5, 3])') self.assertEqual(result, 5) def test_allows_string_operations(self): """Should allow string methods.""" result = self.repl.execute('"hello world".upper()') self.assertEqual(result, "HELLO WORLD") result = self.repl.execute('"a,b,c".split(",")') self.assertEqual(result, ["a", "b", "c"]) def test_path_traversal_in_code(self): """Should prevent path traversal in any code execution.""" # Even if disguised as string manipulation with self.assertRaises(SandboxViolation): self.repl.execute('open(".." + "/" * 10 + "etc/passwd")') @unittest.skipIf(REPLSession is None, "REPL Environment not yet implemented") class TestREPLFunctions(unittest.TestCase): """Test functions exposed to LLM.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() # Create mock store with test data self.mock_store = Mock() self.mock_chunk = Mock() self.mock_chunk.id = "chunk-2026-02-10-abc123" self.mock_chunk.content = "Test chunk content" self.mock_chunk.type = "note" self.mock_chunk.tags = ["test", "important"] self.mock_chunk.metadata.confidence = 0.9 self.mock_chunk.links.context_of = [] self.mock_chunk.links.related_to = ["chunk-2026-02-10-def456"] self.mock_store.get_chunk = Mock(return_value=self.mock_chunk) self.mock_store.list_chunks = Mock(return_value=[ "chunk-2026-02-10-abc123", "chunk-2026-02-10-def456" ]) self.mock_llm = Mock() self.repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm ) def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_read_chunk_returns_dict(self): """read_chunk() should return chunk as dict.""" result = self.repl.execute('read_chunk("chunk-2026-02-10-abc123")') self.assertIsInstance(result, dict) self.assertEqual(result["id"], "chunk-2026-02-10-abc123") self.assertEqual(result["content"], "Test chunk content") self.assertIn("tags", result) def test_read_chunk_nonexistent_returns_none(self): """read_chunk() for missing chunk should return None, not crash.""" self.mock_store.get_chunk = Mock(return_value=None) result = self.repl.execute('read_chunk("chunk-nonexistent")') self.assertIsNone(result) def test_read_chunk_invalid_id(self): """read_chunk() should validate chunk ID format.""" result = self.repl.execute('read_chunk("../../../etc/passwd")') # Should return None or raise specific error, not attempt file access self.assertIsNone(result) def test_search_chunks_returns_list(self): """search_chunks() should return list of chunk IDs.""" # Setup mock to return some chunks self.mock_store.search_chunks = Mock(return_value=[ "chunk-2026-02-10-abc123", "chunk-2026-02-10-def456" ]) result = self.repl.execute('search_chunks("test query")') self.assertIsInstance(result, list) self.assertEqual(len(result), 2) self.assertIn("chunk-2026-02-10-abc123", result) def test_search_chunks_empty_result(self): """search_chunks() should return empty list when no matches.""" self.mock_store.search_chunks = Mock(return_value=[]) result = self.repl.execute('search_chunks("nonexistent")') self.assertEqual(result, []) def test_list_chunks_by_tag(self): """list_chunks_by_tag() should filter by tag.""" self.mock_store.list_chunks = Mock(return_value=[ "chunk-2026-02-10-abc123" ]) result = self.repl.execute('list_chunks_by_tag("important")') self.assertIsInstance(result, list) self.mock_store.list_chunks.assert_called_with(tags=["important"]) def test_list_chunks_by_multiple_tags(self): """list_chunks_by_tag() should support multiple tags.""" self.repl.execute('list_chunks_by_tag(["test", "important"])') self.mock_store.list_chunks.assert_called_with(tags=["test", "important"]) def test_get_linked_chunks(self): """get_linked_chunks() should follow links.""" linked_chunk = Mock() linked_chunk.id = "chunk-2026-02-10-def456" linked_chunk.content = "Linked content" self.mock_store.get_chunk = Mock(side_effect=[self.mock_chunk, linked_chunk, None]) result = self.repl.execute('get_linked_chunks("chunk-2026-02-10-abc123")') self.assertIsInstance(result, list) @unittest.skipIf(REPLSession is None, "REPL Environment not yet implemented") class TestLLMQuery(unittest.TestCase): """Test recursive llm_query() function.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() self.mock_store = Mock() self.mock_store.base_path = Path(self.temp_dir) # Mock LLM client self.mock_llm = Mock() self.mock_llm.complete = Mock(return_value="FINAL('recursive result')") self.mock_llm.get_cost = Mock(return_value=0.001) self.repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm, max_depth=3 ) def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_makes_api_call(self): """llm_query() should call LLM client with prompt.""" self.repl.execute('llm_query("Analyze this")') self.mock_llm.complete.assert_called() call_args = self.mock_llm.complete.call_args self.assertIn("Analyze this", str(call_args)) def test_passes_context_chunks(self): """llm_query() should include context chunk contents.""" context = ["chunk-2026-02-10-abc123", "chunk-2026-02-10-def456"] self.repl.execute(f'llm_query("Analyze", context={context})') call_args = self.mock_llm.complete.call_args # Should have passed context to LLM self.assertIn("chunk", str(call_args).lower()) def test_tracks_cost(self): """llm_query() should update cost tracking.""" initial_cost = self.repl.total_cost self.repl.execute('llm_query("Test query")') self.assertGreater(self.repl.total_cost, initial_cost) def test_handles_api_error(self): """llm_query() should handle API failures gracefully.""" self.mock_llm.complete = Mock(side_effect=Exception("API Error: Rate limited")) result = self.repl.execute('llm_query("Test")') # Should return error info, not crash self.assertIn("error", str(result).lower()) def test_respects_max_depth(self): """llm_query() should fail if recursion too deep.""" # Simulate deep recursion self.repl._current_depth = 3 with self.assertRaises((RecursionError, RuntimeError)): self.repl.execute('llm_query("Deep call")') def test_increments_depth_counter(self): """Each llm_query should increment and decrement depth counter.""" self.repl.execute('llm_query("Test")') # After execution, depth should be back to 0 self.assertEqual(self.repl._current_depth, 0) @unittest.skipIf(REPLSession is None, "REPL Environment not yet implemented") class TestFinalTermination(unittest.TestCase): """Test FINAL() termination condition.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() self.mock_store = Mock() self.mock_llm = Mock() self.repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm ) def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_final_sets_result(self): """FINAL('answer') should set result and signal completion.""" self.repl.execute("FINAL('my answer')") self.assertEqual(self.repl.get_result(), "my answer") self.assertTrue(self.repl.is_complete()) def test_final_with_complex_answer(self): """FINAL should handle complex answer types.""" complex_answer = {"key": "value", "list": [1, 2, 3]} self.repl.execute(f"FINAL({complex_answer})") result = self.repl.get_result() self.assertEqual(result, complex_answer) def test_final_stops_iteration(self): """After FINAL(), REPL should stop executing.""" self.repl.execute("FINAL('done')") # Trying to execute more should raise with self.assertRaises(RuntimeError): self.repl.execute("print('after final')") def test_retrieve_returns_final_answer(self): """retrieve() should return answer passed to FINAL().""" self.repl.execute("FINAL('the answer is 42')") result = self.repl.retrieve() self.assertEqual(result, "the answer is 42") def test_retrieve_before_final_raises(self): """retrieve() before FINAL should raise or return None.""" result = self.repl.retrieve() self.assertIsNone(result) @unittest.skipIf(REPLSession is None, "REPL Environment not yet implemented") class TestStatePersistence(unittest.TestCase): """Test variable persistence across iterations.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() self.mock_store = Mock() self.mock_llm = Mock() self.repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm ) def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_variables_persist(self): """Variables set in iteration 1 should be available in iteration 2.""" self.repl.execute('x = 42') result = self.repl.execute('x * 2') self.assertEqual(result, 84) def test_variables_across_multiple_iterations(self): """Variables should persist across many iterations.""" for i in range(5): self.repl.execute(f'counter = {i}') result = self.repl.execute('counter') self.assertEqual(result, 4) def test_data_structures_persist(self): """Complex data structures should persist.""" self.repl.execute('data = {"key": [1, 2, 3], "nested": {"a": "b"}}') result = self.repl.execute('data["nested"]["a"]') self.assertEqual(result, "b") def test_output_captured(self): """print() output should be captured and accessible.""" self.repl.execute('print("hello world")') output = self.repl.get_output() self.assertIn("hello world", output) def test_stderr_captured(self): """stderr should be captured separately.""" self.repl.execute('import sys; sys.stderr.write("error message")') stderr = self.repl.get_stderr() self.assertIn("error message", stderr) def test_clear_output(self): """clear_output() should reset captured output.""" self.repl.execute('print("before")') self.repl.clear_output() self.repl.execute('print("after")') output = self.repl.get_output() self.assertNotIn("before", output) self.assertIn("after", output) @unittest.skipIf(REPLSession is None, "REPL Environment not yet implemented") class TestRetrieveWorkflow(unittest.TestCase): """Test full RLM retrieval workflow.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() # Setup mock store with test chunks self.mock_store = Mock() self.mock_store.list_chunks = Mock(return_value=[ "chunk-2026-02-10-abc123", "chunk-2026-02-10-def456" ]) chunk1 = Mock() chunk1.id = "chunk-2026-02-10-abc123" chunk1.content = "User likes Python" chunk1.tags = ["preference", "python"] chunk2 = Mock() chunk2.id = "chunk-2026-02-10-def456" chunk2.content = "User prefers TypeScript" chunk2.tags = ["preference", "typescript"] self.mock_store.get_chunk = Mock(side_effect=lambda x: chunk1 if x == "chunk-2026-02-10-abc123" else chunk2 if x == "chunk-2026-02-10-def456" else None) self.mock_store.search_chunks = Mock(return_value=["chunk-2026-02-10-abc123"]) self.mock_llm = Mock() self.repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm, max_iterations=5 ) def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_single_iteration_success(self): """Simple query answered in one iteration.""" # LLM calls FINAL() immediately self.mock_llm.complete = Mock(return_value="FINAL('Python')") result = self.repl.retrieve("What language does the user like?") self.assertEqual(result, "Python") self.assertEqual(self.repl.iteration_count, 1) def test_multi_iteration_success(self): """Query requiring multiple llm_query() calls.""" # First iteration: search chunks # Second iteration: FINAL(answer) responses = [ "candidates = search_chunks('Python'); read_chunk(candidates[0])", "FINAL('User likes Python')" ] self.mock_llm.complete = Mock(side_effect=responses) result = self.repl.retrieve("What does the user like?") self.assertEqual(result, "User likes Python") self.assertEqual(self.repl.iteration_count, 2) def test_retrieve_tracks_cost(self): """retrieve() should track LLM cost.""" response = Mock() response.text = "FINAL('Python')" response.cost_usd = 0.005 self.mock_llm.complete = Mock(return_value=response) result = self.repl.retrieve("What language does the user like?") self.assertEqual(result, "Python") self.assertEqual(self.repl.total_cost, 0.005) def test_max_iterations_timeout(self): """Should return None if max_iterations reached without FINAL().""" # LLM never calls FINAL() self.mock_llm.complete = Mock(return_value="print('still thinking')") result = self.repl.retrieve("Complex query", max_iterations=3) self.assertIsNone(result) self.assertEqual(self.repl.iteration_count, 3) def test_no_chunks_found(self): """Should handle case where no relevant chunks exist.""" self.mock_store.search_chunks = Mock(return_value=[]) self.mock_llm.complete = Mock(return_value="FINAL(None)") result = self.repl.retrieve("Query with no matches") self.assertIsNone(result) @unittest.skipIf(REPLSession is None, "REPL Environment not yet implemented") class TestEdgeCases(unittest.TestCase): """Edge cases and adversarial inputs.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() self.mock_store = Mock() self.mock_store.base_path = Path(self.temp_dir) self.mock_llm = Mock() self.repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm ) def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_empty_code(self): """Executing empty code should not crash.""" result = self.repl.execute("") self.assertIsNone(result) def test_whitespace_only_code(self): """Executing whitespace-only code should not crash.""" result = self.repl.execute(" \n\t ") self.assertIsNone(result) def test_very_long_code(self): """Very long Python code should be handled.""" # 100+ lines long_code = "\n".join([f"x{i} = {i}" for i in range(100)]) long_code += "\nresult = sum([x{} for x in range(100)])" result = self.repl.execute(long_code) # Should complete without error self.assertIsNotNone(result) def test_unicode_in_code(self): """Unicode in code or output should work.""" result = self.repl.execute('emoji = "🎉🚀💻"') self.assertIsNone(result) # Assignment returns None result = self.repl.execute('emoji') self.assertEqual(result, "🎉🚀💻") def test_unicode_in_variables(self): """Unicode variable names should work.""" result = self.repl.execute('变量 = "hello"') result = self.repl.execute('变量') self.assertEqual(result, "hello") def test_syntax_error(self): """Syntax errors should be caught and reported.""" result = self.repl.execute('if True print("missing colon")') self.assertIn("syntax", str(result).lower()) def test_runtime_error(self): """Runtime errors should be caught and reported.""" result = self.repl.execute('1 / 0') self.assertIn("zero", str(result).lower()) def test_name_error(self): """Name errors should be caught and reported.""" result = self.repl.execute('undefined_variable') self.assertIn("name", str(result).lower()) def test_attribute_error(self): """Attribute errors should be caught and reported.""" result = self.repl.execute('"string".nonexistent_method()') self.assertIn("attribute", str(result).lower()) def test_infinite_loop_timeout(self): """Infinite loops should be terminated.""" start_time = time.time() with self.assertRaises((TimeoutError, RuntimeError)): self.repl.execute('while True: pass', timeout=1) elapsed = time.time() - start_time self.assertLess(elapsed, 3) # Should timeout quickly def test_memory_exhaustion_prevention(self): """Should prevent memory exhaustion from large allocations.""" with self.assertRaises((MemoryError, RuntimeError)): self.repl.execute('x = "x" * (1024 * 1024 * 100)') # 100MB string def test_recursion_limit(self): """Deep recursion should be caught.""" with self.assertRaises((RecursionError, RuntimeError)): self.repl.execute(''' def recurse(n): return recurse(n + 1) recurse(0) ''') def test_special_characters_in_strings(self): """Special characters in strings should be handled.""" special = 'special = "\\n\\t\\r\\x00\\xff"' self.repl.execute(special) result = self.repl.execute('len(special)') self.assertEqual(result, 5) # \n, \t, \r, \x00, \xff = 5 chars def test_very_long_string(self): """Very long strings should be handled.""" self.repl.execute('long_str = "x" * 10000') result = self.repl.execute('len(long_str)') self.assertEqual(result, 10000) @unittest.skipIf(REPLSession is None, "REPL Environment not yet implemented") class TestSecurity(unittest.TestCase): """Security tests - sandbox escape attempts.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() self.mock_store = Mock() self.mock_store.base_path = Path(self.temp_dir) self.mock_llm = Mock() self.repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm ) def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_blocks_getattr_exploitation(self): """Should block getattr exploitation for builtins.""" with self.assertRaises(SandboxViolation): self.repl.execute('getattr(__builtins__, "__import__")("os")') def test_blocks_globals_manipulation(self): """Should block globals() manipulation.""" with self.assertRaises(SandboxViolation): self.repl.execute('globals()["__builtins__"]["__import__"]("os")') def test_blocks_locals_manipulation(self): """Should block locals() manipulation.""" with self.assertRaises(SandboxViolation): self.repl.execute('locals()["__builtins__"]["__import__"]("os")') def test_blocks_class_bases_exploit(self): """Should block class base exploitation.""" with self.assertRaises(SandboxViolation): self.repl.execute('().__class__.__bases__[0].__subclasses__()') def test_blocks_code_object_creation(self): """Should block direct code object manipulation.""" with self.assertRaises(SandboxViolation): self.repl.execute('type(compile("1", "", "eval"))(0,0,0,0,0,0,b"\x00")') def test_blocks_del_builtins(self): """Should prevent deletion of safety mechanisms.""" with self.assertRaises((SandboxViolation, TypeError)): self.repl.execute('del __builtins__.open') def test_blocks_setattr_on_builtins(self): """Should block setattr on builtins.""" with self.assertRaises(SandboxViolation): self.repl.execute('setattr(__builtins__, "evil", lambda: None)') @unittest.skipIf(REPLSession is None, "REPL Environment not yet implemented") class TestConcurrency(unittest.TestCase): """Test thread safety.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() self.mock_store = Mock() self.mock_store.base_path = Path(self.temp_dir) self.mock_llm = Mock() self.mock_llm.complete = Mock(return_value="FINAL('result')") def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_isolated_instances(self): """Multiple REPL instances should not interfere.""" repl1 = REPLSession(chunk_store=self.mock_store, llm_client=self.mock_llm) repl2 = REPLSession(chunk_store=self.mock_store, llm_client=self.mock_llm) repl1.execute('x = 42') repl2.execute('x = 99') result1 = repl1.execute('x') result2 = repl2.execute('x') self.assertEqual(result1, 42) self.assertEqual(result2, 99) def test_concurrent_execution(self): """Concurrent execution in different instances should be safe.""" results = [] errors = [] def worker(instance_id): try: repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm ) repl.execute(f'instance = {instance_id}') result = repl.execute('instance') results.append((instance_id, result)) except Exception as e: errors.append((instance_id, str(e))) threads = [ threading.Thread(target=worker, args=(i,)) for i in range(5) ] for t in threads: t.start() for t in threads: t.join() self.assertEqual(len(errors), 0) self.assertEqual(len(results), 5) @unittest.skipIf(REPLSession is None, "REPL Environment not yet implemented") class TestCostTracking(unittest.TestCase): """Test cost tracking functionality.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() self.mock_store = Mock() self.mock_llm = Mock() self.mock_llm.complete = Mock(return_value="FINAL('answer')") self.mock_llm.get_cost = Mock(return_value=0.002) self.repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm ) def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_initial_cost_zero(self): """Initial cost should be zero.""" self.assertEqual(self.repl.total_cost, 0) def test_cost_accumulates(self): """Cost should accumulate across llm_query calls.""" self.repl.execute('llm_query("q1")') self.repl.execute('llm_query("q2")') self.assertEqual(self.repl.total_cost, 0.004) def test_budget_exceeded(self): """Should signal when budget is exceeded.""" budgeted_repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm, max_cost_usd=0.003 ) budgeted_repl.execute('llm_query("q1")') result = budgeted_repl.execute('llm_query("q2")') self.assertIn("budget", str(result).lower()) self.assertGreater(budgeted_repl.total_cost, 0.003) def test_get_cost_breakdown(self): """Should provide cost breakdown.""" self.repl.execute('llm_query("test")') breakdown = self.repl.get_cost_breakdown() self.assertIn("total", breakdown) self.assertIn("calls", breakdown) @unittest.skipIf(REPLSession is None, "REPL Environment not yet implemented") class TestContextManagement(unittest.TestCase): """Test REPL context management.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() self.mock_store = Mock() self.mock_llm = Mock() self.repl = REPLSession( chunk_store=self.mock_store, llm_client=self.mock_llm ) def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_context_manager(self): """Should work as context manager.""" with REPLSession(self.mock_store, self.mock_llm) as repl: repl.execute('x = 42') self.assertEqual(repl.execute('x'), 42) def test_reset_clears_state(self): """reset() should clear all state.""" self.repl.execute('x = 42') self.repl.execute('FINAL("done")') self.repl.reset() self.assertEqual(self.repl.get_state(), {}) self.assertIsNone(self.repl.get_result()) self.assertFalse(self.repl.is_complete()) self.assertEqual(self.repl.iteration_count, 0) # Mock implementations for testing the test structure itself class MockREPLSession: """Mock REPL for validating test structure before implementation.""" def __init__(self, chunk_store, llm_client, max_iterations=10, timeout_seconds=60, max_depth=5): if chunk_store is None: raise ValueError("chunk_store is required") if llm_client is None: raise ValueError("llm_client is required") self.chunk_store = chunk_store self.llm_client = llm_client self.max_iterations = max_iterations self.timeout_seconds = timeout_seconds self.max_depth = max_depth self._state = {} self._result = None self._complete = False self._iteration_count = 0 self._total_cost = 0.0 self._output = [] self._stderr = [] self._current_depth = 0 def get_state(self): return self._state.copy() def get_result(self): return self._result def is_complete(self): return self._complete @property def iteration_count(self): return self._iteration_count @property def total_cost(self): return self._total_cost def execute(self, code, timeout=None): """Mock execute - just validates structure.""" if self._complete: raise RuntimeError("REPL already complete") if not code or not code.strip(): return None self._iteration_count += 1 # Check for FINAL if code.strip().startswith("FINAL("): self._result = eval(code.strip()[6:-1]) self._complete = True return self._result return None def retrieve(self, query=None, max_iterations=None): if self._complete: return self._result return None def reset(self): self._state = {} self._result = None self._complete = False self._iteration_count = 0 def get_output(self): return "\n".join(self._output) def clear_output(self): self._output = [] class TestMockStructure(unittest.TestCase): """Verify the test structure itself works.""" def test_mock_initialization(self): """Mock should initialize properly.""" mock_store = Mock() mock_llm = Mock() repl = MockREPLSession(mock_store, mock_llm) self.assertEqual(repl.get_state(), {}) self.assertIsNone(repl.get_result()) def test_mock_final(self): """Mock should handle FINAL.""" mock_store = Mock() mock_llm = Mock() repl = MockREPLSession(mock_store, mock_llm) repl.execute('FINAL("answer")') self.assertEqual(repl.get_result(), "answer") self.assertTrue(repl.is_complete()) def run_tests(): """Run all tests with verbose output.""" loader = unittest.TestLoader() suite = unittest.TestSuite() # Add all test classes suite.addTests(loader.loadTestsFromTestCase(TestMockStructure)) suite.addTests(loader.loadTestsFromTestCase(TestREPLInitialization)) suite.addTests(loader.loadTestsFromTestCase(TestSafeExecution)) suite.addTests(loader.loadTestsFromTestCase(TestREPLFunctions)) suite.addTests(loader.loadTestsFromTestCase(TestLLMQuery)) suite.addTests(loader.loadTestsFromTestCase(TestFinalTermination)) suite.addTests(loader.loadTestsFromTestCase(TestStatePersistence)) suite.addTests(loader.loadTestsFromTestCase(TestRetrieveWorkflow)) suite.addTests(loader.loadTestsFromTestCase(TestEdgeCases)) suite.addTests(loader.loadTestsFromTestCase(TestSecurity)) suite.addTests(loader.loadTestsFromTestCase(TestConcurrency)) suite.addTests(loader.loadTestsFromTestCase(TestCostTracking)) suite.addTests(loader.loadTestsFromTestCase(TestContextManagement)) runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) return result.wasSuccessful() if __name__ == "__main__": # Check if REPL is implemented if REPLSession is None: print("=" * 70) print("REPL Environment not yet implemented (D1.3)") print("=" * 70) print("\nTests defined (will run when REPL is implemented):") print(" - TestREPLInitialization: 4 tests") print(" - TestSafeExecution: 14 tests (security critical)") print(" - TestREPLFunctions: 8 tests") print(" - TestLLMQuery: 6 tests") print(" - TestFinalTermination: 5 tests") print(" - TestStatePersistence: 6 tests") print(" - TestRetrieveWorkflow: 4 tests") print(" - TestEdgeCases: 14 tests") print(" - TestSecurity: 7 tests") print(" - TestConcurrency: 2 tests") print(" - TestCostTracking: 3 tests") print(" - TestContextManagement: 2 tests") print("\nTotal: 75 tests ready to run") print("=" * 70) exit(0) else: success = run_tests() exit(0 if success else 1)