[Frontend] Complete OpenAI render delegation (#37287)

Signed-off-by: Sage Ahrac <sagiahrak@gmail.com>
This commit is contained in:
Sage
2026-03-17 15:53:55 +02:00
committed by GitHub
parent 56cb1baa66
commit 59192dfd39
8 changed files with 140 additions and 244 deletions

View File

@@ -159,6 +159,7 @@ class TestInitializeToolSessions:
instance = OpenAIServingResponses(
engine_client=engine_client,
models=models,
openai_serving_render=MagicMock(),
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
@@ -245,6 +246,7 @@ class TestValidateGeneratorInput:
instance = OpenAIServingResponses(
engine_client=engine_client,
models=models,
openai_serving_render=MagicMock(),
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
@@ -308,6 +310,7 @@ async def test_reasoning_tokens_counted_for_text_reasoning_model(monkeypatch):
serving = OpenAIServingResponses(
engine_client=engine_client,
models=models,
openai_serving_render=MagicMock(),
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
@@ -607,6 +610,7 @@ def _make_serving_instance_with_reasoning():
serving = OpenAIServingResponses(
engine_client=engine_client,
models=models,
openai_serving_render=MagicMock(),
request_logger=None,
chat_template=None,
chat_template_content_format="auto",

View File

@@ -4,7 +4,7 @@ import asyncio
import contextlib
import json
import time
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from collections.abc import AsyncGenerator, Callable, Mapping
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
@@ -22,9 +22,7 @@ from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import (
@@ -43,19 +41,9 @@ from vllm.entrypoints.openai.engine.protocol import (
GenerationError,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.responses.context import (
ConversationContext,
HarmonyContext,
ParsableContext,
StreamingHarmonyContext,
)
from vllm.entrypoints.openai.responses.protocol import (
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.entrypoints.openai.responses.utils import (
construct_input_messages,
)
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest,
TranscriptionResponse,
@@ -82,26 +70,22 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest,
TokenizeResponse,
)
from vllm.entrypoints.utils import create_error_response, get_max_tokens
from vllm.entrypoints.utils import create_error_response
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import (
ProcessorInputs,
PromptType,
SingletonPrompt,
TokensPrompt,
token_inputs,
)
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers import ChatParams, TokenizeParams
from vllm.renderers.inputs.preprocess import (
extract_prompt_components,
extract_prompt_len,
parse_model_prompt,
prompt_to_seq,
)
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
@@ -116,7 +100,6 @@ from vllm.utils.async_utils import (
collect_from_async_generator,
merge_async_iterators,
)
from vllm.utils.mistral import is_mistral_tokenizer
logger = init_logger(__name__)
@@ -823,109 +806,6 @@ class OpenAIServing:
# Apply server defaults first, then request kwargs override.
return default_chat_template_kwargs | request_chat_template_kwargs
async def _preprocess_completion(
self,
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[ProcessorInputs]:
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
return await self._preprocess_cmpl(request, prompts)
async def _preprocess_cmpl(
self,
request: RendererRequest,
prompts: Sequence[PromptType | bytes],
) -> list[ProcessorInputs]:
renderer = self.renderer
model_config = self.model_config
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
tok_params = request.build_tok_params(model_config)
return await renderer.render_cmpl_async(
parsed_prompts,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
async def _preprocess_chat(
self,
request: RendererChatRequest,
messages: list[ChatCompletionMessageParam],
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
renderer = self.renderer
default_template_kwargs = merge_kwargs(
default_template_kwargs,
dict(
tools=tool_dicts,
tokenize=is_mistral_tokenizer(renderer.tokenizer),
),
)
mm_config = self.model_config.multimodal_config
tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params(
default_template, default_template_content_format
).with_defaults(
default_template_kwargs,
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None),
)
(conversation,), (engine_prompt,) = await renderer.render_chat_async(
[messages],
chat_params,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
if tool_parser is not None:
tool_choice = getattr(request, "tool_choice", "none")
if tool_choice != "none":
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
msg = (
"Tool usage is only supported for Chat Completions API "
"or Responses API requests."
)
raise NotImplementedError(msg)
# TODO: Update adjust_request to accept ResponsesRequest
tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type]
return conversation, [engine_prompt]
def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
return extract_prompt_components(self.model_config, prompt)
@@ -935,109 +815,6 @@ class OpenAIServing:
def _extract_prompt_len(self, prompt: ProcessorInputs):
return extract_prompt_len(self.model_config, prompt)
async def _render_next_turn(
self,
request: ResponsesRequest,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
new_messages = construct_input_messages(
request_input=messages,
)
_, engine_prompts = await self._preprocess_chat(
request,
new_messages,
default_template=chat_template,
default_template_content_format=chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
)
return engine_prompts
async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: ProcessorInputs,
sampling_params: SamplingParams,
context: ConversationContext,
lora_request: LoRARequest | None = None,
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
):
max_model_len = self.model_config.max_model_len
orig_priority = priority
sub_request = 0
while True:
# Ensure that each sub-request has a unique request id.
sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs(
sub_request_id,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
)
async for res in generator:
context.append_output(res)
# NOTE(woosuk): The stop condition is handled by the engine.
yield context
if not context.need_builtin_tool_call():
# The model did not ask for a tool call, so we're done.
break
# Call the tool and update the context with the result.
tool_output = await context.call_tool()
context.append_tool_output(tool_output)
# TODO: uncomment this and enable tool output streaming
# yield context
# Create inputs for the next turn.
# Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion()
engine_prompt = token_inputs(token_ids)
sampling_params.max_tokens = max_model_len - len(token_ids)
elif isinstance(context, ParsableContext):
(engine_prompt,) = await self._render_next_turn(
context.request,
context.parser.response_messages,
context.tool_dicts,
context.tool_parser_cls,
context.chat_template,
context.chat_template_content_format,
)
sampling_params.max_tokens = get_max_tokens(
max_model_len,
context.request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore
self.override_max_tokens, # type: ignore
)
# OPTIMIZATION
priority = orig_priority - 1
sub_request += 1
def _log_inputs(
self,
request_id: str,

View File

@@ -80,6 +80,7 @@ async def init_generate_state(
OpenAIServingResponses(
engine_client,
state.openai_serving_models,
state.openai_serving_render,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
@@ -157,6 +158,7 @@ async def init_generate_state(
ServingTokens(
engine_client,
state.openai_serving_models,
state.openai_serving_render,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,

View File

@@ -5,11 +5,11 @@ import asyncio
import time
import uuid
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping, Sequence
from contextlib import AsyncExitStack
from copy import copy
from http import HTTPStatus
from typing import Final
from typing import Any, Final
from fastapi import Request
from openai.types.responses import (
@@ -86,6 +86,7 @@ from vllm.entrypoints.openai.responses.protocol import (
ResponseCompletedEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseInputOutputItem,
ResponseInputOutputMessage,
ResponseReasoningPartAddedEvent,
ResponseReasoningPartDoneEvent,
@@ -105,16 +106,19 @@ from vllm.entrypoints.openai.responses.utils import (
construct_tool_dicts,
extract_tool_types,
)
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ProcessorInputs, token_inputs
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput
from vllm.parser import ParserManager
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.utils import random_uuid
from vllm.utils.collection_utils import as_list
@@ -165,6 +169,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
engine_client: EngineClient,
models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
@@ -185,6 +190,7 @@ class OpenAIServingResponses(OpenAIServing):
return_tokens_as_token_ids=return_tokens_as_token_ids,
)
self.openai_serving_render = openai_serving_render
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.enable_log_outputs = enable_log_outputs
@@ -587,7 +593,7 @@ class OpenAIServingResponses(OpenAIServing):
prev_response_output=prev_response.output if prev_response else None,
)
_, engine_prompts = await self._preprocess_chat(
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
request,
messages,
default_template=self.chat_template,
@@ -598,6 +604,109 @@ class OpenAIServingResponses(OpenAIServing):
)
return messages, engine_prompts
async def _render_next_turn(
self,
request: ResponsesRequest,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
new_messages = construct_input_messages(
request_input=messages,
)
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
request,
new_messages,
default_template=chat_template,
default_template_content_format=chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
)
return engine_prompts
async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: ProcessorInputs,
sampling_params: SamplingParams,
context: ConversationContext,
lora_request: LoRARequest | None = None,
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
):
max_model_len = self.model_config.max_model_len
orig_priority = priority
sub_request = 0
while True:
# Ensure that each sub-request has a unique request id.
sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs(
sub_request_id,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
)
async for res in generator:
context.append_output(res)
# NOTE(woosuk): The stop condition is handled by the engine.
yield context
if not context.need_builtin_tool_call():
# The model did not ask for a tool call, so we're done.
break
# Call the tool and update the context with the result.
tool_output = await context.call_tool()
context.append_tool_output(tool_output)
# TODO: uncomment this and enable tool output streaming
# yield context
# Create inputs for the next turn.
# Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion()
engine_prompt = token_inputs(token_ids)
sampling_params.max_tokens = max_model_len - len(token_ids)
elif isinstance(context, ParsableContext):
(engine_prompt,) = await self._render_next_turn(
context.request,
context.parser.response_messages,
context.tool_dicts,
context.tool_parser_cls,
context.chat_template,
context.chat_template_content_format,
)
sampling_params.max_tokens = get_max_tokens(
max_model_len,
context.request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore
self.override_max_tokens, # type: ignore
)
# OPTIMIZATION
priority = orig_priority - 1
sub_request += 1
def _make_request_with_harmony(
self,
request: ResponsesRequest,

View File

@@ -68,6 +68,7 @@ def init_pooling_state(
OpenAIServingPooling(
engine_client,
state.openai_serving_models,
state.openai_serving_render,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,

View File

@@ -32,6 +32,7 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.inputs import ProcessorInputs
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
@@ -47,6 +48,7 @@ class OpenAIServingPooling(OpenAIServing):
self,
engine_client: EngineClient,
models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
@@ -59,6 +61,7 @@ class OpenAIServingPooling(OpenAIServing):
request_logger=request_logger,
)
self.openai_serving_render = openai_serving_render
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
@@ -101,12 +104,12 @@ class OpenAIServingPooling(OpenAIServing):
raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id
)
engine_prompts = await self._preprocess_cmpl(
engine_prompts = await self.openai_serving_render.preprocess_cmpl(
request,
prompt_to_seq(raw_prompts),
)
elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template(
error_check_ret = self.openai_serving_render.validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
@@ -114,7 +117,7 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
_, engine_prompts = await self._preprocess_chat(
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
request,
request.messages,
default_template=self.chat_template,
@@ -122,7 +125,7 @@ class OpenAIServingPooling(OpenAIServing):
default_template_kwargs=None,
)
elif isinstance(request, PoolingCompletionRequest):
engine_prompts = await self._preprocess_completion(
engine_prompts = await self.openai_serving_render.preprocess_completion(
request,
prompt_input=request.input,
prompt_embeds=None,

View File

@@ -29,6 +29,7 @@ from vllm.entrypoints.serve.disagg.protocol import (
GenerateResponse,
GenerateResponseChoice,
)
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
@@ -45,6 +46,7 @@ class ServingTokens(OpenAIServing):
self,
engine_client: EngineClient,
models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*,
request_logger: RequestLogger | None,
force_no_detokenize: bool = False,
@@ -58,6 +60,7 @@ class ServingTokens(OpenAIServing):
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
)
self.openai_serving_render = openai_serving_render
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_log_outputs = enable_log_outputs
self.force_no_detokenize = force_no_detokenize
@@ -96,7 +99,7 @@ class ServingTokens(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
engine_prompts = await self._preprocess_completion(
engine_prompts = await self.openai_serving_render.preprocess_completion(
request,
prompt_input=request.token_ids,
prompt_embeds=None,

View File

@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
parse_chat_inputs_to_harmony_messages,
render_for_completion,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
MultiModalFeatures,
@@ -459,9 +460,9 @@ class OpenAIServingRender:
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
return await self._preprocess_cmpl(request, prompts)
return await self.preprocess_cmpl(request, prompts)
async def _preprocess_cmpl(
async def preprocess_cmpl(
self,
request: Any,
prompts: Sequence[PromptType | bytes],
@@ -500,11 +501,7 @@ class OpenAIServingRender:
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
"""Copied from OpenAIServing._preprocess_chat.
Differences: isinstance check is ChatCompletionRequest-only
(ResponsesRequest not supported here); TODO comment dropped accordingly.
"""
"""Copied from OpenAIServing._preprocess_chat."""
renderer = self.renderer
mm_config = self.model_config.multimodal_config
@@ -542,11 +539,11 @@ class OpenAIServingRender:
if tool_parser is not None:
tool_choice = getattr(request, "tool_choice", "none")
if tool_choice != "none":
if not isinstance(request, ChatCompletionRequest):
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
msg = (
"Tool usage is only supported "
" for ChatCompletionRequest, but got "
f"{type(request).__name__}"
"for Chat Completions API or Responses API requests, "
f"but got {type(request).__name__}"
)
raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer()