openai-compat: add /v1/chat/completions endpoint
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
OpenAI-compatible chat completions endpoint so existing OpenAI-API
clients (fire-planner's examples/llm_extract.py and others) can target
this service without rewriting their client.
Behaviour:
- POST /v1/chat/completions accepts the OpenAI chat-completions request
shape (model, messages, max_tokens?, temperature?, stream?).
- Reuses the existing Bearer auth from /execute.
- Synthesises a single prompt body from system+user messages
("System instructions:\n... --- Request:\n...") so the agent treats
them as the user's request rather than seeing raw JSON.
- Internally shares the execution path with /execute by extracting
_invoke_claude_subprocess(). Holds execution_lock for the duration;
returns 503 (not 409) when busy, since OpenAI callers have no
job-id model to retry against.
- Returns the OpenAI chat-completions envelope with the final
assistant text extracted from `claude -p --output-format json`
(falls back to raw stdout if parsing fails).
- stream=true -> 400 {"error": "streaming not supported"}.
- Underlying failure (non-zero exit, timeout, exception) -> 503
{"error": "execution failed", "detail": "<one line>"}.
Model -> agent mapping is hardcoded to `recruiter-triage` for all
models for v1 (broadest tool surface among current agents). Budget
is hardcoded to $2.00/call; timeout 900s. Revisit when a true
general-purpose agent lands.
Tests: 9 new tests covering happy path, streaming rejection, missing
auth, wrong token, job failure, empty messages, JSON-parse fallback,
prompt synthesis, and busy-503. All 20 tests (11 existing + 9 new)
pass; ruff clean.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
191ed5dd87
commit
07dcfca333
2 changed files with 469 additions and 29 deletions
234
app/main.py
234
app/main.py
|
|
@ -1,18 +1,37 @@
|
|||
import asyncio
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from subprocess import PIPE
|
||||
from typing import Any, Literal
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Header
|
||||
from pydantic import BaseModel
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
app = FastAPI(title="Claude Agent Service")
|
||||
|
||||
API_TOKEN = os.environ.get("API_BEARER_TOKEN", "")
|
||||
WORKSPACE_DIR = os.environ.get("WORKSPACE_DIR", "/workspace/infra")
|
||||
|
||||
# OpenAI compat: model -> agent mapping. v1 keeps it dead simple — all models
|
||||
# route to the most general agent we have. `recruiter-triage` has the broadest
|
||||
# tool surface (WebSearch, WebFetch, Read, Grep, Glob, Bash); the alternative
|
||||
# (`beads-task-runner`) is locked to read-only `bd` verbs which would fail
|
||||
# arbitrary OpenAI-API callers. Revisit when a true general-purpose agent
|
||||
# lands in `agents/`.
|
||||
MODEL_TO_AGENT: dict[str, str] = {
|
||||
"claude-haiku-4-5": "recruiter-triage",
|
||||
"claude-sonnet-4-6": "recruiter-triage",
|
||||
"claude-opus-4-7": "recruiter-triage",
|
||||
}
|
||||
AGENT_DEFAULT = "recruiter-triage"
|
||||
OPENAI_COMPAT_BUDGET_USD = 2.0
|
||||
OPENAI_COMPAT_TIMEOUT_SECONDS = 900
|
||||
|
||||
jobs: dict[str, dict] = {}
|
||||
execution_lock = asyncio.Lock()
|
||||
|
||||
|
|
@ -25,6 +44,21 @@ class ExecuteRequest(BaseModel):
|
|||
metadata: dict | None = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant"]
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionsRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[ChatMessage] = Field(..., min_length=1)
|
||||
max_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
stream: bool = False
|
||||
# Tolerate (and ignore) other OpenAI fields rather than 422-ing on them.
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
def verify_token(authorization: str | None):
|
||||
# Reject everything when the service is unconfigured. compare_digest("", "")
|
||||
# returns True, so without this guard an empty API_TOKEN would happily
|
||||
|
|
@ -47,38 +81,60 @@ async def run_git_sync():
|
|||
await proc.wait()
|
||||
|
||||
|
||||
async def _invoke_claude_subprocess(
|
||||
prompt: str,
|
||||
agent: str,
|
||||
max_budget_usd: float,
|
||||
) -> dict[str, Any]:
|
||||
"""Run the claude CLI once and return a result dict.
|
||||
|
||||
The caller is responsible for holding `execution_lock` for the duration —
|
||||
this helper does not touch the lock or the `jobs` dict, so it can be
|
||||
shared by both the background `/execute` path and the synchronous
|
||||
`/v1/chat/completions` path.
|
||||
"""
|
||||
await run_git_sync()
|
||||
|
||||
cmd = [
|
||||
"claude", "-p",
|
||||
"--agent", agent,
|
||||
"--dangerously-skip-permissions",
|
||||
"--max-budget-usd", str(max_budget_usd),
|
||||
"--output-format", "json",
|
||||
prompt,
|
||||
]
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
cwd=WORKSPACE_DIR,
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
)
|
||||
|
||||
output_lines: list[str] = []
|
||||
async for line in proc.stdout:
|
||||
output_lines.append(line.decode())
|
||||
|
||||
stderr = await proc.stderr.read()
|
||||
await proc.wait()
|
||||
|
||||
return {
|
||||
"exit_code": proc.returncode,
|
||||
"output": output_lines,
|
||||
"stderr": stderr.decode(),
|
||||
}
|
||||
|
||||
|
||||
async def run_agent(job_id: str, request: ExecuteRequest):
|
||||
try:
|
||||
await run_git_sync()
|
||||
|
||||
cmd = [
|
||||
"claude", "-p",
|
||||
"--agent", request.agent,
|
||||
"--dangerously-skip-permissions",
|
||||
"--max-budget-usd", str(request.max_budget_usd),
|
||||
"--output-format", "json",
|
||||
request.prompt,
|
||||
]
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
cwd=WORKSPACE_DIR,
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
result = await _invoke_claude_subprocess(
|
||||
request.prompt, request.agent, request.max_budget_usd,
|
||||
)
|
||||
|
||||
output_lines = []
|
||||
async for line in proc.stdout:
|
||||
output_lines.append(line.decode())
|
||||
|
||||
stderr = await proc.stderr.read()
|
||||
await proc.wait()
|
||||
|
||||
jobs[job_id].update({
|
||||
"status": "completed" if proc.returncode == 0 else "failed",
|
||||
"exit_code": proc.returncode,
|
||||
"output": output_lines,
|
||||
"stderr": stderr.decode(),
|
||||
"status": "completed" if result["exit_code"] == 0 else "failed",
|
||||
"exit_code": result["exit_code"],
|
||||
"output": result["output"],
|
||||
"stderr": result["stderr"],
|
||||
"finished_at": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
except asyncio.TimeoutError:
|
||||
|
|
@ -89,6 +145,59 @@ async def run_agent(job_id: str, request: ExecuteRequest):
|
|||
execution_lock.release()
|
||||
|
||||
|
||||
def _extract_assistant_text(output_lines: list[str]) -> str:
|
||||
"""Pull the final assistant text out of `claude -p --output-format json`.
|
||||
|
||||
The CLI emits a single JSON object on stdout (possibly across multiple
|
||||
lines if it pretty-prints) with a `result` field holding the final
|
||||
assistant message. If parsing fails for any reason, fall back to the
|
||||
raw concatenation so callers always get *something* useful.
|
||||
"""
|
||||
raw = "".join(output_lines).strip()
|
||||
if not raw:
|
||||
return ""
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return raw
|
||||
if isinstance(parsed, dict):
|
||||
for key in ("result", "content", "text"):
|
||||
value = parsed.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return raw
|
||||
|
||||
|
||||
def _one_line(text: str, limit: int = 200) -> str:
|
||||
"""Collapse multi-line text to a single line, truncated for response bodies."""
|
||||
flat = " ".join(text.split())
|
||||
return flat[:limit]
|
||||
|
||||
|
||||
def _synthesise_prompt(messages: list[ChatMessage]) -> str:
|
||||
"""Flatten OpenAI chat messages into a single prompt body.
|
||||
|
||||
System messages are surfaced as preamble; user messages become the
|
||||
actual request. Multiple user turns are concatenated in order so a
|
||||
short multi-turn back-and-forth still works (this is a stateless
|
||||
completion — we don't replay prior assistant turns).
|
||||
"""
|
||||
system_parts = [m.content for m in messages if m.role == "system"]
|
||||
user_parts = [m.content for m in messages if m.role == "user"]
|
||||
# Assistant messages from prior turns are intentionally NOT injected —
|
||||
# claude `-p` is stateless and replaying them as user text would
|
||||
# confuse the agent.
|
||||
sections: list[str] = []
|
||||
if system_parts:
|
||||
sections.append("System instructions:\n" + "\n\n".join(system_parts))
|
||||
if user_parts:
|
||||
sections.append("Request:\n" + "\n\n".join(user_parts))
|
||||
if not sections:
|
||||
# Defensive — pydantic min_length=1 should already prevent this.
|
||||
return ""
|
||||
return "\n\n---\n\n".join(sections)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "busy": execution_lock.locked()}
|
||||
|
|
@ -134,3 +243,70 @@ async def get_job(
|
|||
if job_id not in jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return jobs[job_id]
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(
|
||||
request: ChatCompletionsRequest,
|
||||
authorization: str | None = Header(default=None),
|
||||
):
|
||||
verify_token(authorization)
|
||||
|
||||
if request.stream:
|
||||
raise HTTPException(status_code=400, detail="streaming not supported")
|
||||
|
||||
agent = MODEL_TO_AGENT.get(request.model, AGENT_DEFAULT)
|
||||
prompt = _synthesise_prompt(request.messages)
|
||||
|
||||
if execution_lock.locked():
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"error": "execution failed", "detail": "agent is busy"},
|
||||
)
|
||||
|
||||
await execution_lock.acquire()
|
||||
try:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
_invoke_claude_subprocess(prompt, agent, OPENAI_COMPAT_BUDGET_USD),
|
||||
timeout=OPENAI_COMPAT_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"error": "execution failed", "detail": "agent timed out"},
|
||||
)
|
||||
except Exception as exc:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"error": "execution failed", "detail": _one_line(str(exc))},
|
||||
)
|
||||
finally:
|
||||
execution_lock.release()
|
||||
|
||||
if result["exit_code"] != 0:
|
||||
detail = _one_line(result.get("stderr") or "") or f"exit {result['exit_code']}"
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"error": "execution failed", "detail": detail},
|
||||
)
|
||||
|
||||
content = _extract_assistant_text(result["output"])
|
||||
completion_id = "chatcmpl-" + uuid.uuid4().hex[:24]
|
||||
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
264
tests/test_openai_compat.py
Normal file
264
tests/test_openai_compat.py
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
"""Tests for the OpenAI-compatible /v1/chat/completions endpoint."""
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app import main as app_main
|
||||
from app.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_header():
|
||||
return {"Authorization": "Bearer test-token"}
|
||||
|
||||
|
||||
class _AsyncLineIter:
|
||||
"""Real async iterator over a list of bytes lines — mimics
|
||||
`proc.stdout` from `asyncio.subprocess`."""
|
||||
|
||||
def __init__(self, lines: list[bytes]):
|
||||
self._lines = list(lines)
|
||||
self._i = 0
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self._i >= len(self._lines):
|
||||
raise StopAsyncIteration
|
||||
line = self._lines[self._i]
|
||||
self._i += 1
|
||||
return line
|
||||
|
||||
|
||||
def _mock_subprocess_returning(output: bytes, returncode: int = 0):
|
||||
"""Build an AsyncMock that mimics asyncio.create_subprocess_exec."""
|
||||
mock_process = AsyncMock()
|
||||
lines = [chunk + b"\n" for chunk in output.split(b"\n") if chunk]
|
||||
mock_process.stdout = _AsyncLineIter(lines)
|
||||
mock_process.stderr = AsyncMock()
|
||||
mock_process.stderr.read = AsyncMock(return_value=b"")
|
||||
mock_process.wait = AsyncMock(return_value=returncode)
|
||||
mock_process.returncode = returncode
|
||||
return mock_process
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_happy_path(auth_header):
|
||||
"""Happy path: messages in, OpenAI-shape response out."""
|
||||
cli_output = json.dumps({
|
||||
"type": "result",
|
||||
"subtype": "success",
|
||||
"is_error": False,
|
||||
"result": "Paris is the capital of France.",
|
||||
"total_cost_usd": 0.001,
|
||||
"num_turns": 1,
|
||||
"session_id": "abc123",
|
||||
}).encode()
|
||||
|
||||
mock_proc = _mock_subprocess_returning(cli_output, returncode=0)
|
||||
|
||||
with patch("app.main.asyncio.create_subprocess_exec", return_value=mock_proc), \
|
||||
patch("app.main.run_git_sync", new_callable=AsyncMock):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "claude-haiku-4-5",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are concise."},
|
||||
{"role": "user", "content": "Capital of France?"},
|
||||
],
|
||||
},
|
||||
headers=auth_header,
|
||||
)
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
body = response.json()
|
||||
|
||||
assert body["object"] == "chat.completion"
|
||||
assert body["id"].startswith("chatcmpl-")
|
||||
assert body["model"] == "claude-haiku-4-5"
|
||||
assert "created" in body
|
||||
assert isinstance(body["created"], int)
|
||||
|
||||
assert len(body["choices"]) == 1
|
||||
choice = body["choices"][0]
|
||||
assert choice["index"] == 0
|
||||
assert choice["finish_reason"] == "stop"
|
||||
assert choice["message"]["role"] == "assistant"
|
||||
assert choice["message"]["content"] == "Paris is the capital of France."
|
||||
|
||||
assert "usage" in body
|
||||
for key in ("prompt_tokens", "completion_tokens", "total_tokens"):
|
||||
assert key in body["usage"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_rejects_streaming(auth_header):
|
||||
"""stream=true is not supported and must 400 with a clear message."""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "claude-haiku-4-5",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": True,
|
||||
},
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
body = response.json()
|
||||
assert "streaming not supported" in json.dumps(body).lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_requires_auth():
|
||||
"""Missing bearer token must 401, identical to /execute."""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "claude-haiku-4-5",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_wrong_bearer_token():
|
||||
"""A wrong bearer token must also 401."""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "claude-haiku-4-5",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
headers={"Authorization": "Bearer wrong"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_returns_503_on_job_failure(auth_header):
|
||||
"""If the underlying claude subprocess exits non-zero, return 503."""
|
||||
mock_proc = _mock_subprocess_returning(b"", returncode=42)
|
||||
mock_proc.stderr.read = AsyncMock(return_value=b"boom")
|
||||
|
||||
with patch("app.main.asyncio.create_subprocess_exec", return_value=mock_proc), \
|
||||
patch("app.main.run_git_sync", new_callable=AsyncMock):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "claude-haiku-4-5",
|
||||
"messages": [{"role": "user", "content": "trigger fail"}],
|
||||
},
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 503
|
||||
body = response.json()
|
||||
assert body.get("error") == "execution failed"
|
||||
assert "detail" in body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_rejects_empty_messages(auth_header):
|
||||
"""`messages` must be a non-empty list."""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "claude-haiku-4-5",
|
||||
"messages": [],
|
||||
},
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code in (400, 422)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_falls_back_when_no_json_result(auth_header):
|
||||
"""If stdout is not parseable JSON, fall back to raw concatenation."""
|
||||
mock_proc = _mock_subprocess_returning(b"plain non-json output", returncode=0)
|
||||
|
||||
with patch("app.main.asyncio.create_subprocess_exec", return_value=mock_proc), \
|
||||
patch("app.main.run_git_sync", new_callable=AsyncMock):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "claude-haiku-4-5",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
content = response.json()["choices"][0]["message"]["content"]
|
||||
assert "plain non-json output" in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_concats_system_and_user_messages(auth_header):
|
||||
"""The synthesised prompt passed to claude must include both system and user content."""
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_subprocess(*args, **kwargs):
|
||||
captured["args"] = args
|
||||
return _mock_subprocess_returning(
|
||||
json.dumps({"type": "result", "result": "ok", "is_error": False}).encode(),
|
||||
returncode=0,
|
||||
)
|
||||
|
||||
with patch("app.main.asyncio.create_subprocess_exec", side_effect=fake_subprocess), \
|
||||
patch("app.main.run_git_sync", new_callable=AsyncMock):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "claude-haiku-4-5",
|
||||
"messages": [
|
||||
{"role": "system", "content": "SYSTEM-MARKER"},
|
||||
{"role": "user", "content": "USER-MARKER"},
|
||||
],
|
||||
},
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
prompt_arg = captured["args"][-1]
|
||||
assert "SYSTEM-MARKER" in prompt_arg
|
||||
assert "USER-MARKER" in prompt_arg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_returns_503_when_agent_busy(auth_header):
|
||||
"""If the agent is already busy, return 503."""
|
||||
await app_main.execution_lock.acquire()
|
||||
try:
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "claude-haiku-4-5",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
headers=auth_header,
|
||||
)
|
||||
finally:
|
||||
app_main.execution_lock.release()
|
||||
assert response.status_code == 503
|
||||
body = response.json()
|
||||
assert body.get("error") == "execution failed"
|
||||
Loading…
Add table
Add a link
Reference in a new issue