openai-compat: pass --model from request through to claude -p
Replaces the MODEL_TO_AGENT dict (which only mapped model -> agent and ignored the model itself) with a SUPPORTED_MODELS allowlist + per-request --model CLI flag. Callers can now pick Haiku/Sonnet/Opus per request to control cost; unknown model IDs 400 with the supported list; missing model defaults to claude-sonnet-4-6 (mid-tier). The --model CLI flag overrides whatever model: is in the agent's frontmatter, so recruiter-triage's `model: sonnet` no longer pins every request to Sonnet. Verified with claude CLI 2.1.153 that the bare-form IDs (claude-haiku-4-5, claude-sonnet-4-6, claude-opus-4-7) are accepted without date suffixes — confirmed via modelUsage keys in the JSON output. Six new tests cover: default routing, haiku/sonnet/opus pass-through, unsupported-model 400 shape, and the response.model echo.
This commit is contained in:
parent
07dcfca333
commit
7baa66d994
2 changed files with 165 additions and 17 deletions
56
app/main.py
56
app/main.py
|
|
@ -17,18 +17,19 @@ app = FastAPI(title="Claude Agent Service")
|
|||
API_TOKEN = os.environ.get("API_BEARER_TOKEN", "")
|
||||
WORKSPACE_DIR = os.environ.get("WORKSPACE_DIR", "/workspace/infra")
|
||||
|
||||
# OpenAI compat: model -> agent mapping. v1 keeps it dead simple — all models
|
||||
# route to the most general agent we have. `recruiter-triage` has the broadest
|
||||
# tool surface (WebSearch, WebFetch, Read, Grep, Glob, Bash); the alternative
|
||||
# (`beads-task-runner`) is locked to read-only `bd` verbs which would fail
|
||||
# arbitrary OpenAI-API callers. Revisit when a true general-purpose agent
|
||||
# lands in `agents/`.
|
||||
MODEL_TO_AGENT: dict[str, str] = {
|
||||
"claude-haiku-4-5": "recruiter-triage",
|
||||
"claude-sonnet-4-6": "recruiter-triage",
|
||||
"claude-opus-4-7": "recruiter-triage",
|
||||
}
|
||||
AGENT_DEFAULT = "recruiter-triage"
|
||||
# OpenAI compat: model selection is per-request so callers can pick
|
||||
# Haiku/Sonnet/Opus to control cost. The agent is fixed — `recruiter-triage`
|
||||
# has the broadest tool surface (WebSearch, WebFetch, Read, Grep, Glob, Bash);
|
||||
# the alternative (`beads-task-runner`) is locked to read-only `bd` verbs which
|
||||
# would fail arbitrary OpenAI-API callers. The model on the agent's frontmatter
|
||||
# is overridden by the `--model` CLI flag we pass per-request.
|
||||
SUPPORTED_MODELS: frozenset[str] = frozenset({
|
||||
"claude-haiku-4-5",
|
||||
"claude-sonnet-4-6",
|
||||
"claude-opus-4-7",
|
||||
})
|
||||
DEFAULT_MODEL = "claude-sonnet-4-6"
|
||||
OPENAI_COMPAT_AGENT = "recruiter-triage"
|
||||
OPENAI_COMPAT_BUDGET_USD = 2.0
|
||||
OPENAI_COMPAT_TIMEOUT_SECONDS = 900
|
||||
|
||||
|
|
@ -50,7 +51,10 @@ class ChatMessage(BaseModel):
|
|||
|
||||
|
||||
class ChatCompletionsRequest(BaseModel):
|
||||
model: str
|
||||
# `model` is optional: callers that omit it get DEFAULT_MODEL. We still
|
||||
# validate the explicit value against SUPPORTED_MODELS at the route level
|
||||
# so we can return a structured 400 listing the allowed IDs.
|
||||
model: str | None = None
|
||||
messages: list[ChatMessage] = Field(..., min_length=1)
|
||||
max_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
|
|
@ -85,6 +89,7 @@ async def _invoke_claude_subprocess(
|
|||
prompt: str,
|
||||
agent: str,
|
||||
max_budget_usd: float,
|
||||
model: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Run the claude CLI once and return a result dict.
|
||||
|
||||
|
|
@ -92,6 +97,10 @@ async def _invoke_claude_subprocess(
|
|||
this helper does not touch the lock or the `jobs` dict, so it can be
|
||||
shared by both the background `/execute` path and the synchronous
|
||||
`/v1/chat/completions` path.
|
||||
|
||||
`model`, when provided, becomes `--model <id>` on the claude CLI. This
|
||||
overrides whatever `model:` is set in the agent's frontmatter so the
|
||||
OpenAI-compat path can pick Haiku/Sonnet/Opus per-request.
|
||||
"""
|
||||
await run_git_sync()
|
||||
|
||||
|
|
@ -101,8 +110,10 @@ async def _invoke_claude_subprocess(
|
|||
"--dangerously-skip-permissions",
|
||||
"--max-budget-usd", str(max_budget_usd),
|
||||
"--output-format", "json",
|
||||
prompt,
|
||||
]
|
||||
if model is not None:
|
||||
cmd.extend(["--model", model])
|
||||
cmd.append(prompt)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
|
|
@ -255,7 +266,16 @@ async def chat_completions(
|
|||
if request.stream:
|
||||
raise HTTPException(status_code=400, detail="streaming not supported")
|
||||
|
||||
agent = MODEL_TO_AGENT.get(request.model, AGENT_DEFAULT)
|
||||
model = request.model if request.model is not None else DEFAULT_MODEL
|
||||
if model not in SUPPORTED_MODELS:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error": "unsupported model",
|
||||
"supported": sorted(SUPPORTED_MODELS),
|
||||
},
|
||||
)
|
||||
|
||||
prompt = _synthesise_prompt(request.messages)
|
||||
|
||||
if execution_lock.locked():
|
||||
|
|
@ -268,7 +288,9 @@ async def chat_completions(
|
|||
try:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
_invoke_claude_subprocess(prompt, agent, OPENAI_COMPAT_BUDGET_USD),
|
||||
_invoke_claude_subprocess(
|
||||
prompt, OPENAI_COMPAT_AGENT, OPENAI_COMPAT_BUDGET_USD, model=model,
|
||||
),
|
||||
timeout=OPENAI_COMPAT_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
|
|
@ -298,7 +320,7 @@ async def chat_completions(
|
|||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
|
|
|
|||
|
|
@ -262,3 +262,129 @@ async def test_chat_completions_returns_503_when_agent_busy(auth_header):
|
|||
assert response.status_code == 503
|
||||
body = response.json()
|
||||
assert body.get("error") == "execution failed"
|
||||
|
||||
|
||||
async def _capture_subprocess_args(
|
||||
auth_header: dict,
|
||||
payload: dict,
|
||||
) -> tuple[int, dict, tuple]:
|
||||
"""POST `payload` to /v1/chat/completions and return (status, body, subprocess_args).
|
||||
|
||||
The subprocess is mocked to return a trivial success result; the
|
||||
point is to inspect the argv passed to `asyncio.create_subprocess_exec`.
|
||||
"""
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_subprocess(*args, **kwargs):
|
||||
captured["args"] = args
|
||||
return _mock_subprocess_returning(
|
||||
json.dumps({"type": "result", "result": "ok", "is_error": False}).encode(),
|
||||
returncode=0,
|
||||
)
|
||||
|
||||
with patch("app.main.asyncio.create_subprocess_exec", side_effect=fake_subprocess), \
|
||||
patch("app.main.run_git_sync", new_callable=AsyncMock):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=auth_header,
|
||||
)
|
||||
return response.status_code, response.json(), captured.get("args", ())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_routes_haiku_to_claude_cli(auth_header):
|
||||
"""`model: claude-haiku-4-5` → subprocess invoked with `--model claude-haiku-4-5`."""
|
||||
status, _, args = await _capture_subprocess_args(
|
||||
auth_header,
|
||||
{
|
||||
"model": "claude-haiku-4-5",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
)
|
||||
assert status == 200
|
||||
assert "--model" in args
|
||||
model_idx = args.index("--model")
|
||||
assert args[model_idx + 1] == "claude-haiku-4-5"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_routes_sonnet_to_claude_cli(auth_header):
|
||||
"""`model: claude-sonnet-4-6` → subprocess invoked with `--model claude-sonnet-4-6`."""
|
||||
status, _, args = await _capture_subprocess_args(
|
||||
auth_header,
|
||||
{
|
||||
"model": "claude-sonnet-4-6",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
)
|
||||
assert status == 200
|
||||
assert "--model" in args
|
||||
model_idx = args.index("--model")
|
||||
assert args[model_idx + 1] == "claude-sonnet-4-6"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_routes_opus_to_claude_cli(auth_header):
|
||||
"""`model: claude-opus-4-7` → subprocess invoked with `--model claude-opus-4-7`."""
|
||||
status, _, args = await _capture_subprocess_args(
|
||||
auth_header,
|
||||
{
|
||||
"model": "claude-opus-4-7",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
)
|
||||
assert status == 200
|
||||
assert "--model" in args
|
||||
model_idx = args.index("--model")
|
||||
assert args[model_idx + 1] == "claude-opus-4-7"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_uses_default_model_when_field_missing(auth_header):
|
||||
"""Missing `model` → subprocess invoked with `--model claude-sonnet-4-6` (default)."""
|
||||
status, _, args = await _capture_subprocess_args(
|
||||
auth_header,
|
||||
{"messages": [{"role": "user", "content": "hi"}]},
|
||||
)
|
||||
assert status == 200
|
||||
assert "--model" in args
|
||||
model_idx = args.index("--model")
|
||||
assert args[model_idx + 1] == "claude-sonnet-4-6"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_rejects_unknown_model(auth_header):
|
||||
"""Unknown models 400 with `unsupported model` and the supported list."""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
body = response.json()
|
||||
assert body.get("error") == "unsupported model"
|
||||
assert "supported" in body
|
||||
supported = body["supported"]
|
||||
assert isinstance(supported, list)
|
||||
assert "claude-haiku-4-5" in supported
|
||||
assert "claude-sonnet-4-6" in supported
|
||||
assert "claude-opus-4-7" in supported
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completions_response_model_echoes_default_when_missing(auth_header):
|
||||
"""When `model` is omitted, the response `model` field reports the default used."""
|
||||
status, body, _ = await _capture_subprocess_args(
|
||||
auth_header,
|
||||
{"messages": [{"role": "user", "content": "hi"}]},
|
||||
)
|
||||
assert status == 200
|
||||
assert body["model"] == "claude-sonnet-4-6"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue