Add /v1/chat/completions/batch endpoint for batched chat completions (#38011)

Signed-off-by: Matej Rojec <64556640+MatejRojec@users.noreply.github.com>
This commit is contained in:
Matej Rojec
2026-03-26 05:13:33 +01:00
committed by GitHub
parent e6bf9f15ec
commit 2908094567
8 changed files with 771 additions and 21 deletions

View File

@@ -0,0 +1,194 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Examples of batched chat completions via the vLLM OpenAI-compatible API.
The /v1/chat/completions/batch endpoint accepts ``messages`` as a list of
conversations. Each conversation is processed independently and the response
contains one choice per conversation, indexed 0, 1, ..., N-1.
Start a server first, e.g.:
vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000
Current limitations compared to /v1/chat/completions:
- Streaming is not supported.
- Tool use is not supported.
- Beam search is not supported.
"""
import json
import os
import httpx
BASE_URL = os.environ.get("VLLM_BASE_URL", "http://localhost:8000")
MODEL = os.environ.get("VLLM_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
BATCH_URL = f"{BASE_URL}/v1/chat/completions/batch"
def post_batch(payload: dict) -> dict:
response = httpx.post(BATCH_URL, json=payload, timeout=60)
response.raise_for_status()
return response.json()
def main() -> None:
print("=== Example 1a: single conversation (standard endpoint) ===")
response = httpx.post(
f"{BASE_URL}/v1/chat/completions",
json={
"model": MODEL,
"messages": [{"role": "user", "content": "What is the capital of Japan?"}],
},
timeout=60,
)
response.raise_for_status()
data = response.json()
for choice in data["choices"]:
print(f" [{choice['index']}] {choice['message']['content']}")
print("\n=== Example 1b: batched plain text (2 conversations) ===")
data = post_batch(
{
"model": MODEL,
"messages": [
[{"role": "user", "content": "What is the capital of France?"}],
[{"role": "user", "content": "What is the capital of Japan?"}],
],
}
)
for choice in data["choices"]:
print(f" [{choice['index']}] {choice['message']['content']}")
print("\n=== Example 2: batch with regex constraint (yes|no) ===")
data = post_batch(
{
"model": MODEL,
"messages": [
[{"role": "user", "content": "Is the sky blue? Answer yes or no."}],
[{"role": "user", "content": "Is fire cold? Answer yes or no."}],
],
"structured_outputs": {"regex": "(yes|no)"},
}
)
for choice in data["choices"]:
print(f" [{choice['index']}] {choice['message']['content']}")
print("\n=== Example 3: batch with json_schema ===")
person_schema = {
"type": "object",
"properties": {
"name": {"type": "string", "description": "Full name of the person"},
"age": {"type": "integer", "description": "Age in years"},
},
"required": ["name", "age"],
}
data = post_batch(
{
"model": MODEL,
"messages": [
[
{
"role": "user",
"content": "Describe the person: name Alice, age 30.",
}
],
[{"role": "user", "content": "Describe the person: name Bob, age 25."}],
],
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "person",
"strict": True,
"schema": person_schema,
},
},
}
)
for choice in data["choices"]:
person = json.loads(choice["message"]["content"])
print(f" [{choice['index']}] {person}")
print("\n=== Example 4: batch book summaries ===")
book_schema = {
"type": "object",
"properties": {
"author": {
"type": "string",
"description": "Full name of the author",
},
"num_pages": {
"type": "integer",
"description": "Number of pages in the book",
},
"short_summary": {
"type": "string",
"description": "A one-sentence summary of the book",
},
"long_summary": {
"type": "string",
"description": (
"A detailed two to three sentence summary covering "
"the main themes and plot"
),
},
},
"required": ["author", "num_pages", "short_summary", "long_summary"],
}
system_msg = {
"role": "system",
"content": (
"You are a literary analyst. Extract structured information "
"from book descriptions."
),
}
data = post_batch(
{
"model": MODEL,
"messages": [
[
system_msg,
{
"role": "user",
"content": (
"Extract information from this book: '1984' by George"
" Orwell, published in 1949, 328 pages. A dystopian"
" novel set in a totalitarian society ruled by Big"
" Brother, following Winston Smith as he secretly"
" rebels against the oppressive Party that surveils"
" and controls every aspect of life."
),
},
],
[
system_msg,
{
"role": "user",
"content": (
"Extract information from this book: 'The Hitchhiker's"
" Guide to the Galaxy' by Douglas Adams, published in"
" 1979, 193 pages. A comedic science fiction novel"
" following Arthur Dent, an ordinary Englishman who is"
" whisked off Earth moments before it is demolished to"
" make way for a hyperspace bypass, and his subsequent"
" absurd adventures across the universe."
),
},
],
],
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "book_summary",
"strict": True,
"schema": book_schema,
},
},
}
)
for choice in data["choices"]:
book = json.loads(choice["message"]["content"])
print(f" [{choice['index']}] {book}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,113 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import httpx
import pytest
from tests.utils import RemoteOpenAIServer
# any model with a chat template defined in tokenizer_config should work here
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager",
]
@pytest.fixture(scope="module")
def server(default_server_args):
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_batched_chat_completions(
server: RemoteOpenAIServer, model_name: str
) -> None:
conversations = [
[{"role": "user", "content": "Reply with exactly the word: alpha"}],
[{"role": "user", "content": "Reply with exactly the word: beta"}],
]
async with httpx.AsyncClient() as http_client:
response = await http_client.post(
f"{server.url_for('v1/chat/completions/batch')}",
json={
"model": model_name,
"messages": conversations,
},
timeout=60,
)
assert response.status_code == 200, response.text
data = response.json()
choices = data["choices"]
assert len(choices) == 2
indices = {choice["index"] for choice in choices}
assert indices == {0, 1}
# Each conversation should produce a non-empty text response.
for choice in choices:
assert choice["message"]["content"]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_batched_chat_completions_with_json_schema(
server: RemoteOpenAIServer, model_name: str
) -> None:
schema = {
"type": "object",
"properties": {
"answer": {"type": "string", "enum": ["yes", "no"]},
},
"required": ["answer"],
}
conversations = [
[{"role": "user", "content": "Is the sky blue? Answer in JSON."}],
[{"role": "user", "content": "Is fire cold? Answer in JSON."}],
]
async with httpx.AsyncClient() as http_client:
response = await http_client.post(
f"{server.url_for('v1/chat/completions/batch')}",
json={
"model": model_name,
"messages": conversations,
"response_format": {
"type": "json_schema",
"json_schema": {"name": "answer", "schema": schema, "strict": True},
},
},
timeout=60,
)
assert response.status_code == 200, response.text
data = response.json()
choices = data["choices"]
assert len(choices) == 2
for choice in choices:
parsed = json.loads(choice["message"]["content"])
assert "answer" in parsed
assert parsed["answer"] in ("yes", "no")

View File

@@ -174,6 +174,7 @@ def test_openapi_stateless(case: Case):
timeout = {
# requires a longer timeout
("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS,
("POST", "/v1/chat/completions/batch"): LONG_TIMEOUT_SECONDS,
("POST", "/v1/completions"): LONG_TIMEOUT_SECONDS,
("POST", "/v1/messages"): LONG_TIMEOUT_SECONDS,
}.get(key, DEFAULT_TIMEOUT_SECONDS)

View File

@@ -7,7 +7,9 @@ from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.chat_completion.batch_serving import OpenAIServingChatBatch
from vllm.entrypoints.openai.chat_completion.protocol import (
BatchChatCompletionRequest,
ChatCompletionRequest,
ChatCompletionResponse,
)
@@ -31,6 +33,10 @@ def chat(request: Request) -> OpenAIServingChat | None:
return request.app.state.openai_serving_chat
def batch_chat(request: Request) -> OpenAIServingChatBatch | None:
return request.app.state.openai_serving_chat_batch
@router.post(
"/v1/chat/completions",
dependencies=[Depends(validate_json_request)],
@@ -68,5 +74,33 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post(
"/v1/chat/completions/batch",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_batch_chat_completion(
request: BatchChatCompletionRequest, raw_request: Request
):
handler = batch_chat(raw_request)
if handler is None:
raise NotImplementedError("The model does not support Chat Completions API")
result = await handler.create_batch_chat_completion(request, raw_request)
if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
return JSONResponse(content=result.model_dump())
def attach_router(app: FastAPI):
app.include_router(router)

View File

@@ -0,0 +1,317 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from fastapi import Request
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.chat_completion.protocol import (
BatchChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs import EngineInput
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.reasoning import ReasoningParser
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
logger = init_logger(__name__)
class OpenAIServingChatBatch(OpenAIServingChat):
"""Extends OpenAIServingChat with the /v1/chat/completions/batch endpoint.
Processes N conversations from a single request concurrently and returns
one choice per conversation indexed 0, 1, ..., N-1.
"""
async def render_batch_chat_request(
self,
request: BatchChatCompletionRequest,
) -> tuple[list[list[ConversationMessage]], list[EngineInput]] | ErrorResponse:
"""Validate the model and preprocess a batched chat completion request.
Performs engine-aware checks then delegates per-conversation
preprocessing to OpenAIServingRender, validating the chat template
once for the whole batch.
Returns:
A tuple of (all_conversations, engine_prompts) on success — one
entry per conversation — or an ErrorResponse on failure.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
return error_check_ret
if self.engine_client.errored:
raise self.engine_client.dead_error
render = self.openai_serving_render
if not render.use_harmony:
# Common case: validate the chat template once for the whole batch.
error_check_ret = render.validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=render.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
tool_parser = render.tool_parser
tool_dicts: list[dict] | None = None
all_conversations: list[list[ConversationMessage]] = []
all_engine_prompts: list[EngineInput] = []
for messages in request.messages:
single_request = request.to_chat_completion_request(messages)
if render.use_harmony:
conversation, engine_prompts = render._make_request_with_harmony(
single_request, should_include_tools=tool_dicts is not None
)
else:
conversation, engine_prompts = await render.preprocess_chat(
single_request,
messages,
default_template=render.chat_template,
default_template_content_format=render.chat_template_content_format,
default_template_kwargs=render.default_chat_template_kwargs,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
)
all_conversations.append(conversation)
all_engine_prompts.append(engine_prompts[0])
return all_conversations, all_engine_prompts
async def create_batch_chat_completion(
self,
request: BatchChatCompletionRequest,
raw_request: Request | None = None,
) -> ChatCompletionResponse | ErrorResponse:
"""Batch Chat Completion endpoint (/v1/chat/completions/batch).
Processes N conversations from a single request concurrently and
returns one choice per conversation indexed 0, 1, ..., N-1.
Streaming, tool use, and beam search are not supported.
"""
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
reasoning_parser: ReasoningParser | None = None
if self.reasoning_parser_cls:
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
request.chat_template_kwargs,
self.default_chat_template_kwargs,
)
reasoning_parser = self.reasoning_parser_cls(
tokenizer,
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
)
render_result = await self.render_batch_chat_request(request)
if isinstance(render_result, ErrorResponse):
return render_result
all_conversations, engine_prompts = render_result
request_id = (
f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}"
)
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
lora_request = self._maybe_get_adapters(request, supports_default_mm_loras=True)
model_name = self.models.model_name(lora_request)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = []
for i, engine_prompt in enumerate(engine_prompts):
sub_request_id = f"{request_id}_{i}"
max_tokens = get_max_tokens(
max_model_len,
request.max_completion_tokens
if request.max_completion_tokens is not None
else request.max_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
self.override_max_tokens,
)
single_request = request.to_chat_completion_request(request.messages[i])
sampling_params = single_request.to_sampling_params(
max_tokens, self.default_sampling_params
)
self._log_inputs(
sub_request_id,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
generators.append(
self.engine_client.generate(
engine_prompt,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority if hasattr(request, "priority") else 0,
data_parallel_rank=data_parallel_rank,
reasoning_ended=None,
)
)
return await self.chat_completion_full_generator_batch(
request, # type: ignore[arg-type]
generators,
request_id,
model_name,
all_conversations,
tokenizer,
request_metadata,
reasoning_parser,
)
async def chat_completion_full_generator_batch(
self,
request: BatchChatCompletionRequest, # type: ignore[override]
generators: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
model_name: str,
all_conversations: list[list[ConversationMessage]],
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
reasoning_parser: ReasoningParser | None = None,
) -> ErrorResponse | ChatCompletionResponse:
"""Handle batched (non-streaming) chat completions.
Fans out N generators (one per conversation in the batch), collects
the final output for each, and assembles a single
``ChatCompletionResponse`` whose ``choices`` are indexed 0,...,N-1.
Tool-use and streaming are rejected upstream by the
``check_batch_mode`` validator, so neither needs to be handled here.
"""
created_time = int(time.time())
role = self.get_chat_request_role(request) # type: ignore[arg-type]
final_results: dict[int, RequestOutput] = {}
try:
async for prompt_idx, res in merge_async_iterators(*generators):
final_results[prompt_idx] = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
choices: list[ChatCompletionResponseChoice] = []
total_prompt_tokens = 0
total_completion_tokens = 0
for prompt_idx in range(len(generators)):
final_res = final_results.get(prompt_idx)
if final_res is None:
return self.create_error_response(
f"No output received from the engine for prompt {prompt_idx}.",
err_type="InternalServerError",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
if final_res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
total_prompt_tokens += num_prompt_tokens
total_completion_tokens += sum(
len(output.token_ids) for output in final_res.outputs
)
for output in final_res.outputs:
self._raise_if_error(output.finish_reason, request_id)
if request.logprobs and request.top_logprobs is not None:
assert output.logprobs is not None, "Did not output logprobs"
logprobs = self._create_chat_logprobs(
token_ids=output.token_ids,
top_logprobs=output.logprobs,
num_output_top_logprobs=request.top_logprobs,
tokenizer=tokenizer,
return_as_token_id=request.return_token_ids,
)
else:
logprobs = None
if reasoning_parser:
reasoning, content = reasoning_parser.extract_reasoning(
output.text,
request=request, # type: ignore[arg-type]
)
if not getattr(request, "include_reasoning", True):
reasoning = None
else:
reasoning = None
content = output.text
message = ChatMessage(role=role, reasoning=reasoning, content=content)
if request.echo:
conversation = all_conversations[prompt_idx]
last_msg_content: str | list[dict[str, str]] = ""
if conversation and "content" in conversation[-1]:
last_msg_content = conversation[-1]["content"] or ""
if isinstance(last_msg_content, list):
last_msg_content = "\n".join(
msg["text"] for msg in last_msg_content
)
message.content = last_msg_content + (message.content or "")
choice_data = ChatCompletionResponseChoice(
index=prompt_idx,
message=message,
logprobs=logprobs,
finish_reason=output.finish_reason
if output.finish_reason
else "stop",
stop_reason=output.stop_reason,
token_ids=(
as_list(output.token_ids) if request.return_token_ids else None
),
)
choices.append(choice_data)
usage = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
)
request_metadata.final_usage_info = usage
choices.sort(key=lambda c: c.index)
return ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)

