diff --git a/vllm/entrypoints/anthropic/api_router.py b/vllm/entrypoints/anthropic/api_router.py index 1494dd7e5..2b65fff50 100644 --- a/vllm/entrypoints/anthropic/api_router.py +++ b/vllm/entrypoints/anthropic/api_router.py @@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse from vllm.entrypoints.anthropic.protocol import ( + AnthropicCountTokensRequest, + AnthropicCountTokensResponse, AnthropicError, AnthropicErrorResponse, AnthropicMessagesRequest, @@ -31,6 +33,18 @@ def messages(request: Request) -> AnthropicServingMessages: return request.app.state.anthropic_serving_messages +def translate_error_response(response: ErrorResponse) -> JSONResponse: + anthropic_error = AnthropicErrorResponse( + error=AnthropicError( + type=response.error.type, + message=response.error.message, + ) + ) + return JSONResponse( + status_code=response.error.code, content=anthropic_error.model_dump() + ) + + @router.post( "/v1/messages", dependencies=[Depends(validate_json_request)], @@ -44,17 +58,6 @@ def messages(request: Request) -> AnthropicServingMessages: @with_cancellation @load_aware_call async def create_messages(request: AnthropicMessagesRequest, raw_request: Request): - def translate_error_response(response: ErrorResponse) -> JSONResponse: - anthropic_error = AnthropicErrorResponse( - error=AnthropicError( - type=response.error.type, - message=response.error.message, - ) - ) - return JSONResponse( - status_code=response.error.code, content=anthropic_error.model_dump() - ) - handler = messages(raw_request) if handler is None: base_server = raw_request.app.state.openai_serving_tokenization @@ -88,5 +91,46 @@ async def create_messages(request: AnthropicMessagesRequest, raw_request: Reques return StreamingResponse(content=generator, media_type="text/event-stream") +@router.post( + "/v1/messages/count_tokens", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"model": AnthropicCountTokensResponse}, + HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse}, + }, +) +@load_aware_call +@with_cancellation +async def count_tokens(request: AnthropicCountTokensRequest, raw_request: Request): + handler = messages(raw_request) + if handler is None: + base_server = raw_request.app.state.openai_serving_tokenization + error = base_server.create_error_response( + message="The model does not support Messages API" + ) + return translate_error_response(error) + + try: + response = await handler.count_tokens(request, raw_request) + except Exception as e: + logger.exception("Error in count_tokens: %s", e) + return JSONResponse( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + content=AnthropicErrorResponse( + error=AnthropicError( + type="internal_error", + message=str(e), + ) + ).model_dump(), + ) + + if isinstance(response, ErrorResponse): + return translate_error_response(response) + + return JSONResponse(content=response.model_dump(exclude_none=True)) + + def attach_router(app: FastAPI): app.include_router(router) diff --git a/vllm/entrypoints/anthropic/protocol.py b/vllm/entrypoints/anthropic/protocol.py index 3081e9781..19ca28f1d 100644 --- a/vllm/entrypoints/anthropic/protocol.py +++ b/vllm/entrypoints/anthropic/protocol.py @@ -175,3 +175,33 @@ class AnthropicMessagesResponse(BaseModel): def model_post_init(self, __context): if not self.id: self.id = f"msg_{int(time.time() * 1000)}" + + +class AnthropicContextManagement(BaseModel): + """Context management information for token counting.""" + + original_input_tokens: int + + +class AnthropicCountTokensRequest(BaseModel): + """Anthropic messages.count_tokens request""" + + model: str + messages: list[AnthropicMessage] + system: str | list[AnthropicContentBlock] | None = None + tool_choice: AnthropicToolChoice | None = None + tools: list[AnthropicTool] | None = None + + @field_validator("model") + @classmethod + def validate_model(cls, v): + if not v: + raise ValueError("Model is required") + return v + + +class AnthropicCountTokensResponse(BaseModel): + """Anthropic messages.count_tokens response""" + + input_tokens: int + context_management: AnthropicContextManagement | None = None diff --git a/vllm/entrypoints/anthropic/serving.py b/vllm/entrypoints/anthropic/serving.py index 82af26476..f0110de38 100644 --- a/vllm/entrypoints/anthropic/serving.py +++ b/vllm/entrypoints/anthropic/serving.py @@ -17,6 +17,9 @@ from fastapi import Request from vllm.engine.protocol import EngineClient from vllm.entrypoints.anthropic.protocol import ( AnthropicContentBlock, + AnthropicContextManagement, + AnthropicCountTokensRequest, + AnthropicCountTokensResponse, AnthropicDelta, AnthropicError, AnthropicMessagesRequest, @@ -109,135 +112,202 @@ class AnthropicServingMessages(OpenAIServingChat): @classmethod def _convert_anthropic_to_openai_request( - cls, anthropic_request: AnthropicMessagesRequest + cls, anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest ) -> ChatCompletionRequest: """Convert Anthropic message format to OpenAI format""" - openai_messages = [] + openai_messages: list[dict[str, Any]] = [] - # Add system message if provided - if anthropic_request.system: - if isinstance(anthropic_request.system, str): - openai_messages.append( - {"role": "system", "content": anthropic_request.system} - ) - else: - system_prompt = "" - for block in anthropic_request.system: - if block.type == "text" and block.text: - system_prompt += block.text - openai_messages.append({"role": "system", "content": system_prompt}) + cls._convert_system_message(anthropic_request, openai_messages) + cls._convert_messages(anthropic_request.messages, openai_messages) + req = cls._build_base_request(anthropic_request, openai_messages) + cls._handle_streaming_options(req, anthropic_request) + cls._convert_tool_choice(anthropic_request, req) + cls._convert_tools(anthropic_request, req) + return req - for msg in anthropic_request.messages: + @classmethod + def _convert_system_message( + cls, + anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest, + openai_messages: list[dict[str, Any]], + ) -> None: + """Convert Anthropic system message to OpenAI format""" + if not anthropic_request.system: + return + + if isinstance(anthropic_request.system, str): + openai_messages.append( + {"role": "system", "content": anthropic_request.system} + ) + else: + system_prompt = "" + for block in anthropic_request.system: + if block.type == "text" and block.text: + system_prompt += block.text + openai_messages.append({"role": "system", "content": system_prompt}) + + @classmethod + def _convert_messages( + cls, messages: list, openai_messages: list[dict[str, Any]] + ) -> None: + """Convert Anthropic messages to OpenAI format""" + for msg in messages: openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore + if isinstance(msg.content, str): openai_msg["content"] = msg.content else: - # Handle complex content blocks - content_parts: list[dict[str, Any]] = [] - tool_calls: list[dict[str, Any]] = [] - reasoning_parts: list[str] = [] - - for block in msg.content: - if block.type == "text" and block.text: - content_parts.append({"type": "text", "text": block.text}) - elif block.type == "image" and block.source: - image_url = cls._convert_image_source_to_url(block.source) - content_parts.append( - { - "type": "image_url", - "image_url": {"url": image_url}, - } - ) - elif block.type == "thinking" and block.thinking is not None: - reasoning_parts.append(block.thinking) - elif block.type == "tool_use": - # Convert tool use to function call format - tool_call = { - "id": block.id or f"call_{int(time.time())}", - "type": "function", - "function": { - "name": block.name or "", - "arguments": json.dumps(block.input or {}), - }, - } - tool_calls.append(tool_call) - elif block.type == "tool_result": - if msg.role == "user": - # Parse tool_result content which can be - # a string or a list of content blocks - # (text, image, etc.) - tool_text = "" - tool_image_urls: list[str] = [] - if isinstance(block.content, str): - tool_text = block.content - elif isinstance(block.content, list): - text_parts: list[str] = [] - for item in block.content: - if not isinstance(item, dict): - continue - item_type = item.get("type") - if item_type == "text": - text_parts.append(item.get("text", "")) - elif item_type == "image": - source = item.get("source", {}) - url = cls._convert_image_source_to_url(source) - if url: - tool_image_urls.append(url) - tool_text = "\n".join(text_parts) - openai_messages.append( - { - "role": "tool", - "tool_call_id": block.tool_use_id or "", - "content": tool_text or "", - } - ) - # OpenAI tool messages only support string - # content, so inject images from tool - # results as a follow-up user message - if tool_image_urls: - openai_messages.append( - { - "role": "user", - "content": [ # type: ignore[dict-item] - { - "type": "image_url", - "image_url": {"url": img}, - } - for img in tool_image_urls - ], - } - ) - else: - # Assistant tool result becomes regular text - tool_result_text = ( - str(block.content) if block.content else "" - ) - content_parts.append( - { - "type": "text", - "text": f"Tool result: {tool_result_text}", - } - ) - - if reasoning_parts: - openai_msg["reasoning"] = "".join(reasoning_parts) - - # Add tool calls to the message if any - if tool_calls: - openai_msg["tool_calls"] = tool_calls # type: ignore - - # Add content parts if any - if content_parts: - if len(content_parts) == 1 and content_parts[0]["type"] == "text": - openai_msg["content"] = content_parts[0]["text"] - else: - openai_msg["content"] = content_parts # type: ignore - elif not tool_calls and not reasoning_parts: - continue + cls._convert_message_content(msg, openai_msg, openai_messages) openai_messages.append(openai_msg) - req = ChatCompletionRequest( + @classmethod + def _convert_message_content( + cls, + msg, + openai_msg: dict[str, Any], + openai_messages: list[dict[str, Any]], + ) -> None: + """Convert complex message content blocks""" + content_parts: list[dict[str, Any]] = [] + tool_calls: list[dict[str, Any]] = [] + reasoning_parts: list[str] = [] + + for block in msg.content: + cls._convert_block( + block, + msg.role, + content_parts, + tool_calls, + reasoning_parts, + openai_messages, + ) + + if reasoning_parts: + openai_msg["reasoning"] = "".join(reasoning_parts) + + if tool_calls: + openai_msg["tool_calls"] = tool_calls # type: ignore + + if content_parts: + if len(content_parts) == 1 and content_parts[0]["type"] == "text": + openai_msg["content"] = content_parts[0]["text"] + else: + openai_msg["content"] = content_parts # type: ignore + elif not tool_calls and not reasoning_parts: + return + + @classmethod + def _convert_block( + cls, + block, + role: str, + content_parts: list[dict[str, Any]], + tool_calls: list[dict[str, Any]], + reasoning_parts: list[str], + openai_messages: list[dict[str, Any]], + ) -> None: + """Convert individual content block""" + if block.type == "text" and block.text: + content_parts.append({"type": "text", "text": block.text}) + elif block.type == "image" and block.source: + image_url = cls._convert_image_source_to_url(block.source) + content_parts.append({"type": "image_url", "image_url": {"url": image_url}}) + elif block.type == "thinking" and block.thinking is not None: + reasoning_parts.append(block.thinking) + elif block.type == "tool_use": + cls._convert_tool_use_block(block, tool_calls) + elif block.type == "tool_result": + cls._convert_tool_result_block(block, role, openai_messages, content_parts) + + @classmethod + def _convert_tool_use_block(cls, block, tool_calls: list[dict[str, Any]]) -> None: + """Convert tool_use block to OpenAI function call format""" + tool_call = { + "id": block.id or f"call_{int(time.time())}", + "type": "function", + "function": { + "name": block.name or "", + "arguments": json.dumps(block.input or {}), + }, + } + tool_calls.append(tool_call) + + @classmethod + def _convert_tool_result_block( + cls, + block, + role: str, + openai_messages: list[dict[str, Any]], + content_parts: list[dict[str, Any]], + ) -> None: + """Convert tool_result block to OpenAI format""" + if role == "user": + cls._convert_user_tool_result(block, openai_messages) + else: + tool_result_text = str(block.content) if block.content else "" + content_parts.append( + {"type": "text", "text": f"Tool result: {tool_result_text}"} + ) + + @classmethod + def _convert_user_tool_result( + cls, block, openai_messages: list[dict[str, Any]] + ) -> None: + """Convert user tool_result with text and image support""" + tool_text = "" + tool_image_urls: list[str] = [] + + if isinstance(block.content, str): + tool_text = block.content + elif isinstance(block.content, list): + text_parts: list[str] = [] + for item in block.content: + if not isinstance(item, dict): + continue + item_type = item.get("type") + if item_type == "text": + text_parts.append(item.get("text", "")) + elif item_type == "image": + source = item.get("source", {}) + url = cls._convert_image_source_to_url(source) + if url: + tool_image_urls.append(url) + tool_text = "\n".join(text_parts) + + openai_messages.append( + { + "role": "tool", + "tool_call_id": block.tool_use_id or "", + "content": tool_text or "", + } + ) + + if tool_image_urls: + openai_messages.append( + { + "role": "user", + "content": [ # type: ignore[dict-item] + {"type": "image_url", "image_url": {"url": img}} + for img in tool_image_urls + ], + } + ) + + @classmethod + def _build_base_request( + cls, + anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest, + openai_messages: list[dict[str, Any]], + ) -> ChatCompletionRequest: + """Build base ChatCompletionRequest""" + if isinstance(anthropic_request, AnthropicCountTokensRequest): + return ChatCompletionRequest( + model=anthropic_request.model, + messages=openai_messages, + ) + + return ChatCompletionRequest( model=anthropic_request.model, messages=openai_messages, max_tokens=anthropic_request.max_tokens, @@ -248,19 +318,38 @@ class AnthropicServingMessages(OpenAIServingChat): top_k=anthropic_request.top_k, ) + @classmethod + def _handle_streaming_options( + cls, + req: ChatCompletionRequest, + anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest, + ) -> None: + """Handle streaming configuration""" + if isinstance(anthropic_request, AnthropicCountTokensRequest): + return if anthropic_request.stream: req.stream = anthropic_request.stream - req.stream_options = StreamOptions.validate( + req.stream_options = StreamOptions.model_validate( {"include_usage": True, "continuous_usage_stats": True} ) + @classmethod + def _convert_tool_choice( + cls, + anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest, + req: ChatCompletionRequest, + ) -> None: + """Convert Anthropic tool_choice to OpenAI format""" if anthropic_request.tool_choice is None: req.tool_choice = None - elif anthropic_request.tool_choice.type == "auto": + return + + tool_choice_type = anthropic_request.tool_choice.type + if tool_choice_type == "auto": req.tool_choice = "auto" - elif anthropic_request.tool_choice.type == "any": + elif tool_choice_type == "any": req.tool_choice = "required" - elif anthropic_request.tool_choice.type == "tool": + elif tool_choice_type == "tool": req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate( { "type": "function", @@ -268,9 +357,17 @@ class AnthropicServingMessages(OpenAIServingChat): } ) - tools = [] + @classmethod + def _convert_tools( + cls, + anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest, + req: ChatCompletionRequest, + ) -> None: + """Convert Anthropic tools to OpenAI format""" if anthropic_request.tools is None: - return req + return + + tools = [] for tool in anthropic_request.tools: tools.append( ChatCompletionToolsParam.model_validate( @@ -284,10 +381,10 @@ class AnthropicServingMessages(OpenAIServingChat): } ) ) + if req.tool_choice is None: req.tool_choice = "auto" req.tools = tools - return req async def create_messages( self, @@ -670,3 +767,31 @@ class AnthropicServingMessages(OpenAIServingChat): data = error_response.model_dump_json(exclude_unset=True) yield wrap_data_with_event(data, "error") yield "data: [DONE]\n\n" + + async def count_tokens( + self, + request: AnthropicCountTokensRequest, + raw_request: Request | None = None, + ) -> AnthropicCountTokensResponse | ErrorResponse: + """Implements Anthropic's messages.count_tokens endpoint.""" + chat_req = self._convert_anthropic_to_openai_request(request) + result = await self.render_chat_request(chat_req) + if isinstance(result, ErrorResponse): + return result + + _, engine_prompts = result + + input_tokens = sum( # type: ignore + len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc] + for prompt in engine_prompts + if "prompt_token_ids" in prompt + ) + + response = AnthropicCountTokensResponse( + input_tokens=input_tokens, + context_management=AnthropicContextManagement( + original_input_tokens=input_tokens + ), + ) + + return response