[Frontend] Complete OpenAI render delegation (#37287)
Signed-off-by: Sage Ahrac <sagiahrak@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
|
||||
from collections.abc import AsyncGenerator, Callable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
|
||||
@@ -22,9 +22,7 @@ from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
@@ -43,19 +41,9 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
GenerationError,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.responses.context import (
|
||||
ConversationContext,
|
||||
HarmonyContext,
|
||||
ParsableContext,
|
||||
StreamingHarmonyContext,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponseInputOutputItem,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.utils import (
|
||||
construct_input_messages,
|
||||
)
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
@@ -82,26 +70,22 @@ from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
)
|
||||
from vllm.entrypoints.utils import create_error_response, get_max_tokens
|
||||
from vllm.entrypoints.utils import create_error_response
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import (
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
SingletonPrompt,
|
||||
TokensPrompt,
|
||||
token_inputs,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.renderers import ChatParams, TokenizeParams
|
||||
from vllm.renderers.inputs.preprocess import (
|
||||
extract_prompt_components,
|
||||
extract_prompt_len,
|
||||
parse_model_prompt,
|
||||
prompt_to_seq,
|
||||
)
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@@ -116,7 +100,6 @@ from vllm.utils.async_utils import (
|
||||
collect_from_async_generator,
|
||||
merge_async_iterators,
|
||||
)
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -823,109 +806,6 @@ class OpenAIServing:
|
||||
# Apply server defaults first, then request kwargs override.
|
||||
return default_chat_template_kwargs | request_chat_template_kwargs
|
||||
|
||||
async def _preprocess_completion(
|
||||
self,
|
||||
request: RendererRequest,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
|
||||
prompt_embeds: bytes | list[bytes] | None,
|
||||
) -> list[ProcessorInputs]:
|
||||
prompts = list[SingletonPrompt | bytes]()
|
||||
if prompt_embeds is not None: # embeds take higher priority
|
||||
prompts.extend(prompt_to_seq(prompt_embeds))
|
||||
if prompt_input is not None:
|
||||
prompts.extend(prompt_to_seq(prompt_input))
|
||||
|
||||
return await self._preprocess_cmpl(request, prompts)
|
||||
|
||||
async def _preprocess_cmpl(
|
||||
self,
|
||||
request: RendererRequest,
|
||||
prompts: Sequence[PromptType | bytes],
|
||||
) -> list[ProcessorInputs]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(model_config, prompt)
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
tok_params = request.build_tok_params(model_config)
|
||||
|
||||
return await renderer.render_cmpl_async(
|
||||
parsed_prompts,
|
||||
tok_params,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
async def _preprocess_chat(
|
||||
self,
|
||||
request: RendererChatRequest,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
default_template_kwargs: dict[str, Any] | None,
|
||||
tool_dicts: list[dict[str, Any]] | None = None,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
|
||||
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
|
||||
renderer = self.renderer
|
||||
|
||||
default_template_kwargs = merge_kwargs(
|
||||
default_template_kwargs,
|
||||
dict(
|
||||
tools=tool_dicts,
|
||||
tokenize=is_mistral_tokenizer(renderer.tokenizer),
|
||||
),
|
||||
)
|
||||
|
||||
mm_config = self.model_config.multimodal_config
|
||||
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
chat_params = request.build_chat_params(
|
||||
default_template, default_template_content_format
|
||||
).with_defaults(
|
||||
default_template_kwargs,
|
||||
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
|
||||
default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None),
|
||||
)
|
||||
|
||||
(conversation,), (engine_prompt,) = await renderer.render_chat_async(
|
||||
[messages],
|
||||
chat_params,
|
||||
tok_params,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
# tool parsing is done only if a tool_parser has been set and if
|
||||
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
|
||||
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
|
||||
if tool_parser is not None:
|
||||
tool_choice = getattr(request, "tool_choice", "none")
|
||||
if tool_choice != "none":
|
||||
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
|
||||
msg = (
|
||||
"Tool usage is only supported for Chat Completions API "
|
||||
"or Responses API requests."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
# TODO: Update adjust_request to accept ResponsesRequest
|
||||
tokenizer = renderer.get_tokenizer()
|
||||
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type]
|
||||
|
||||
return conversation, [engine_prompt]
|
||||
|
||||
def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
|
||||
return extract_prompt_components(self.model_config, prompt)
|
||||
|
||||
@@ -935,109 +815,6 @@ class OpenAIServing:
|
||||
def _extract_prompt_len(self, prompt: ProcessorInputs):
|
||||
return extract_prompt_len(self.model_config, prompt)
|
||||
|
||||
async def _render_next_turn(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
messages: list[ResponseInputOutputItem],
|
||||
tool_dicts: list[dict[str, Any]] | None,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
new_messages = construct_input_messages(
|
||||
request_input=messages,
|
||||
)
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
new_messages,
|
||||
default_template=chat_template,
|
||||
default_template_content_format=chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
tool_dicts=tool_dicts,
|
||||
tool_parser=tool_parser,
|
||||
)
|
||||
return engine_prompts
|
||||
|
||||
async def _generate_with_builtin_tools(
|
||||
self,
|
||||
request_id: str,
|
||||
engine_prompt: ProcessorInputs,
|
||||
sampling_params: SamplingParams,
|
||||
context: ConversationContext,
|
||||
lora_request: LoRARequest | None = None,
|
||||
priority: int = 0,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
orig_priority = priority
|
||||
sub_request = 0
|
||||
while True:
|
||||
# Ensure that each sub-request has a unique request id.
|
||||
sub_request_id = f"{request_id}_{sub_request}"
|
||||
|
||||
self._log_inputs(
|
||||
sub_request_id,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
sub_request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
async for res in generator:
|
||||
context.append_output(res)
|
||||
# NOTE(woosuk): The stop condition is handled by the engine.
|
||||
yield context
|
||||
|
||||
if not context.need_builtin_tool_call():
|
||||
# The model did not ask for a tool call, so we're done.
|
||||
break
|
||||
|
||||
# Call the tool and update the context with the result.
|
||||
tool_output = await context.call_tool()
|
||||
context.append_tool_output(tool_output)
|
||||
|
||||
# TODO: uncomment this and enable tool output streaming
|
||||
# yield context
|
||||
|
||||
# Create inputs for the next turn.
|
||||
# Render the next prompt token ids and update sampling_params.
|
||||
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
|
||||
token_ids = context.render_for_completion()
|
||||
engine_prompt = token_inputs(token_ids)
|
||||
|
||||
sampling_params.max_tokens = max_model_len - len(token_ids)
|
||||
elif isinstance(context, ParsableContext):
|
||||
(engine_prompt,) = await self._render_next_turn(
|
||||
context.request,
|
||||
context.parser.response_messages,
|
||||
context.tool_dicts,
|
||||
context.tool_parser_cls,
|
||||
context.chat_template,
|
||||
context.chat_template_content_format,
|
||||
)
|
||||
|
||||
sampling_params.max_tokens = get_max_tokens(
|
||||
max_model_len,
|
||||
context.request.max_output_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params, # type: ignore
|
||||
self.override_max_tokens, # type: ignore
|
||||
)
|
||||
|
||||
# OPTIMIZATION
|
||||
priority = orig_priority - 1
|
||||
sub_request += 1
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
|
||||
@@ -80,6 +80,7 @@ async def init_generate_state(
|
||||
OpenAIServingResponses(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
state.openai_serving_render,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
@@ -157,6 +158,7 @@ async def init_generate_state(
|
||||
ServingTokens(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
state.openai_serving_render,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
|
||||
@@ -5,11 +5,11 @@ import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping, Sequence
|
||||
from contextlib import AsyncExitStack
|
||||
from copy import copy
|
||||
from http import HTTPStatus
|
||||
from typing import Final
|
||||
from typing import Any, Final
|
||||
|
||||
from fastapi import Request
|
||||
from openai.types.responses import (
|
||||
@@ -86,6 +86,7 @@ from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponseCompletedEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseInputOutputItem,
|
||||
ResponseInputOutputMessage,
|
||||
ResponseReasoningPartAddedEvent,
|
||||
ResponseReasoningPartDoneEvent,
|
||||
@@ -105,16 +106,19 @@ from vllm.entrypoints.openai.responses.utils import (
|
||||
construct_tool_dicts,
|
||||
extract_tool_types,
|
||||
)
|
||||
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import ProcessorInputs, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob as SampleLogprob
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.parser import ParserManager
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.collection_utils import as_list
|
||||
|
||||
@@ -165,6 +169,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
openai_serving_render: OpenAIServingRender,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
@@ -185,6 +190,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
)
|
||||
|
||||
self.openai_serving_render = openai_serving_render
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.enable_log_outputs = enable_log_outputs
|
||||
@@ -587,7 +593,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
prev_response_output=prev_response.output if prev_response else None,
|
||||
)
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
|
||||
request,
|
||||
messages,
|
||||
default_template=self.chat_template,
|
||||
@@ -598,6 +604,109 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
)
|
||||
return messages, engine_prompts
|
||||
|
||||
async def _render_next_turn(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
messages: list[ResponseInputOutputItem],
|
||||
tool_dicts: list[dict[str, Any]] | None,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
new_messages = construct_input_messages(
|
||||
request_input=messages,
|
||||
)
|
||||
|
||||
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
|
||||
request,
|
||||
new_messages,
|
||||
default_template=chat_template,
|
||||
default_template_content_format=chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
tool_dicts=tool_dicts,
|
||||
tool_parser=tool_parser,
|
||||
)
|
||||
return engine_prompts
|
||||
|
||||
async def _generate_with_builtin_tools(
|
||||
self,
|
||||
request_id: str,
|
||||
engine_prompt: ProcessorInputs,
|
||||
sampling_params: SamplingParams,
|
||||
context: ConversationContext,
|
||||
lora_request: LoRARequest | None = None,
|
||||
priority: int = 0,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
orig_priority = priority
|
||||
sub_request = 0
|
||||
while True:
|
||||
# Ensure that each sub-request has a unique request id.
|
||||
sub_request_id = f"{request_id}_{sub_request}"
|
||||
|
||||
self._log_inputs(
|
||||
sub_request_id,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
sub_request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
async for res in generator:
|
||||
context.append_output(res)
|
||||
# NOTE(woosuk): The stop condition is handled by the engine.
|
||||
yield context
|
||||
|
||||
if not context.need_builtin_tool_call():
|
||||
# The model did not ask for a tool call, so we're done.
|
||||
break
|
||||
|
||||
# Call the tool and update the context with the result.
|
||||
tool_output = await context.call_tool()
|
||||
context.append_tool_output(tool_output)
|
||||
|
||||
# TODO: uncomment this and enable tool output streaming
|
||||
# yield context
|
||||
|
||||
# Create inputs for the next turn.
|
||||
# Render the next prompt token ids and update sampling_params.
|
||||
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
|
||||
token_ids = context.render_for_completion()
|
||||
engine_prompt = token_inputs(token_ids)
|
||||
|
||||
sampling_params.max_tokens = max_model_len - len(token_ids)
|
||||
elif isinstance(context, ParsableContext):
|
||||
(engine_prompt,) = await self._render_next_turn(
|
||||
context.request,
|
||||
context.parser.response_messages,
|
||||
context.tool_dicts,
|
||||
context.tool_parser_cls,
|
||||
context.chat_template,
|
||||
context.chat_template_content_format,
|
||||
)
|
||||
|
||||
sampling_params.max_tokens = get_max_tokens(
|
||||
max_model_len,
|
||||
context.request.max_output_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params, # type: ignore
|
||||
self.override_max_tokens, # type: ignore
|
||||
)
|
||||
|
||||
# OPTIMIZATION
|
||||
priority = orig_priority - 1
|
||||
sub_request += 1
|
||||
|
||||
def _make_request_with_harmony(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
|
||||
Reference in New Issue
Block a user