View File

@@ -787,3 +787,83 @@ class ChatCompletionRequest(OpenAIBaseModel):
if data.get("reasoning_effort") == "none":
data["include_reasoning"] = False
return data
class BatchChatCompletionRequest(OpenAIBaseModel):
"""Request model for the /v1/chat/completions/batch endpoint.
Accepts the same fields as ChatCompletionRequest except that ``messages``
is a list of conversations (each conversation is a
``list[ChatCompletionMessageParam]``). Each conversation is processed
independently and the response contains one choice per conversation,
indexed 0, 1, ..., N-1.
Current limitations compared to the single-conversation endpoint:
- Streaming is not supported (``stream`` must be False or omitted).
- Tool use is not supported (``tools`` must be omitted).
- Beam search is not supported (``use_beam_search`` must be False or omitted).
- The ``n`` parameter must be 1 (or omitted).
"""
messages: list[list[ChatCompletionMessageParam]] = Field(..., min_length=1)
model: str | None = None
# Shared sampling / generation fields — mirror ChatCompletionRequest.
frequency_penalty: float | None = 0.0
logit_bias: dict[str, float] | None = None
logprobs: bool | None = False
top_logprobs: int | None = 0
max_tokens: int | None = None
max_completion_tokens: int | None = None
n: int | None = 1
presence_penalty: float | None = 0.0
response_format: Any | None = None
seed: int | None = Field(None, ge=_INT64_MIN, le=_INT64_MAX)
stop: str | list[str] | None = Field(default_factory=list)
temperature: float | None = 0.7
top_p: float | None = 1.0
user: str | None = None
# vLLM extensions
best_of: int | None = None
use_beam_search: bool = False
top_k: int | None = None
min_p: float | None = 0.0
repetition_penalty: float | None = 1.0
length_penalty: float | None = 1.0
early_stopping: bool = False
structured_outputs: StructuredOutputsParams | None = None
request_id: str | None = None
add_generation_prompt: bool = True
continue_final_message: bool = False
chat_template: str | None = None
chat_template_kwargs: dict[str, Any] | None = None
include_stop_str_in_output: bool = False
guided_decoding_backend: str | None = None
echo: bool = False
return_token_ids: bool = False
@model_validator(mode="before")
@classmethod
def check_batch_mode(cls, data: Any) -> Any:
if isinstance(data, BatchChatCompletionRequest):
data = data.model_dump(exclude_unset=True)
if data.get("use_beam_search"):
raise ValueError(
"Batch chat completions do not support beam search. "
"Please set `use_beam_search` to False."
)
n = data.get("n", 1)
if n is not None and n != 1:
raise ValueError(
"Batch chat completions do not support `n > 1`. Please set `n` to 1."
)
return data
def to_chat_completion_request(
self, messages: list[ChatCompletionMessageParam]
) -> ChatCompletionRequest:
"""Build a single-conversation ChatCompletionRequest from one conversation."""
data = self.model_dump(exclude={"messages"}, exclude_none=True)
data["messages"] = messages
return ChatCompletionRequest.model_validate(data)

