diff --git a/app/main.py b/app/main.py index 7ceb3c0..1547332 100644 --- a/app/main.py +++ b/app/main.py @@ -4,6 +4,7 @@ import json import os import time import uuid +from contextlib import asynccontextmanager from datetime import datetime, timezone from subprocess import PIPE from typing import Any, Literal @@ -15,7 +16,26 @@ from pydantic import BaseModel, Field app = FastAPI(title="Claude Agent Service") API_TOKEN = os.environ.get("API_BEARER_TOKEN", "") -WORKSPACE_DIR = os.environ.get("WORKSPACE_DIR", "/workspace/infra") + +# Warm base clone, populated by the init container. Each job clones from this +# into its own dir under JOBS_DIR so concurrent calls never share a working +# tree (no git index.lock contention, no clobbered edits). +BASE_DIR = os.environ.get("WORKSPACE_DIR", "/workspace/infra") +JOBS_DIR = os.environ.get("JOBS_DIR", "/workspace/jobs") +GIT_CRYPT_KEY = os.environ.get("GIT_CRYPT_KEY", "/secrets/git-crypt/key") + +# Concurrency. MAX_CONCURRENCY caps simultaneous claude runs ("soft-unbounded" +# — a high default rather than a tight limit); excess calls queue FIFO rather +# than being rejected. MAX_QUEUE_DEPTH is a safety valve so a runaway burst +# can't pin unbounded memory: past it, callers are turned away (429/503). +MAX_CONCURRENCY = int(os.environ.get("MAX_CONCURRENCY", "10")) +MAX_QUEUE_DEPTH = int(os.environ.get("MAX_QUEUE_DEPTH", "100")) +# Completed jobs are evicted from the in-memory registry past this age so the +# dict doesn't grow without bound. +JOB_TTL_SECONDS = int(os.environ.get("JOB_TTL_SECONDS", "3600")) +# Bursts share one base fetch rather than serialising a network round-trip per +# job behind the git lock. +FETCH_DEBOUNCE_SECONDS = int(os.environ.get("FETCH_DEBOUNCE_SECONDS", "15")) # OpenAI compat: model selection is per-request so callers can pick # Haiku/Sonnet/Opus to control cost. The agent is fixed — `recruiter-triage` @@ -44,8 +64,18 @@ OPENAI_COMPAT_AGENT = "recruiter-triage" OPENAI_COMPAT_BUDGET_USD = 2.0 OPENAI_COMPAT_TIMEOUT_SECONDS = 900 +_TERMINAL_STATUSES = frozenset({"completed", "failed", "timeout", "error"}) + jobs: dict[str, dict] = {} -execution_lock = asyncio.Lock() + +# Concurrency primitives. The semaphore bounds simultaneous executions; the git +# lock is held only for the fast per-job workspace setup/teardown (fetch + +# local clone + unlock + rm), NOT for the agent run itself. +execution_semaphore = asyncio.Semaphore(MAX_CONCURRENCY) +git_lock = asyncio.Lock() +inflight_active = 0 +inflight_queued = 0 +_last_fetch_epoch = 0.0 class ExecuteRequest(BaseModel): @@ -87,34 +117,139 @@ def verify_token(authorization: str | None): raise HTTPException(status_code=401, detail="Invalid token") -async def run_git_sync(): +def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _reserve_queue_slot() -> bool: + """Admit a call into the queue, or refuse it if the queue is saturated. + + Returns False when active + queued already fills MAX_QUEUE_DEPTH — the + caller should then turn the request away (429/503). + """ + global inflight_queued + if inflight_active + inflight_queued >= MAX_QUEUE_DEPTH: + return False + inflight_queued += 1 + return True + + +@asynccontextmanager +async def _execution_slot(): + """Hold one concurrency permit for the duration of an agent run. + + The caller must have reserved a queue slot via `_reserve_queue_slot()` + first; this moves it from queued -> active on acquire and always releases. + """ + global inflight_active, inflight_queued + acquired = False + try: + await execution_semaphore.acquire() + acquired = True + inflight_queued -= 1 + inflight_active += 1 + yield + finally: + if acquired: + inflight_active -= 1 + execution_semaphore.release() + else: + # Cancelled while still waiting in the queue. + inflight_queued -= 1 + + +def _evict_old_jobs() -> None: + now = time.time() + stale = [ + jid for jid, job in jobs.items() + if job.get("status") in _TERMINAL_STATUSES + and now - job.get("finished_epoch", now) > JOB_TTL_SECONDS + ] + for jid in stale: + jobs.pop(jid, None) + + +async def _run(*cmd: str, cwd: str | None = None, timeout: float | None = None, + check: bool = True, capture: bool = False) -> tuple[int, str]: + """Run a subprocess (no shell), optionally capturing stdout. Raises on + non-zero unless `check=False`. Used for the git/git-crypt/rm steps of + per-job workspace setup.""" proc = await asyncio.create_subprocess_exec( - "git", "pull", "--rebase", - cwd=WORKSPACE_DIR, - stdout=PIPE, stderr=PIPE, + *cmd, cwd=cwd, stdout=PIPE, stderr=PIPE, ) - await proc.wait() + try: + out, err = await asyncio.wait_for(proc.communicate(), timeout=timeout) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + raise + rc = proc.returncode or 0 + if check and rc != 0: + raise RuntimeError(f"{cmd[0]} failed ({rc}): {err.decode(errors='replace')[:200]}") + return rc, (out.decode(errors="replace") if capture else "") + + +async def _refresh_base() -> None: + """Pull the base clone up to origin/master, debounced so a burst of jobs + shares one fetch. Failures are tolerated — jobs run against the last good + base rather than wedging on a transient network blip.""" + global _last_fetch_epoch + now = time.time() + if now - _last_fetch_epoch < FETCH_DEBOUNCE_SECONDS: + return + _last_fetch_epoch = now + await _run("git", "-C", BASE_DIR, "fetch", "origin", "--prune", + timeout=120, check=False) + await _run("git", "-C", BASE_DIR, "reset", "--hard", "origin/master", + check=False) + + +async def prepare_workspace(job_id: str) -> str: + """Create an isolated git checkout for one job and return its path. + + A local clone of the warm base hardlinks the object store (near-free) and + carries only tracked files (no stale .terraform). The git lock is held just + for this fast setup, never for the agent run. + """ + job_dir = os.path.join(JOBS_DIR, job_id) + async with git_lock: + await _refresh_base() + await _run("git", "clone", "--local", BASE_DIR, job_dir) + rc, base_origin = await _run( + "git", "-C", BASE_DIR, "remote", "get-url", "origin", + check=False, capture=True, + ) + if rc == 0 and base_origin.strip(): + await _run("git", "-C", job_dir, "remote", "set-url", "origin", + base_origin.strip(), check=False) + if GIT_CRYPT_KEY and os.path.exists(GIT_CRYPT_KEY): + await _run("git-crypt", "unlock", GIT_CRYPT_KEY, cwd=job_dir, check=False) + return job_dir + + +async def cleanup_workspace(path: str | None) -> None: + if not path: + return + await _run("rm", "-rf", path, check=False) async def _invoke_claude_subprocess( prompt: str, agent: str, max_budget_usd: float, + workspace: str, model: str | None = None, ) -> dict[str, Any]: - """Run the claude CLI once and return a result dict. + """Run the claude CLI once in `workspace` and return a result dict. - The caller is responsible for holding `execution_lock` for the duration — - this helper does not touch the lock or the `jobs` dict, so it can be - shared by both the background `/execute` path and the synchronous - `/v1/chat/completions` path. + Holds no lock and does not touch the `jobs` dict, so it is shared by both + the background `/execute` path and the synchronous `/v1/chat/completions` + path. The caller provides an isolated `workspace` (one per job) as cwd. `model`, when provided, becomes `--model ` on the claude CLI. This overrides whatever `model:` is set in the agent's frontmatter so the OpenAI-compat path can pick Haiku/Sonnet/Opus per-request. """ - await run_git_sync() - cmd = [ "claude", "-p", "--agent", agent, @@ -128,11 +263,13 @@ async def _invoke_claude_subprocess( proc = await asyncio.create_subprocess_exec( *cmd, - cwd=WORKSPACE_DIR, + cwd=workspace, stdout=PIPE, stderr=PIPE, ) + # stdout=PIPE / stderr=PIPE guarantee both streams are present. + assert proc.stdout is not None and proc.stderr is not None output_lines: list[str] = [] async for line in proc.stdout: output_lines.append(line.decode()) @@ -147,24 +284,49 @@ async def _invoke_claude_subprocess( } -async def run_agent(job_id: str, request: ExecuteRequest): +async def _run_execute_job(job_id: str, request: ExecuteRequest): + """Background worker for /execute: waits for a slot (queued), then runs the + agent in an isolated workspace. The timeout covers execution only, never + the time spent waiting in the queue.""" + workspace = None try: - result = await _invoke_claude_subprocess( - request.prompt, request.agent, request.max_budget_usd, - ) - jobs[job_id].update({ - "status": "completed" if result["exit_code"] == 0 else "failed", - "exit_code": result["exit_code"], - "output": result["output"], - "stderr": result["stderr"], - "finished_at": datetime.now(timezone.utc).isoformat(), - }) + async with _execution_slot(): + jobs[job_id]["status"] = "running" + jobs[job_id]["started_at"] = _now_iso() + workspace = await prepare_workspace(job_id) + result = await asyncio.wait_for( + _invoke_claude_subprocess( + request.prompt, request.agent, request.max_budget_usd, workspace, + ), + timeout=request.timeout_seconds, + ) + jobs[job_id].update({ + "status": "completed" if result["exit_code"] == 0 else "failed", + "exit_code": result["exit_code"], + "output": result["output"], + "stderr": result["stderr"], + "finished_at": _now_iso(), + "finished_epoch": time.time(), + }) except asyncio.TimeoutError: - jobs[job_id].update({"status": "timeout"}) + jobs[job_id].update({ + "status": "timeout", + "finished_at": _now_iso(), + "finished_epoch": time.time(), + }) except Exception as exc: - jobs[job_id].update({"status": "error", "error": str(exc)}) + jobs[job_id].update({ + "status": "error", + "error": str(exc), + "finished_at": _now_iso(), + "finished_epoch": time.time(), + }) finally: - execution_lock.release() + try: + await cleanup_workspace(workspace) + except Exception: + pass + _evict_old_jobs() def _extract_assistant_text(output_lines: list[str]) -> str: @@ -222,7 +384,13 @@ def _synthesise_prompt(messages: list[ChatMessage]) -> str: @app.get("/health") async def health(): - return {"status": "ok", "busy": execution_lock.locked()} + return { + "status": "ok", + "busy": inflight_active >= MAX_CONCURRENCY, + "active": inflight_active, + "queued": inflight_queued, + "capacity": MAX_CONCURRENCY, + } @app.post("/execute", status_code=202) @@ -232,28 +400,21 @@ async def execute( ): verify_token(authorization) - if execution_lock.locked(): - raise HTTPException(status_code=409, detail="Agent is busy") - - await execution_lock.acquire() + if not _reserve_queue_slot(): + raise HTTPException(status_code=429, detail="Queue full") job_id = uuid.uuid4().hex[:12] jobs[job_id] = { - "status": "running", + "status": "queued", "prompt": request.prompt, "agent": request.agent, - "started_at": datetime.now(timezone.utc).isoformat(), + "created_at": _now_iso(), "metadata": request.metadata, } - asyncio.create_task( - asyncio.wait_for( - run_agent(job_id, request), - timeout=request.timeout_seconds, - ) - ) + asyncio.create_task(_run_execute_job(job_id, request)) - return {"job_id": job_id, "status": "running"} + return {"job_id": job_id, "status": "queued"} @app.get("/jobs/{job_id}") @@ -289,33 +450,39 @@ async def chat_completions( prompt = _synthesise_prompt(request.messages) - if execution_lock.locked(): + if not _reserve_queue_slot(): return JSONResponse( status_code=503, - content={"error": "execution failed", "detail": "agent is busy"}, + content={"error": "execution failed", "detail": "queue full"}, ) - await execution_lock.acquire() + chat_id = uuid.uuid4().hex[:12] + workspace = None try: - try: + async with _execution_slot(): + workspace = await prepare_workspace(chat_id) result = await asyncio.wait_for( _invoke_claude_subprocess( - prompt, OPENAI_COMPAT_AGENT, OPENAI_COMPAT_BUDGET_USD, model=model, + prompt, OPENAI_COMPAT_AGENT, OPENAI_COMPAT_BUDGET_USD, + workspace, model=model, ), timeout=OPENAI_COMPAT_TIMEOUT_SECONDS, ) - except asyncio.TimeoutError: - return JSONResponse( - status_code=503, - content={"error": "execution failed", "detail": "agent timed out"}, - ) - except Exception as exc: - return JSONResponse( - status_code=503, - content={"error": "execution failed", "detail": _one_line(str(exc))}, - ) + except asyncio.TimeoutError: + return JSONResponse( + status_code=503, + content={"error": "execution failed", "detail": "agent timed out"}, + ) + except Exception as exc: + return JSONResponse( + status_code=503, + content={"error": "execution failed", "detail": _one_line(str(exc))}, + ) finally: - execution_lock.release() + try: + await cleanup_workspace(workspace) + except Exception: + pass if result["exit_code"] != 0: detail = _one_line(result.get("stderr") or "") or f"exit {result['exit_code']}" diff --git a/docs/2026-06-02-parallel-execution-design.md b/docs/2026-06-02-parallel-execution-design.md new file mode 100644 index 0000000..7ef0070 --- /dev/null +++ b/docs/2026-06-02-parallel-execution-design.md @@ -0,0 +1,124 @@ +# Parallel, independent execution — design + +**Date:** 2026-06-02 +**Status:** approved, in implementation +**Scope:** `claude-agent-service` — remove the single-flight execution lock so +multiple agent calls run concurrently, each in its own isolated workspace. + +## Problem + +Today a single global `asyncio.Lock` (`execution_lock`) serializes **every** +agent invocation: + +- `POST /execute` returns `409 Agent is busy` when a job is in flight. +- `POST /v1/chat/completions` returns `503 agent is busy` likewise. +- All calls run `claude -p` with `cwd=/workspace/infra` — one shared working + tree, `git pull --rebase`'d before each call. + +The lock exists because two `claude -p` processes in the *same* working tree +would clobber each other's file edits and git state (`.git/index.lock` +contention, racing `git pull --rebase`). + +## Goal + +Run calls **in parallel**, each **fully independent** of the others, without +the git/file collisions that the lock currently prevents — on a single pod +(`replicas=1`), keeping the in-memory job registry coherent for `/jobs/{id}` +polling. + +## Design + +### Workspace isolation — per-job local clone + +Each job gets its **own git checkout** so file edits and git operations never +touch another job's state: + +1. A warm **base clone** lives at `/workspace/base` (created by the existing + init container; renamed from `/workspace/infra`), git-crypt-unlocked. +2. Per job, under a short-held `git_lock`: + - Debounced `git fetch origin && git reset --hard origin/master` on the base + (skipped if fetched within `FETCH_DEBOUNCE_SECONDS`) so bursts share one + network fetch. + - `git clone --local /workspace/base /workspace/jobs/` — objects are + hardlinked (near-free disk, no `.terraform` carried since clone takes + tracked content only). + - Re-point `origin` to the GitHub URL and `git-crypt unlock ` in the + job dir. +3. The job runs `claude -p` with `cwd=/workspace/jobs/` **holding no lock**. +4. `finally` → `rm -rf /workspace/jobs/`. + +`git_lock` is held only for the fast setup/teardown (~<2 s); execution is fully +parallel. Rejected alternatives: **git worktree** (shares one `.git` → agents +that `git commit`/`pull` still contend — not truly independent) and **`cp -a`** +(copies accumulated `.terraform` provider caches → disk blowup). + +Distinct `cwd` per job also isolates Claude CLI per-project state +(`~/.claude/projects//`). The long-lived `CLAUDE_CODE_OAUTH_TOKEN` +avoids credential-file write races in the shared `~/.claude`. + +### Concurrency model + +- `execution_semaphore = asyncio.Semaphore(MAX_CONCURRENCY)` replaces + `execution_lock`. Default **`MAX_CONCURRENCY=10`** ("soft-unbounded"). +- Requests beyond the limit **queue FIFO** (asyncio fairness) — they are not + rejected. +- `MAX_QUEUE_DEPTH` safety valve (default **100**): if `active + queued` exceeds + it, reject (`429` on `/execute`, `503` on chat) to bound memory. +- A `concurrency_slot()` async context manager wraps acquire/release and keeps + `inflight_active` / `inflight_queued` counters for `/health`. + +### Endpoint behavior + +| Endpoint | Before | After | +|---|---|---| +| `POST /execute` | `202` or `409` busy | `202` always (unless queue full → `429`); job `status="queued"` until a slot frees, then `running`. **Timeout clock starts on execution, not queue-wait.** | +| `POST /v1/chat/completions` | `200` or `503` busy | **queues** for a slot (caller waits, bounded by the 900 s timeout); still `503` on execution failure/timeout or if queue full | +| `GET /jobs/{id}` | unchanged | unchanged (can now report `queued`) | +| `GET /health` | `{status, busy=lock.locked()}` | `{status, busy=(active>=capacity), active, queued, capacity}` — keeps BeadBoard `/api/agent-status` + beads-dispatcher working | + +### Housekeeping + +- **Job eviction**: completed/failed/timeout/error jobs older than + `JOB_TTL_SECONDS` (default 3600) are evicted; the in-memory `jobs` dict + currently grows unbounded and parallelism increases churn. +- Pod restart still loses in-flight jobs (pre-existing; out of scope — no + shared store, matching the in-pod decision). + +### Infra (`infra/stacks/claude-agent-service/main.tf`) + +- Mount the existing `git-crypt-key` configmap into the **main container** + (today only the init container has it) — needed for per-job unlock. +- Pod memory: request `2Gi`, limit `12Gi` (Burstable, tier-aux); CPU request + `1`, no CPU limit. Fits node2/3/5 headroom (~22–26 GB free). +- Wire `MAX_CONCURRENCY` env. Rename init-container clone target to + `/workspace/base`; `WORKSPACE_DIR`→ base path. +- `replicas=1`, `Recreate` unchanged. + +## Blast radius (verified) + +All callers handle the busy responses gracefully or fail safely, so removing +them is safe: + +- **n8n DIUN** (`/execute`) — rate-limited 5/6h, no retry; 409 was rare. +- **payslip-ingest** (`/execute`+poll) — 90× retry; big win from parallelism. +- **recruiter-responder** (`/execute`+poll) — returns `busy`, OpenClaw retries. +- **fire-planner** (`/v1/chat/completions`) — client-side semaphore; can be + relaxed after this. +- **BeadBoard** (`/execute`) — UI shows busy via `/api/agent-status` (`/health`). +- **beads-dispatcher** CronJob — gates on `/health` busy; 2-min tick. + +## Testing (TDD) + +Rewrite `test_execute_respects_sequential_lock` and +`test_chat_completions_returns_503_when_agent_busy` (they encode the removed +behavior). New tests: two concurrent `/execute` both run; safety-queue at +`MAX_CONCURRENCY=2`; concurrent chat-completions both run; `/health` capacity +fields; per-job distinct workspace `cwd`; timeout excludes queue-wait; job +eviction; queue-depth rejection. An autouse fixture resets semaphore + counters ++ jobs between tests. + +## Docs to update (same change) + +`infra/docs/architecture/automated-upgrades.md`, +`infra/docs/runbooks/beads-auto-dispatch.md`, `infra/AGENTS.md`, root +`CLAUDE.md` — all currently describe "sequential / single-slot". diff --git a/tests/conftest.py b/tests/conftest.py index 6df2255..b08a72f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,45 @@ +import asyncio import os os.environ.setdefault("API_BEARER_TOKEN", "test-token") os.environ.setdefault("WORKSPACE_DIR", "/tmp/test-workspace") + +import pytest + +from app import main as app_main + + +@pytest.fixture(autouse=True) +def _reset_execution_state(): + """Reset concurrency state between tests. + + A fresh semaphore per test avoids the "bound to a different event loop" + error (pytest-asyncio uses a new loop per function), and clearing the + counters/jobs keeps tests independent. + """ + app_main.jobs.clear() + app_main.inflight_active = 0 + app_main.inflight_queued = 0 + app_main.execution_semaphore = asyncio.Semaphore(app_main.MAX_CONCURRENCY) + app_main._last_fetch_epoch = 0.0 + app_main.MAX_QUEUE_DEPTH = int(os.environ.get("MAX_QUEUE_DEPTH", "100")) + yield + + +@pytest.fixture +def drain(): + """Wait for all background /execute jobs to finish. + + Tests that fire `/execute` must drain before leaving the `patch(...)` + context — otherwise a background task resumes after the mocks are torn + down, spawns a real subprocess during loop teardown, and deadlocks the + asyncio child-watcher. + """ + async def _drain(timeout: float = 3.0): + loop = asyncio.get_event_loop() + deadline = loop.time() + timeout + while app_main.inflight_active or app_main.inflight_queued: + if loop.time() > deadline: + break + await asyncio.sleep(0.01) + return _drain diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000..ad3766c --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,223 @@ +"""Tests for parallel, independent execution. + +These exercise the post-lock behavior: multiple agent calls run concurrently, +each in its own workspace, with a bounded semaphore + FIFO queue instead of a +single-flight lock. +""" +import asyncio + +import pytest +from unittest.mock import AsyncMock, patch +from httpx import ASGITransport, AsyncClient + +from app import main as app_main +from app.main import app + + +@pytest.fixture +def auth_header(): + return {"Authorization": "Bearer test-token"} + + +class _BlockingStdout: + """async-iterable stdout that blocks on first read until `release` is set, + then ends with no output — mimics a long-running `claude -p`.""" + + def __init__(self, release: asyncio.Event): + self._release = release + self._done = False + + def __aiter__(self): + return self + + async def __anext__(self): + if self._done: + raise StopAsyncIteration + await self._release.wait() + self._done = True + raise StopAsyncIteration + + +class ConcurrencyProbe: + """Tracks how many mock subprocesses have started, and gates their exit.""" + + def __init__(self): + self.started = 0 + self.release = asyncio.Event() + + def factory(self): + async def make(*args, **kwargs): + self.started += 1 + mock = AsyncMock() + mock.stdout = _BlockingStdout(self.release) + mock.stderr = AsyncMock() + mock.stderr.read = AsyncMock(return_value=b"") + mock.wait = AsyncMock(return_value=0) + mock.returncode = 0 + return mock + return make + + async def wait_started(self, n: int, timeout: float = 2.0): + deadline = asyncio.get_event_loop().time() + timeout + while self.started < n: + if asyncio.get_event_loop().time() > deadline: + break + await asyncio.sleep(0.01) + + +def _patch_workspace(): + """Patch the per-job workspace seams so no real git runs.""" + return ( + patch("app.main.prepare_workspace", new=AsyncMock(return_value="/tmp/ws")), + patch("app.main.cleanup_workspace", new=AsyncMock()), + ) + + +@pytest.mark.asyncio +async def test_execute_does_not_return_409_when_a_job_is_running(auth_header, drain): + """A second /execute must NOT be rejected with 409 while one is in flight.""" + probe = ConcurrencyProbe() + pw, cw = _patch_workspace() + with pw, cw, patch("app.main.asyncio.create_subprocess_exec", side_effect=probe.factory()): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r1 = await client.post("/execute", json={"prompt": "a", "agent": "x"}, headers=auth_header) + await probe.wait_started(1) + r2 = await client.post("/execute", json={"prompt": "b", "agent": "y"}, headers=auth_header) + probe.release.set() + await drain() + assert r1.status_code == 202 + assert r2.status_code == 202 + + +@pytest.mark.asyncio +async def test_two_execute_jobs_run_concurrently(auth_header, drain): + """Two /execute jobs run their subprocesses at the same time (not serialized).""" + probe = ConcurrencyProbe() + pw, cw = _patch_workspace() + with pw, cw, patch("app.main.asyncio.create_subprocess_exec", side_effect=probe.factory()): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + await client.post("/execute", json={"prompt": "a", "agent": "x"}, headers=auth_header) + await client.post("/execute", json={"prompt": "b", "agent": "y"}, headers=auth_header) + await probe.wait_started(2) + both_running = probe.started >= 2 + probe.release.set() + await drain() + assert both_running, "both jobs should have started before either finished" + + +@pytest.mark.asyncio +async def test_safety_queue_blocks_beyond_capacity(auth_header, drain): + """With capacity=1, the 2nd job is accepted but stays queued until a slot frees.""" + app_main.execution_semaphore = asyncio.Semaphore(1) + probe = ConcurrencyProbe() + pw, cw = _patch_workspace() + with pw, cw, patch("app.main.asyncio.create_subprocess_exec", side_effect=probe.factory()): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r1 = await client.post("/execute", json={"prompt": "a", "agent": "x"}, headers=auth_header) + await probe.wait_started(1) + r2 = await client.post("/execute", json={"prompt": "b", "agent": "y"}, headers=auth_header) + # Give the 2nd task a chance to (not) start — capacity is 1. + await asyncio.sleep(0.05) + only_one_started = probe.started == 1 + job2 = (await client.get(f"/jobs/{r2.json()['job_id']}", headers=auth_header)).json() + probe.release.set() + await drain() + assert r1.status_code == 202 + assert r2.status_code == 202 + assert only_one_started, "2nd job must wait while capacity is full" + assert job2["status"] == "queued" + + +@pytest.mark.asyncio +async def test_two_chat_completions_run_concurrently(auth_header): + """Concurrent /v1/chat/completions both run — no 503 busy.""" + probe = ConcurrencyProbe() + pw, cw = _patch_workspace() + with pw, cw, patch("app.main.asyncio.create_subprocess_exec", side_effect=probe.factory()): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + payload = {"model": "haiku", "messages": [{"role": "user", "content": "hi"}]} + t1 = asyncio.create_task(client.post("/v1/chat/completions", json=payload, headers=auth_header)) + t2 = asyncio.create_task(client.post("/v1/chat/completions", json=payload, headers=auth_header)) + await probe.wait_started(2) + both_running = probe.started >= 2 + probe.release.set() + r1, r2 = await asyncio.gather(t1, t2) + assert both_running, "both chat calls should run concurrently" + assert r1.status_code == 200 + assert r2.status_code == 200 + + +@pytest.mark.asyncio +async def test_health_reports_capacity_fields(): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + body = (await client.get("/health")).json() + assert body["status"] == "ok" + assert body["busy"] is False + assert body["active"] == 0 + assert body["queued"] == 0 + assert body["capacity"] == app_main.MAX_CONCURRENCY + + +@pytest.mark.asyncio +async def test_each_job_gets_distinct_workspace(auth_header, drain): + """prepare_workspace is called per job with the job id, yielding distinct cwds.""" + seen_job_ids = [] + + async def fake_prepare(job_id): + seen_job_ids.append(job_id) + return f"/tmp/ws/{job_id}" + + probe = ConcurrencyProbe() + with patch("app.main.prepare_workspace", side_effect=fake_prepare), \ + patch("app.main.cleanup_workspace", new=AsyncMock()), \ + patch("app.main.asyncio.create_subprocess_exec", side_effect=probe.factory()): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + await client.post("/execute", json={"prompt": "a", "agent": "x"}, headers=auth_header) + await client.post("/execute", json={"prompt": "b", "agent": "y"}, headers=auth_header) + await probe.wait_started(2) + probe.release.set() + await drain() + assert len(set(seen_job_ids)) == 2, "each job should prepare its own workspace" + + +@pytest.mark.asyncio +async def test_queue_depth_rejection(auth_header, drain): + """Beyond MAX_QUEUE_DEPTH, /execute is rejected with 429.""" + app_main.execution_semaphore = asyncio.Semaphore(1) + app_main.MAX_QUEUE_DEPTH = 2 + probe = ConcurrencyProbe() + pw, cw = _patch_workspace() + with pw, cw, patch("app.main.asyncio.create_subprocess_exec", side_effect=probe.factory()): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r1 = await client.post("/execute", json={"prompt": "a", "agent": "x"}, headers=auth_header) + await probe.wait_started(1) + r2 = await client.post("/execute", json={"prompt": "b", "agent": "y"}, headers=auth_header) + r3 = await client.post("/execute", json={"prompt": "c", "agent": "z"}, headers=auth_header) + probe.release.set() + await drain() + assert r1.status_code == 202 # active + assert r2.status_code == 202 # queued + assert r3.status_code == 429 # over depth + + +def test_evict_old_jobs_drops_finished_past_ttl(): + """Completed jobs older than JOB_TTL are evicted; running/queued are kept.""" + import time + app_main.jobs.clear() + now = time.time() + app_main.jobs["old"] = {"status": "completed", "finished_epoch": now - 99999} + app_main.jobs["fresh"] = {"status": "completed", "finished_epoch": now} + app_main.jobs["running"] = {"status": "running"} + app_main.jobs["queued"] = {"status": "queued"} + app_main._evict_old_jobs() + assert "old" not in app_main.jobs + assert "fresh" in app_main.jobs + assert "running" in app_main.jobs + assert "queued" in app_main.jobs diff --git a/tests/test_main.py b/tests/test_main.py index cd13a65..3975cec 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,3 @@ -import asyncio from unittest.mock import AsyncMock, patch, MagicMock import pytest @@ -65,7 +64,7 @@ async def test_execute_rejects_missing_prompt(auth_header): @pytest.mark.asyncio -async def test_execute_starts_job(auth_header): +async def test_execute_starts_job(auth_header, drain): mock_process = AsyncMock() mock_process.stdout = AsyncMock() mock_process.stdout.__aiter__ = MagicMock(return_value=iter([])) @@ -75,7 +74,8 @@ async def test_execute_starts_job(auth_header): mock_process.returncode = 0 with patch("app.main.asyncio.create_subprocess_exec", return_value=mock_process): - with patch("app.main.run_git_sync", new_callable=AsyncMock): + with patch("app.main.prepare_workspace", new=AsyncMock(return_value="/tmp/ws")), \ + patch("app.main.cleanup_workspace", new=AsyncMock()): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( @@ -83,10 +83,11 @@ async def test_execute_starts_job(auth_header): json={"prompt": "test prompt", "agent": "test-agent"}, headers=auth_header, ) + await drain() assert response.status_code == 202 body = response.json() assert "job_id" in body - assert body["status"] == "running" + assert body["status"] == "queued" @pytest.mark.asyncio @@ -98,7 +99,7 @@ async def test_get_job_not_found(auth_header): @pytest.mark.asyncio -async def test_execute_stores_metadata_on_job(auth_header): +async def test_execute_stores_metadata_on_job(auth_header, drain): mock_process = AsyncMock() mock_process.stdout = AsyncMock() mock_process.stdout.__aiter__ = MagicMock(return_value=iter([])) @@ -110,7 +111,8 @@ async def test_execute_stores_metadata_on_job(auth_header): metadata = {"task_id": "code-xyz", "source": "beadboard"} with patch("app.main.asyncio.create_subprocess_exec", return_value=mock_process): - with patch("app.main.run_git_sync", new_callable=AsyncMock): + with patch("app.main.prepare_workspace", new=AsyncMock(return_value="/tmp/ws")), \ + patch("app.main.cleanup_workspace", new=AsyncMock()): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( @@ -126,55 +128,11 @@ async def test_execute_stores_metadata_on_job(auth_header): job_id = response.json()["job_id"] job_response = await client.get(f"/jobs/{job_id}", headers=auth_header) + await drain() assert job_response.status_code == 200 assert job_response.json()["metadata"] == metadata -@pytest.mark.asyncio -async def test_execute_respects_sequential_lock(auth_header): - hold_event = asyncio.Event() - release_event = asyncio.Event() - - async def slow_subprocess(*args, **kwargs): - mock = AsyncMock() - mock.stdout = AsyncMock() - - async def slow_iter(): - hold_event.set() - await release_event.wait() - return - yield # noqa: F841 - unreachable yield makes this an async generator - - mock.stdout.__aiter__ = MagicMock(side_effect=slow_iter) - mock.stderr = AsyncMock() - mock.stderr.read = AsyncMock(return_value=b"") - mock.wait = AsyncMock(return_value=0) - mock.returncode = 0 - return mock - - with patch("app.main.asyncio.create_subprocess_exec", side_effect=slow_subprocess): - with patch("app.main.run_git_sync", new_callable=AsyncMock): - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as client: - task1 = asyncio.create_task(client.post( - "/execute", - json={"prompt": "first", "agent": "agent1"}, - headers=auth_header, - )) - await hold_event.wait() - - response2 = await client.post( - "/execute", - json={"prompt": "second", "agent": "agent2"}, - headers=auth_header, - ) - assert response2.status_code == 409 - - release_event.set() - response1 = await task1 - assert response1.status_code == 202 - - @pytest.mark.asyncio async def test_execute_rejects_empty_api_token_header(): # When the service is booted without an API_BEARER_TOKEN (misconfiguration), @@ -192,7 +150,7 @@ async def test_execute_rejects_empty_api_token_header(): @pytest.mark.asyncio -async def test_execute_accepts_correct_bearer_token(): +async def test_execute_accepts_correct_bearer_token(drain): mock_process = AsyncMock() mock_process.stdout = AsyncMock() mock_process.stdout.__aiter__ = MagicMock(return_value=iter([])) @@ -203,7 +161,8 @@ async def test_execute_accepts_correct_bearer_token(): with patch.object(app_main, "API_TOKEN", "secret"): with patch("app.main.asyncio.create_subprocess_exec", return_value=mock_process): - with patch("app.main.run_git_sync", new_callable=AsyncMock): + with patch("app.main.prepare_workspace", new=AsyncMock(return_value="/tmp/ws")), \ + patch("app.main.cleanup_workspace", new=AsyncMock()): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( @@ -211,4 +170,5 @@ async def test_execute_accepts_correct_bearer_token(): json={"prompt": "test", "agent": "test-agent"}, headers={"Authorization": "Bearer secret"}, ) + await drain() assert response.status_code == 202 diff --git a/tests/test_openai_compat.py b/tests/test_openai_compat.py index 45bf1f6..3441972 100644 --- a/tests/test_openai_compat.py +++ b/tests/test_openai_compat.py @@ -5,7 +5,6 @@ from unittest.mock import AsyncMock, patch import pytest from httpx import ASGITransport, AsyncClient -from app import main as app_main from app.main import app @@ -61,7 +60,8 @@ async def test_chat_completions_happy_path(auth_header): mock_proc = _mock_subprocess_returning(cli_output, returncode=0) with patch("app.main.asyncio.create_subprocess_exec", return_value=mock_proc), \ - patch("app.main.run_git_sync", new_callable=AsyncMock): + patch("app.main.prepare_workspace", new=AsyncMock(return_value="/tmp/ws")), \ + patch("app.main.cleanup_workspace", new=AsyncMock()): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( @@ -154,7 +154,8 @@ async def test_chat_completions_returns_503_on_job_failure(auth_header): mock_proc.stderr.read = AsyncMock(return_value=b"boom") with patch("app.main.asyncio.create_subprocess_exec", return_value=mock_proc), \ - patch("app.main.run_git_sync", new_callable=AsyncMock): + patch("app.main.prepare_workspace", new=AsyncMock(return_value="/tmp/ws")), \ + patch("app.main.cleanup_workspace", new=AsyncMock()): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( @@ -193,7 +194,8 @@ async def test_chat_completions_falls_back_when_no_json_result(auth_header): mock_proc = _mock_subprocess_returning(b"plain non-json output", returncode=0) with patch("app.main.asyncio.create_subprocess_exec", return_value=mock_proc), \ - patch("app.main.run_git_sync", new_callable=AsyncMock): + patch("app.main.prepare_workspace", new=AsyncMock(return_value="/tmp/ws")), \ + patch("app.main.cleanup_workspace", new=AsyncMock()): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( @@ -222,7 +224,8 @@ async def test_chat_completions_concats_system_and_user_messages(auth_header): ) with patch("app.main.asyncio.create_subprocess_exec", side_effect=fake_subprocess), \ - patch("app.main.run_git_sync", new_callable=AsyncMock): + patch("app.main.prepare_workspace", new=AsyncMock(return_value="/tmp/ws")), \ + patch("app.main.cleanup_workspace", new=AsyncMock()): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( @@ -242,28 +245,6 @@ async def test_chat_completions_concats_system_and_user_messages(auth_header): assert "USER-MARKER" in prompt_arg -@pytest.mark.asyncio -async def test_chat_completions_returns_503_when_agent_busy(auth_header): - """If the agent is already busy, return 503.""" - await app_main.execution_lock.acquire() - try: - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as client: - response = await client.post( - "/v1/chat/completions", - json={ - "model": "haiku", - "messages": [{"role": "user", "content": "hi"}], - }, - headers=auth_header, - ) - finally: - app_main.execution_lock.release() - assert response.status_code == 503 - body = response.json() - assert body.get("error") == "execution failed" - - async def _capture_subprocess_args( auth_header: dict, payload: dict, @@ -283,7 +264,8 @@ async def _capture_subprocess_args( ) with patch("app.main.asyncio.create_subprocess_exec", side_effect=fake_subprocess), \ - patch("app.main.run_git_sync", new_callable=AsyncMock): + patch("app.main.prepare_workspace", new=AsyncMock(return_value="/tmp/ws")), \ + patch("app.main.cleanup_workspace", new=AsyncMock()): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post(