diff --git a/app/main.py b/app/main.py index a77b370..0efb37f 100644 --- a/app/main.py +++ b/app/main.py @@ -1,18 +1,37 @@ import asyncio import hmac +import json import os +import time import uuid from datetime import datetime, timezone from subprocess import PIPE +from typing import Any, Literal from fastapi import FastAPI, HTTPException, Header -from pydantic import BaseModel +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field app = FastAPI(title="Claude Agent Service") API_TOKEN = os.environ.get("API_BEARER_TOKEN", "") WORKSPACE_DIR = os.environ.get("WORKSPACE_DIR", "/workspace/infra") +# OpenAI compat: model -> agent mapping. v1 keeps it dead simple — all models +# route to the most general agent we have. `recruiter-triage` has the broadest +# tool surface (WebSearch, WebFetch, Read, Grep, Glob, Bash); the alternative +# (`beads-task-runner`) is locked to read-only `bd` verbs which would fail +# arbitrary OpenAI-API callers. Revisit when a true general-purpose agent +# lands in `agents/`. +MODEL_TO_AGENT: dict[str, str] = { + "claude-haiku-4-5": "recruiter-triage", + "claude-sonnet-4-6": "recruiter-triage", + "claude-opus-4-7": "recruiter-triage", +} +AGENT_DEFAULT = "recruiter-triage" +OPENAI_COMPAT_BUDGET_USD = 2.0 +OPENAI_COMPAT_TIMEOUT_SECONDS = 900 + jobs: dict[str, dict] = {} execution_lock = asyncio.Lock() @@ -25,6 +44,21 @@ class ExecuteRequest(BaseModel): metadata: dict | None = None +class ChatMessage(BaseModel): + role: Literal["system", "user", "assistant"] + content: str + + +class ChatCompletionsRequest(BaseModel): + model: str + messages: list[ChatMessage] = Field(..., min_length=1) + max_tokens: int | None = None + temperature: float | None = None + stream: bool = False + # Tolerate (and ignore) other OpenAI fields rather than 422-ing on them. + model_config = {"extra": "allow"} + + def verify_token(authorization: str | None): # Reject everything when the service is unconfigured. compare_digest("", "") # returns True, so without this guard an empty API_TOKEN would happily @@ -47,38 +81,60 @@ async def run_git_sync(): await proc.wait() +async def _invoke_claude_subprocess( + prompt: str, + agent: str, + max_budget_usd: float, +) -> dict[str, Any]: + """Run the claude CLI once and return a result dict. + + The caller is responsible for holding `execution_lock` for the duration — + this helper does not touch the lock or the `jobs` dict, so it can be + shared by both the background `/execute` path and the synchronous + `/v1/chat/completions` path. + """ + await run_git_sync() + + cmd = [ + "claude", "-p", + "--agent", agent, + "--dangerously-skip-permissions", + "--max-budget-usd", str(max_budget_usd), + "--output-format", "json", + prompt, + ] + + proc = await asyncio.create_subprocess_exec( + *cmd, + cwd=WORKSPACE_DIR, + stdout=PIPE, + stderr=PIPE, + ) + + output_lines: list[str] = [] + async for line in proc.stdout: + output_lines.append(line.decode()) + + stderr = await proc.stderr.read() + await proc.wait() + + return { + "exit_code": proc.returncode, + "output": output_lines, + "stderr": stderr.decode(), + } + + async def run_agent(job_id: str, request: ExecuteRequest): try: - await run_git_sync() - - cmd = [ - "claude", "-p", - "--agent", request.agent, - "--dangerously-skip-permissions", - "--max-budget-usd", str(request.max_budget_usd), - "--output-format", "json", - request.prompt, - ] - - proc = await asyncio.create_subprocess_exec( - *cmd, - cwd=WORKSPACE_DIR, - stdout=PIPE, - stderr=PIPE, + result = await _invoke_claude_subprocess( + request.prompt, request.agent, request.max_budget_usd, ) - - output_lines = [] - async for line in proc.stdout: - output_lines.append(line.decode()) - - stderr = await proc.stderr.read() - await proc.wait() - jobs[job_id].update({ - "status": "completed" if proc.returncode == 0 else "failed", - "exit_code": proc.returncode, - "output": output_lines, - "stderr": stderr.decode(), + "status": "completed" if result["exit_code"] == 0 else "failed", + "exit_code": result["exit_code"], + "output": result["output"], + "stderr": result["stderr"], "finished_at": datetime.now(timezone.utc).isoformat(), }) except asyncio.TimeoutError: @@ -89,6 +145,59 @@ async def run_agent(job_id: str, request: ExecuteRequest): execution_lock.release() +def _extract_assistant_text(output_lines: list[str]) -> str: + """Pull the final assistant text out of `claude -p --output-format json`. + + The CLI emits a single JSON object on stdout (possibly across multiple + lines if it pretty-prints) with a `result` field holding the final + assistant message. If parsing fails for any reason, fall back to the + raw concatenation so callers always get *something* useful. + """ + raw = "".join(output_lines).strip() + if not raw: + return "" + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + return raw + if isinstance(parsed, dict): + for key in ("result", "content", "text"): + value = parsed.get(key) + if isinstance(value, str) and value: + return value + return raw + + +def _one_line(text: str, limit: int = 200) -> str: + """Collapse multi-line text to a single line, truncated for response bodies.""" + flat = " ".join(text.split()) + return flat[:limit] + + +def _synthesise_prompt(messages: list[ChatMessage]) -> str: + """Flatten OpenAI chat messages into a single prompt body. + + System messages are surfaced as preamble; user messages become the + actual request. Multiple user turns are concatenated in order so a + short multi-turn back-and-forth still works (this is a stateless + completion — we don't replay prior assistant turns). + """ + system_parts = [m.content for m in messages if m.role == "system"] + user_parts = [m.content for m in messages if m.role == "user"] + # Assistant messages from prior turns are intentionally NOT injected — + # claude `-p` is stateless and replaying them as user text would + # confuse the agent. + sections: list[str] = [] + if system_parts: + sections.append("System instructions:\n" + "\n\n".join(system_parts)) + if user_parts: + sections.append("Request:\n" + "\n\n".join(user_parts)) + if not sections: + # Defensive — pydantic min_length=1 should already prevent this. + return "" + return "\n\n---\n\n".join(sections) + + @app.get("/health") async def health(): return {"status": "ok", "busy": execution_lock.locked()} @@ -134,3 +243,70 @@ async def get_job( if job_id not in jobs: raise HTTPException(status_code=404, detail="Job not found") return jobs[job_id] + + +@app.post("/v1/chat/completions") +async def chat_completions( + request: ChatCompletionsRequest, + authorization: str | None = Header(default=None), +): + verify_token(authorization) + + if request.stream: + raise HTTPException(status_code=400, detail="streaming not supported") + + agent = MODEL_TO_AGENT.get(request.model, AGENT_DEFAULT) + prompt = _synthesise_prompt(request.messages) + + if execution_lock.locked(): + return JSONResponse( + status_code=503, + content={"error": "execution failed", "detail": "agent is busy"}, + ) + + await execution_lock.acquire() + try: + try: + result = await asyncio.wait_for( + _invoke_claude_subprocess(prompt, agent, OPENAI_COMPAT_BUDGET_USD), + timeout=OPENAI_COMPAT_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + return JSONResponse( + status_code=503, + content={"error": "execution failed", "detail": "agent timed out"}, + ) + except Exception as exc: + return JSONResponse( + status_code=503, + content={"error": "execution failed", "detail": _one_line(str(exc))}, + ) + finally: + execution_lock.release() + + if result["exit_code"] != 0: + detail = _one_line(result.get("stderr") or "") or f"exit {result['exit_code']}" + return JSONResponse( + status_code=503, + content={"error": "execution failed", "detail": detail}, + ) + + content = _extract_assistant_text(result["output"]) + completion_id = "chatcmpl-" + uuid.uuid4().hex[:24] + + return { + "id": completion_id, + "object": "chat.completion", + "created": int(time.time()), + "model": request.model, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + }], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + } diff --git a/tests/test_openai_compat.py b/tests/test_openai_compat.py new file mode 100644 index 0000000..d269912 --- /dev/null +++ b/tests/test_openai_compat.py @@ -0,0 +1,264 @@ +"""Tests for the OpenAI-compatible /v1/chat/completions endpoint.""" +import json +from unittest.mock import AsyncMock, patch + +import pytest +from httpx import ASGITransport, AsyncClient + +from app import main as app_main +from app.main import app + + +@pytest.fixture +def auth_header(): + return {"Authorization": "Bearer test-token"} + + +class _AsyncLineIter: + """Real async iterator over a list of bytes lines — mimics + `proc.stdout` from `asyncio.subprocess`.""" + + def __init__(self, lines: list[bytes]): + self._lines = list(lines) + self._i = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self._i >= len(self._lines): + raise StopAsyncIteration + line = self._lines[self._i] + self._i += 1 + return line + + +def _mock_subprocess_returning(output: bytes, returncode: int = 0): + """Build an AsyncMock that mimics asyncio.create_subprocess_exec.""" + mock_process = AsyncMock() + lines = [chunk + b"\n" for chunk in output.split(b"\n") if chunk] + mock_process.stdout = _AsyncLineIter(lines) + mock_process.stderr = AsyncMock() + mock_process.stderr.read = AsyncMock(return_value=b"") + mock_process.wait = AsyncMock(return_value=returncode) + mock_process.returncode = returncode + return mock_process + + +@pytest.mark.asyncio +async def test_chat_completions_happy_path(auth_header): + """Happy path: messages in, OpenAI-shape response out.""" + cli_output = json.dumps({ + "type": "result", + "subtype": "success", + "is_error": False, + "result": "Paris is the capital of France.", + "total_cost_usd": 0.001, + "num_turns": 1, + "session_id": "abc123", + }).encode() + + mock_proc = _mock_subprocess_returning(cli_output, returncode=0) + + with patch("app.main.asyncio.create_subprocess_exec", return_value=mock_proc), \ + patch("app.main.run_git_sync", new_callable=AsyncMock): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/v1/chat/completions", + json={ + "model": "claude-haiku-4-5", + "messages": [ + {"role": "system", "content": "You are concise."}, + {"role": "user", "content": "Capital of France?"}, + ], + }, + headers=auth_header, + ) + + assert response.status_code == 200, response.text + body = response.json() + + assert body["object"] == "chat.completion" + assert body["id"].startswith("chatcmpl-") + assert body["model"] == "claude-haiku-4-5" + assert "created" in body + assert isinstance(body["created"], int) + + assert len(body["choices"]) == 1 + choice = body["choices"][0] + assert choice["index"] == 0 + assert choice["finish_reason"] == "stop" + assert choice["message"]["role"] == "assistant" + assert choice["message"]["content"] == "Paris is the capital of France." + + assert "usage" in body + for key in ("prompt_tokens", "completion_tokens", "total_tokens"): + assert key in body["usage"] + + +@pytest.mark.asyncio +async def test_chat_completions_rejects_streaming(auth_header): + """stream=true is not supported and must 400 with a clear message.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/v1/chat/completions", + json={ + "model": "claude-haiku-4-5", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + headers=auth_header, + ) + assert response.status_code == 400 + body = response.json() + assert "streaming not supported" in json.dumps(body).lower() + + +@pytest.mark.asyncio +async def test_chat_completions_requires_auth(): + """Missing bearer token must 401, identical to /execute.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/v1/chat/completions", + json={ + "model": "claude-haiku-4-5", + "messages": [{"role": "user", "content": "hi"}], + }, + ) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_chat_completions_wrong_bearer_token(): + """A wrong bearer token must also 401.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/v1/chat/completions", + json={ + "model": "claude-haiku-4-5", + "messages": [{"role": "user", "content": "hi"}], + }, + headers={"Authorization": "Bearer wrong"}, + ) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_chat_completions_returns_503_on_job_failure(auth_header): + """If the underlying claude subprocess exits non-zero, return 503.""" + mock_proc = _mock_subprocess_returning(b"", returncode=42) + mock_proc.stderr.read = AsyncMock(return_value=b"boom") + + with patch("app.main.asyncio.create_subprocess_exec", return_value=mock_proc), \ + patch("app.main.run_git_sync", new_callable=AsyncMock): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/v1/chat/completions", + json={ + "model": "claude-haiku-4-5", + "messages": [{"role": "user", "content": "trigger fail"}], + }, + headers=auth_header, + ) + assert response.status_code == 503 + body = response.json() + assert body.get("error") == "execution failed" + assert "detail" in body + + +@pytest.mark.asyncio +async def test_chat_completions_rejects_empty_messages(auth_header): + """`messages` must be a non-empty list.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/v1/chat/completions", + json={ + "model": "claude-haiku-4-5", + "messages": [], + }, + headers=auth_header, + ) + assert response.status_code in (400, 422) + + +@pytest.mark.asyncio +async def test_chat_completions_falls_back_when_no_json_result(auth_header): + """If stdout is not parseable JSON, fall back to raw concatenation.""" + mock_proc = _mock_subprocess_returning(b"plain non-json output", returncode=0) + + with patch("app.main.asyncio.create_subprocess_exec", return_value=mock_proc), \ + patch("app.main.run_git_sync", new_callable=AsyncMock): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/v1/chat/completions", + json={ + "model": "claude-haiku-4-5", + "messages": [{"role": "user", "content": "hi"}], + }, + headers=auth_header, + ) + assert response.status_code == 200 + content = response.json()["choices"][0]["message"]["content"] + assert "plain non-json output" in content + + +@pytest.mark.asyncio +async def test_chat_completions_concats_system_and_user_messages(auth_header): + """The synthesised prompt passed to claude must include both system and user content.""" + captured: dict = {} + + async def fake_subprocess(*args, **kwargs): + captured["args"] = args + return _mock_subprocess_returning( + json.dumps({"type": "result", "result": "ok", "is_error": False}).encode(), + returncode=0, + ) + + with patch("app.main.asyncio.create_subprocess_exec", side_effect=fake_subprocess), \ + patch("app.main.run_git_sync", new_callable=AsyncMock): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/v1/chat/completions", + json={ + "model": "claude-haiku-4-5", + "messages": [ + {"role": "system", "content": "SYSTEM-MARKER"}, + {"role": "user", "content": "USER-MARKER"}, + ], + }, + headers=auth_header, + ) + assert response.status_code == 200 + prompt_arg = captured["args"][-1] + assert "SYSTEM-MARKER" in prompt_arg + assert "USER-MARKER" in prompt_arg + + +@pytest.mark.asyncio +async def test_chat_completions_returns_503_when_agent_busy(auth_header): + """If the agent is already busy, return 503.""" + await app_main.execution_lock.acquire() + try: + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/v1/chat/completions", + json={ + "model": "claude-haiku-4-5", + "messages": [{"role": "user", "content": "hi"}], + }, + headers=auth_header, + ) + finally: + app_main.execution_lock.release() + assert response.status_code == 503 + body = response.json() + assert body.get("error") == "execution failed"