[Feat] Supports Anthropic Messages count_tokens API (#35588)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.anthropic.protocol import (
|
||||
AnthropicCountTokensRequest,
|
||||
AnthropicCountTokensResponse,
|
||||
AnthropicError,
|
||||
AnthropicErrorResponse,
|
||||
AnthropicMessagesRequest,
|
||||
@@ -31,6 +33,18 @@ def messages(request: Request) -> AnthropicServingMessages:
|
||||
return request.app.state.anthropic_serving_messages
|
||||
|
||||
|
||||
def translate_error_response(response: ErrorResponse) -> JSONResponse:
|
||||
anthropic_error = AnthropicErrorResponse(
|
||||
error=AnthropicError(
|
||||
type=response.error.type,
|
||||
message=response.error.message,
|
||||
)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=response.error.code, content=anthropic_error.model_dump()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
@@ -44,17 +58,6 @@ def messages(request: Request) -> AnthropicServingMessages:
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
|
||||
def translate_error_response(response: ErrorResponse) -> JSONResponse:
|
||||
anthropic_error = AnthropicErrorResponse(
|
||||
error=AnthropicError(
|
||||
type=response.error.type,
|
||||
message=response.error.message,
|
||||
)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=response.error.code, content=anthropic_error.model_dump()
|
||||
)
|
||||
|
||||
handler = messages(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
@@ -88,5 +91,46 @@ async def create_messages(request: AnthropicMessagesRequest, raw_request: Reques
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages/count_tokens",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"model": AnthropicCountTokensResponse},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
|
||||
},
|
||||
)
|
||||
@load_aware_call
|
||||
@with_cancellation
|
||||
async def count_tokens(request: AnthropicCountTokensRequest, raw_request: Request):
|
||||
handler = messages(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
error = base_server.create_error_response(
|
||||
message="The model does not support Messages API"
|
||||
)
|
||||
return translate_error_response(error)
|
||||
|
||||
try:
|
||||
response = await handler.count_tokens(request, raw_request)
|
||||
except Exception as e:
|
||||
logger.exception("Error in count_tokens: %s", e)
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
content=AnthropicErrorResponse(
|
||||
error=AnthropicError(
|
||||
type="internal_error",
|
||||
message=str(e),
|
||||
)
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if isinstance(response, ErrorResponse):
|
||||
return translate_error_response(response)
|
||||
|
||||
return JSONResponse(content=response.model_dump(exclude_none=True))
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
|
||||
@@ -175,3 +175,33 @@ class AnthropicMessagesResponse(BaseModel):
|
||||
def model_post_init(self, __context):
|
||||
if not self.id:
|
||||
self.id = f"msg_{int(time.time() * 1000)}"
|
||||
|
||||
|
||||
class AnthropicContextManagement(BaseModel):
|
||||
"""Context management information for token counting."""
|
||||
|
||||
original_input_tokens: int
|
||||
|
||||
|
||||
class AnthropicCountTokensRequest(BaseModel):
|
||||
"""Anthropic messages.count_tokens request"""
|
||||
|
||||
model: str
|
||||
messages: list[AnthropicMessage]
|
||||
system: str | list[AnthropicContentBlock] | None = None
|
||||
tool_choice: AnthropicToolChoice | None = None
|
||||
tools: list[AnthropicTool] | None = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Model is required")
|
||||
return v
|
||||
|
||||
|
||||
class AnthropicCountTokensResponse(BaseModel):
|
||||
"""Anthropic messages.count_tokens response"""
|
||||
|
||||
input_tokens: int
|
||||
context_management: AnthropicContextManagement | None = None
|
||||
|
||||
@@ -17,6 +17,9 @@ from fastapi import Request
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.anthropic.protocol import (
|
||||
AnthropicContentBlock,
|
||||
AnthropicContextManagement,
|
||||
AnthropicCountTokensRequest,
|
||||
AnthropicCountTokensResponse,
|
||||
AnthropicDelta,
|
||||
AnthropicError,
|
||||
AnthropicMessagesRequest,
|
||||
@@ -109,135 +112,202 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
|
||||
@classmethod
|
||||
def _convert_anthropic_to_openai_request(
|
||||
cls, anthropic_request: AnthropicMessagesRequest
|
||||
cls, anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest
|
||||
) -> ChatCompletionRequest:
|
||||
"""Convert Anthropic message format to OpenAI format"""
|
||||
openai_messages = []
|
||||
openai_messages: list[dict[str, Any]] = []
|
||||
|
||||
# Add system message if provided
|
||||
if anthropic_request.system:
|
||||
if isinstance(anthropic_request.system, str):
|
||||
openai_messages.append(
|
||||
{"role": "system", "content": anthropic_request.system}
|
||||
)
|
||||
else:
|
||||
system_prompt = ""
|
||||
for block in anthropic_request.system:
|
||||
if block.type == "text" and block.text:
|
||||
system_prompt += block.text
|
||||
openai_messages.append({"role": "system", "content": system_prompt})
|
||||
cls._convert_system_message(anthropic_request, openai_messages)
|
||||
cls._convert_messages(anthropic_request.messages, openai_messages)
|
||||
req = cls._build_base_request(anthropic_request, openai_messages)
|
||||
cls._handle_streaming_options(req, anthropic_request)
|
||||
cls._convert_tool_choice(anthropic_request, req)
|
||||
cls._convert_tools(anthropic_request, req)
|
||||
return req
|
||||
|
||||
for msg in anthropic_request.messages:
|
||||
@classmethod
|
||||
def _convert_system_message(
|
||||
cls,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
openai_messages: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Convert Anthropic system message to OpenAI format"""
|
||||
if not anthropic_request.system:
|
||||
return
|
||||
|
||||
if isinstance(anthropic_request.system, str):
|
||||
openai_messages.append(
|
||||
{"role": "system", "content": anthropic_request.system}
|
||||
)
|
||||
else:
|
||||
system_prompt = ""
|
||||
for block in anthropic_request.system:
|
||||
if block.type == "text" and block.text:
|
||||
system_prompt += block.text
|
||||
openai_messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
@classmethod
|
||||
def _convert_messages(
|
||||
cls, messages: list, openai_messages: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Convert Anthropic messages to OpenAI format"""
|
||||
for msg in messages:
|
||||
openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore
|
||||
|
||||
if isinstance(msg.content, str):
|
||||
openai_msg["content"] = msg.content
|
||||
else:
|
||||
# Handle complex content blocks
|
||||
content_parts: list[dict[str, Any]] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
reasoning_parts: list[str] = []
|
||||
|
||||
for block in msg.content:
|
||||
if block.type == "text" and block.text:
|
||||
content_parts.append({"type": "text", "text": block.text})
|
||||
elif block.type == "image" and block.source:
|
||||
image_url = cls._convert_image_source_to_url(block.source)
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_url},
|
||||
}
|
||||
)
|
||||
elif block.type == "thinking" and block.thinking is not None:
|
||||
reasoning_parts.append(block.thinking)
|
||||
elif block.type == "tool_use":
|
||||
# Convert tool use to function call format
|
||||
tool_call = {
|
||||
"id": block.id or f"call_{int(time.time())}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.name or "",
|
||||
"arguments": json.dumps(block.input or {}),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
elif block.type == "tool_result":
|
||||
if msg.role == "user":
|
||||
# Parse tool_result content which can be
|
||||
# a string or a list of content blocks
|
||||
# (text, image, etc.)
|
||||
tool_text = ""
|
||||
tool_image_urls: list[str] = []
|
||||
if isinstance(block.content, str):
|
||||
tool_text = block.content
|
||||
elif isinstance(block.content, list):
|
||||
text_parts: list[str] = []
|
||||
for item in block.content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item_type == "image":
|
||||
source = item.get("source", {})
|
||||
url = cls._convert_image_source_to_url(source)
|
||||
if url:
|
||||
tool_image_urls.append(url)
|
||||
tool_text = "\n".join(text_parts)
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.tool_use_id or "",
|
||||
"content": tool_text or "",
|
||||
}
|
||||
)
|
||||
# OpenAI tool messages only support string
|
||||
# content, so inject images from tool
|
||||
# results as a follow-up user message
|
||||
if tool_image_urls:
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [ # type: ignore[dict-item]
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": img},
|
||||
}
|
||||
for img in tool_image_urls
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Assistant tool result becomes regular text
|
||||
tool_result_text = (
|
||||
str(block.content) if block.content else ""
|
||||
)
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Tool result: {tool_result_text}",
|
||||
}
|
||||
)
|
||||
|
||||
if reasoning_parts:
|
||||
openai_msg["reasoning"] = "".join(reasoning_parts)
|
||||
|
||||
# Add tool calls to the message if any
|
||||
if tool_calls:
|
||||
openai_msg["tool_calls"] = tool_calls # type: ignore
|
||||
|
||||
# Add content parts if any
|
||||
if content_parts:
|
||||
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
|
||||
openai_msg["content"] = content_parts[0]["text"]
|
||||
else:
|
||||
openai_msg["content"] = content_parts # type: ignore
|
||||
elif not tool_calls and not reasoning_parts:
|
||||
continue
|
||||
cls._convert_message_content(msg, openai_msg, openai_messages)
|
||||
|
||||
openai_messages.append(openai_msg)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
@classmethod
|
||||
def _convert_message_content(
|
||||
cls,
|
||||
msg,
|
||||
openai_msg: dict[str, Any],
|
||||
openai_messages: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Convert complex message content blocks"""
|
||||
content_parts: list[dict[str, Any]] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
reasoning_parts: list[str] = []
|
||||
|
||||
for block in msg.content:
|
||||
cls._convert_block(
|
||||
block,
|
||||
msg.role,
|
||||
content_parts,
|
||||
tool_calls,
|
||||
reasoning_parts,
|
||||
openai_messages,
|
||||
)
|
||||
|
||||
if reasoning_parts:
|
||||
openai_msg["reasoning"] = "".join(reasoning_parts)
|
||||
|
||||
if tool_calls:
|
||||
openai_msg["tool_calls"] = tool_calls # type: ignore
|
||||
|
||||
if content_parts:
|
||||
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
|
||||
openai_msg["content"] = content_parts[0]["text"]
|
||||
else:
|
||||
openai_msg["content"] = content_parts # type: ignore
|
||||
elif not tool_calls and not reasoning_parts:
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def _convert_block(
|
||||
cls,
|
||||
block,
|
||||
role: str,
|
||||
content_parts: list[dict[str, Any]],
|
||||
tool_calls: list[dict[str, Any]],
|
||||
reasoning_parts: list[str],
|
||||
openai_messages: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Convert individual content block"""
|
||||
if block.type == "text" and block.text:
|
||||
content_parts.append({"type": "text", "text": block.text})
|
||||
elif block.type == "image" and block.source:
|
||||
image_url = cls._convert_image_source_to_url(block.source)
|
||||
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
elif block.type == "thinking" and block.thinking is not None:
|
||||
reasoning_parts.append(block.thinking)
|
||||
elif block.type == "tool_use":
|
||||
cls._convert_tool_use_block(block, tool_calls)
|
||||
elif block.type == "tool_result":
|
||||
cls._convert_tool_result_block(block, role, openai_messages, content_parts)
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_use_block(cls, block, tool_calls: list[dict[str, Any]]) -> None:
|
||||
"""Convert tool_use block to OpenAI function call format"""
|
||||
tool_call = {
|
||||
"id": block.id or f"call_{int(time.time())}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.name or "",
|
||||
"arguments": json.dumps(block.input or {}),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_result_block(
|
||||
cls,
|
||||
block,
|
||||
role: str,
|
||||
openai_messages: list[dict[str, Any]],
|
||||
content_parts: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Convert tool_result block to OpenAI format"""
|
||||
if role == "user":
|
||||
cls._convert_user_tool_result(block, openai_messages)
|
||||
else:
|
||||
tool_result_text = str(block.content) if block.content else ""
|
||||
content_parts.append(
|
||||
{"type": "text", "text": f"Tool result: {tool_result_text}"}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _convert_user_tool_result(
|
||||
cls, block, openai_messages: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Convert user tool_result with text and image support"""
|
||||
tool_text = ""
|
||||
tool_image_urls: list[str] = []
|
||||
|
||||
if isinstance(block.content, str):
|
||||
tool_text = block.content
|
||||
elif isinstance(block.content, list):
|
||||
text_parts: list[str] = []
|
||||
for item in block.content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item_type == "image":
|
||||
source = item.get("source", {})
|
||||
url = cls._convert_image_source_to_url(source)
|
||||
if url:
|
||||
tool_image_urls.append(url)
|
||||
tool_text = "\n".join(text_parts)
|
||||
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.tool_use_id or "",
|
||||
"content": tool_text or "",
|
||||
}
|
||||
)
|
||||
|
||||
if tool_image_urls:
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [ # type: ignore[dict-item]
|
||||
{"type": "image_url", "image_url": {"url": img}}
|
||||
for img in tool_image_urls
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_base_request(
|
||||
cls,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
openai_messages: list[dict[str, Any]],
|
||||
) -> ChatCompletionRequest:
|
||||
"""Build base ChatCompletionRequest"""
|
||||
if isinstance(anthropic_request, AnthropicCountTokensRequest):
|
||||
return ChatCompletionRequest(
|
||||
model=anthropic_request.model,
|
||||
messages=openai_messages,
|
||||
)
|
||||
|
||||
return ChatCompletionRequest(
|
||||
model=anthropic_request.model,
|
||||
messages=openai_messages,
|
||||
max_tokens=anthropic_request.max_tokens,
|
||||
@@ -248,19 +318,38 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
top_k=anthropic_request.top_k,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _handle_streaming_options(
|
||||
cls,
|
||||
req: ChatCompletionRequest,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
) -> None:
|
||||
"""Handle streaming configuration"""
|
||||
if isinstance(anthropic_request, AnthropicCountTokensRequest):
|
||||
return
|
||||
if anthropic_request.stream:
|
||||
req.stream = anthropic_request.stream
|
||||
req.stream_options = StreamOptions.validate(
|
||||
req.stream_options = StreamOptions.model_validate(
|
||||
{"include_usage": True, "continuous_usage_stats": True}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_choice(
|
||||
cls,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
req: ChatCompletionRequest,
|
||||
) -> None:
|
||||
"""Convert Anthropic tool_choice to OpenAI format"""
|
||||
if anthropic_request.tool_choice is None:
|
||||
req.tool_choice = None
|
||||
elif anthropic_request.tool_choice.type == "auto":
|
||||
return
|
||||
|
||||
tool_choice_type = anthropic_request.tool_choice.type
|
||||
if tool_choice_type == "auto":
|
||||
req.tool_choice = "auto"
|
||||
elif anthropic_request.tool_choice.type == "any":
|
||||
elif tool_choice_type == "any":
|
||||
req.tool_choice = "required"
|
||||
elif anthropic_request.tool_choice.type == "tool":
|
||||
elif tool_choice_type == "tool":
|
||||
req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
@@ -268,9 +357,17 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
}
|
||||
)
|
||||
|
||||
tools = []
|
||||
@classmethod
|
||||
def _convert_tools(
|
||||
cls,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
req: ChatCompletionRequest,
|
||||
) -> None:
|
||||
"""Convert Anthropic tools to OpenAI format"""
|
||||
if anthropic_request.tools is None:
|
||||
return req
|
||||
return
|
||||
|
||||
tools = []
|
||||
for tool in anthropic_request.tools:
|
||||
tools.append(
|
||||
ChatCompletionToolsParam.model_validate(
|
||||
@@ -284,10 +381,10 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if req.tool_choice is None:
|
||||
req.tool_choice = "auto"
|
||||
req.tools = tools
|
||||
return req
|
||||
|
||||
async def create_messages(
|
||||
self,
|
||||
@@ -670,3 +767,31 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
data = error_response.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "error")
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
request: AnthropicCountTokensRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> AnthropicCountTokensResponse | ErrorResponse:
|
||||
"""Implements Anthropic's messages.count_tokens endpoint."""
|
||||
chat_req = self._convert_anthropic_to_openai_request(request)
|
||||
result = await self.render_chat_request(chat_req)
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
|
||||
_, engine_prompts = result
|
||||
|
||||
input_tokens = sum( # type: ignore
|
||||
len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc]
|
||||
for prompt in engine_prompts
|
||||
if "prompt_token_ids" in prompt
|
||||
)
|
||||
|
||||
response = AnthropicCountTokensResponse(
|
||||
input_tokens=input_tokens,
|
||||
context_management=AnthropicContextManagement(
|
||||
original_input_tokens=input_tokens
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user