View File

@@ -26,6 +26,7 @@ from vllm.entrypoints.chat_utils import (
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import (
BatchChatCompletionRequest,
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionResponse,
@@ -124,7 +125,10 @@ CompletionLikeRequest: TypeAlias = (
)
ChatLikeRequest: TypeAlias = (
ChatCompletionRequest | TokenizeChatRequest | PoolingChatRequest
ChatCompletionRequest
| BatchChatCompletionRequest
| TokenizeChatRequest
| PoolingChatRequest
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest

View File

@@ -56,6 +56,9 @@ async def init_generate_state(
MCPToolServer,
ToolServer,
)
from vllm.entrypoints.openai.chat_completion.batch_serving import (
OpenAIServingChatBatch,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
@@ -96,27 +99,31 @@ async def init_generate_state(
if "generate" in supported_tasks
else None
)
_chat_kwargs = dict(
engine_client=engine_client,
models=state.openai_serving_models,
response_role=args.response_role,
openai_serving_render=state.openai_serving_render,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
default_chat_template_kwargs=args.default_chat_template_kwargs,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
enable_log_deltas=args.enable_log_deltas,
)
state.openai_serving_chat = (
OpenAIServingChat(
engine_client,
state.openai_serving_models,
args.response_role,
openai_serving_render=state.openai_serving_render,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
default_chat_template_kwargs=args.default_chat_template_kwargs,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
enable_log_deltas=args.enable_log_deltas,
)
OpenAIServingChat(**_chat_kwargs) if "generate" in supported_tasks else None
)
state.openai_serving_chat_batch = (
OpenAIServingChatBatch(**_chat_kwargs)
if "generate" in supported_tasks
else None
)