[Misc] Reorganize inputs (#35182)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-03-26 01:22:54 +08:00
committed by GitHub
parent 678b3c99e8
commit ba2f0acc2d
142 changed files with 1212 additions and 1342 deletions

View File

@@ -63,7 +63,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
)
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import ProcessorInputs
from vllm.inputs import EngineInput
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
@@ -177,7 +177,7 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request(
self,
request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse:
) -> tuple[list[ConversationMessage], list[EngineInput]] | ErrorResponse:
"""
Validate the model and preprocess a chat completion request.
@@ -185,7 +185,7 @@ class OpenAIServingChat(OpenAIServing):
engine-aware checks (LoRA model validation, engine health).
Returns:
A tuple of (conversation, engine_prompts) on success,
A tuple of (conversation, engine_inputs) on success,
or an ErrorResponse on failure.
"""
error_check_ret = await self._check_model(request)
@@ -231,7 +231,7 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(result, ErrorResponse):
return result
conversation, engine_prompts = result
conversation, engine_inputs = result
request_id = (
f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}"
@@ -251,13 +251,13 @@ class OpenAIServingChat(OpenAIServing):
# Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = []
for i, engine_prompt in enumerate(engine_prompts):
prompt_token_ids = self._extract_prompt_components(engine_prompt).token_ids
for i, engine_input in enumerate(engine_inputs):
prompt_token_ids = self._extract_prompt_components(engine_input).token_ids
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
sub_request_id = (
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
request_id if len(engine_inputs) == 1 else f"{request_id}_{i}"
)
max_tokens = get_max_tokens(
@@ -265,7 +265,7 @@ class OpenAIServingChat(OpenAIServing):
request.max_completion_tokens
if request.max_completion_tokens is not None
else request.max_tokens,
self._extract_prompt_len(engine_prompt),
self._extract_prompt_len(engine_input),
self.default_sampling_params,
self.override_max_tokens,
)
@@ -283,7 +283,7 @@ class OpenAIServingChat(OpenAIServing):
self._log_inputs(
sub_request_id,
engine_prompt,
engine_input,
params=sampling_params,
lora_request=lora_request,
)
@@ -296,7 +296,7 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
prompt=engine_input,
request_id=sub_request_id,
params=sampling_params,
lora_request=lora_request,
@@ -313,7 +313,7 @@ class OpenAIServingChat(OpenAIServing):
reasoning_ended = None
generator = self.engine_client.generate(
engine_prompt,
engine_input,
sampling_params,
sub_request_id,
lora_request=lora_request,

View File

@@ -33,7 +33,7 @@ from vllm.entrypoints.openai.engine.serving import (
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ProcessorInputs
from vllm.inputs import EngineInput
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
@@ -82,7 +82,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def render_completion_request(
self,
request: CompletionRequest,
) -> list[ProcessorInputs] | ErrorResponse:
) -> list[EngineInput] | ErrorResponse:
"""
Validate the model and preprocess a completion request.
@@ -90,8 +90,7 @@ class OpenAIServingCompletion(OpenAIServing):
engine-aware checks (LoRA model validation, engine health).
Returns:
A list of engine_prompts on success,
or an ErrorResponse on failure.
A list of engine_inputs on success, or an ErrorResponse on failure.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
@@ -128,7 +127,7 @@ class OpenAIServingCompletion(OpenAIServing):
if isinstance(result, ErrorResponse):
return result
engine_prompts = result
engine_inputs = result
request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
created_time = int(time.time())
@@ -145,11 +144,11 @@ class OpenAIServingCompletion(OpenAIServing):
# Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = []
for i, engine_prompt in enumerate(engine_prompts):
for i, engine_input in enumerate(engine_inputs):
max_tokens = get_max_tokens(
max_model_len,
request.max_tokens,
self._extract_prompt_len(engine_prompt),
self._extract_prompt_len(engine_input),
self.default_sampling_params,
self.override_max_tokens,
)
@@ -169,7 +168,7 @@ class OpenAIServingCompletion(OpenAIServing):
self._log_inputs(
request_id_item,
engine_prompt,
engine_input,
params=sampling_params,
lora_request=lora_request,
)
@@ -182,7 +181,7 @@ class OpenAIServingCompletion(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
prompt=engine_input,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
@@ -190,7 +189,7 @@ class OpenAIServingCompletion(OpenAIServing):
)
else:
generator = self.engine_client.generate(
engine_prompt,
engine_input,
sampling_params,
request_id_item,
lora_request=lora_request,
@@ -204,7 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators(*generators)
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
num_prompts = len(engine_inputs)
# Streaming response
tokenizer = self.renderer.tokenizer
@@ -212,7 +211,7 @@ class OpenAIServingCompletion(OpenAIServing):
if request.stream:
return self.completion_stream_generator(
request,
engine_prompts,
engine_inputs,
result_generator,
request_id,
created_time,
@@ -235,8 +234,7 @@ class OpenAIServingCompletion(OpenAIServing):
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
engine_prompt = engine_prompts[i]
final_res.prompt = self._extract_prompt_text(engine_prompt)
final_res.prompt = self._extract_prompt_text(engine_inputs[i])
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
@@ -268,7 +266,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator(
self,
request: CompletionRequest,
engine_prompts: list[ProcessorInputs],
engine_inputs: list[EngineInput],
result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str,
created_time: int,
@@ -301,8 +299,8 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = res.prompt
if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx]
prompt_text = self._extract_prompt_text(engine_prompt)
engine_input = engine_inputs[prompt_idx]
prompt_text = self._extract_prompt_text(engine_input)
# Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None:

View File

@@ -72,11 +72,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
)
from vllm.entrypoints.utils import create_error_response
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import (
ProcessorInputs,
PromptType,
TokensPrompt,
)
from vllm.inputs import EngineInput, PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
@@ -163,7 +159,7 @@ class ServeContext(Generic[RequestT]):
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[ProcessorInputs] | None = None
engine_inputs: list[EngineInput] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
@@ -202,7 +198,7 @@ class OpenAIServing:
async def beam_search(
self,
prompt: ProcessorInputs,
prompt: EngineInput,
request_id: str,
params: BeamSearchParams,
lora_request: LoRARequest | None = None,
@@ -493,21 +489,21 @@ class OpenAIServing:
if isinstance(pooling_params, ErrorResponse):
return pooling_params
if ctx.engine_prompts is None:
if ctx.engine_inputs is None:
return self.create_error_response("Engine prompts not available")
for i, engine_prompt in enumerate(ctx.engine_prompts):
for i, engine_input in enumerate(ctx.engine_inputs):
request_id_item = f"{ctx.request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
engine_input,
params=pooling_params,
lora_request=ctx.lora_request,
)
generator = self.engine_client.encode(
engine_prompt,
engine_input,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
@@ -526,10 +522,10 @@ class OpenAIServing:
ctx: ServeContext,
) -> ErrorResponse | None:
"""Collect batch results from the result generator."""
if ctx.engine_prompts is None:
if ctx.engine_inputs is None:
return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_prompts)
num_prompts = len(ctx.engine_inputs)
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
@@ -806,19 +802,19 @@ class OpenAIServing:
# Apply server defaults first, then request kwargs override.
return default_chat_template_kwargs | request_chat_template_kwargs
def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
def _extract_prompt_components(self, prompt: PromptType | EngineInput):
return extract_prompt_components(self.model_config, prompt)
def _extract_prompt_text(self, prompt: ProcessorInputs):
def _extract_prompt_text(self, prompt: PromptType | EngineInput):
return self._extract_prompt_components(prompt).text
def _extract_prompt_len(self, prompt: ProcessorInputs):
def _extract_prompt_len(self, prompt: EngineInput):
return extract_prompt_len(self.model_config, prompt)
def _log_inputs(
self,
request_id: str,
inputs: PromptType | ProcessorInputs,
inputs: PromptType | EngineInput,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:

View File

@@ -12,7 +12,7 @@ from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsRealtime
from vllm.renderers.inputs.preprocess import parse_model_prompt
@@ -83,6 +83,6 @@ class OpenAIServingRealtime(OpenAIServing):
async for prompt in stream_input_iter:
parsed_prompt = parse_model_prompt(model_config, prompt)
(engine_prompt,) = await renderer.render_cmpl_async([parsed_prompt])
(engine_input,) = await renderer.render_cmpl_async([parsed_prompt])
yield StreamingInput(prompt=engine_prompt)
yield StreamingInput(prompt=engine_input)

View File

@@ -110,7 +110,7 @@ from vllm.entrypoints.openai.responses.utils import (
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ProcessorInputs, token_inputs
from vllm.inputs import EngineInput, tokens_input
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
@@ -269,10 +269,10 @@ class OpenAIServingResponses(OpenAIServing):
def _validate_generator_input(
self,
engine_prompt: ProcessorInputs,
engine_input: EngineInput,
) -> ErrorResponse | None:
"""Add validations to the input to the generator here."""
prompt_len = self._extract_prompt_len(engine_prompt)
prompt_len = self._extract_prompt_len(engine_input)
max_model_len = self.model_config.max_model_len
if prompt_len >= max_model_len:
@@ -369,11 +369,11 @@ class OpenAIServingResponses(OpenAIServing):
model_name = self.models.model_name(lora_request)
if self.use_harmony:
messages, engine_prompts = self._make_request_with_harmony(
messages, engine_inputs = self._make_request_with_harmony(
request, prev_response
)
else:
messages, engine_prompts = await self._make_request(request, prev_response)
messages, engine_inputs = await self._make_request(request, prev_response)
request_metadata = RequestResponseMetadata(request_id=request.request_id)
if raw_request:
@@ -413,15 +413,15 @@ class OpenAIServingResponses(OpenAIServing):
available_tools = []
tokenizer = self.renderer.get_tokenizer()
for engine_prompt in engine_prompts:
maybe_error = self._validate_generator_input(engine_prompt)
for engine_input in engine_inputs:
maybe_error = self._validate_generator_input(engine_input)
if maybe_error is not None:
return maybe_error
default_max_tokens = get_max_tokens(
max_model_len,
request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self._extract_prompt_len(engine_input),
self.default_sampling_params,
self.override_max_tokens,
)
@@ -480,7 +480,7 @@ class OpenAIServingResponses(OpenAIServing):
)
generator = self._generate_with_builtin_tools(
request_id=request.request_id,
engine_prompt=engine_prompt,
engine_input=engine_input,
sampling_params=sampling_params,
context=context,
lora_request=lora_request,
@@ -586,7 +586,7 @@ class OpenAIServingResponses(OpenAIServing):
prev_response_output=prev_response.output if prev_response else None,
)
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
_, engine_inputs = await self.openai_serving_render.preprocess_chat(
request,
messages,
default_template=self.chat_template,
@@ -595,7 +595,7 @@ class OpenAIServingResponses(OpenAIServing):
tool_dicts=tool_dicts,
tool_parser=self.parser.tool_parser_cls if self.parser else None,
)
return messages, engine_prompts
return messages, engine_inputs
async def _render_next_turn(
self,
@@ -610,7 +610,7 @@ class OpenAIServingResponses(OpenAIServing):
request_input=messages,
)
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
_, engine_inputs = await self.openai_serving_render.preprocess_chat(
request,
new_messages,
default_template=chat_template,
@@ -619,12 +619,12 @@ class OpenAIServingResponses(OpenAIServing):
tool_dicts=tool_dicts,
tool_parser=tool_parser,
)
return engine_prompts
return engine_inputs
async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: ProcessorInputs,
engine_input: EngineInput,
sampling_params: SamplingParams,
context: ConversationContext,
lora_request: LoRARequest | None = None,
@@ -641,13 +641,13 @@ class OpenAIServingResponses(OpenAIServing):
self._log_inputs(
sub_request_id,
engine_prompt,
engine_input,
params=sampling_params,
lora_request=lora_request,
)
generator = self.engine_client.generate(
engine_prompt,
engine_input,
sampling_params,
sub_request_id,
lora_request=lora_request,
@@ -675,11 +675,11 @@ class OpenAIServingResponses(OpenAIServing):
# Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion()
engine_prompt = token_inputs(token_ids)
engine_input = tokens_input(token_ids)
sampling_params.max_tokens = max_model_len - len(token_ids)
elif isinstance(context, ParsableContext):
(engine_prompt,) = await self._render_next_turn(
(engine_input,) = await self._render_next_turn(
context.request,
context.parser.response_messages,
context.tool_dicts,
@@ -691,7 +691,7 @@ class OpenAIServingResponses(OpenAIServing):
sampling_params.max_tokens = get_max_tokens(
max_model_len,
context.request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self._extract_prompt_len(engine_input),
self.default_sampling_params, # type: ignore
self.override_max_tokens, # type: ignore
)
@@ -713,14 +713,10 @@ class OpenAIServingResponses(OpenAIServing):
arrival_time = time.time()
messages = self._construct_input_messages_with_harmony(request, prev_response)
prompt_token_ids = render_for_completion(messages)
engine_prompt = token_inputs(prompt_token_ids)
engine_prompt["arrival_time"] = arrival_time
engine_input = tokens_input(prompt_token_ids, cache_salt=request.cache_salt)
engine_input["arrival_time"] = arrival_time
# Add cache_salt if provided in the request
if request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
return messages, [engine_prompt]
return messages, [engine_input]
async def _initialize_tool_sessions(
self,

View File

@@ -38,7 +38,7 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EncoderDecoderInputs, ProcessorInputs
from vllm.inputs import EncoderDecoderInput, EngineInput
from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription
@@ -171,7 +171,7 @@ class OpenAISpeechToText(OpenAIServing):
request: SpeechToTextRequest,
audio_data: bytes,
request_id: str,
) -> tuple[list[ProcessorInputs], float]:
) -> tuple[list[EngineInput], float]:
# Validate request
language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper.
@@ -250,9 +250,9 @@ class OpenAISpeechToText(OpenAIServing):
parsed_prompts.append(parsed_prompt)
engine_prompts = await self.renderer.render_cmpl_async(parsed_prompts)
engine_inputs = await self.renderer.render_cmpl_async(parsed_prompts)
return engine_prompts, duration
return engine_inputs, duration
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
dec_prompt = prompt["decoder_prompt"]
@@ -271,7 +271,7 @@ class OpenAISpeechToText(OpenAIServing):
return prompt
@staticmethod
def _get_decoder_prompt_len(engine_prompts: list[ProcessorInputs]) -> int:
def _get_decoder_prompt_len(engine_inputs: list[EngineInput]) -> int:
"""Get the length of the decoder prompt. Currently we need to offset
by the decoder prompt length when running beam search because the mm
encoder is not currently cached and runs on decode calls; because of
@@ -282,12 +282,13 @@ class OpenAISpeechToText(OpenAIServing):
encoder/decoder caching is implemented.
"""
input_len = 0
assert len(engine_prompts) > 0
first_eng_prompt = engine_prompts[0]
assert len(engine_inputs) > 0
first_input = engine_inputs[0]
if first_input.get("type") == "enc_dec":
first_input = cast(EncoderDecoderInput, first_input)
input_len = len(first_input["decoder_prompt"]["prompt_token_ids"])
if first_eng_prompt.get("type") == "enc_dec":
first_eng_prompt = cast(EncoderDecoderInputs, first_eng_prompt)
input_len = len(first_eng_prompt["decoder_prompt"]["prompt_token_ids"])
return input_len
def _get_verbose_segments(
@@ -409,7 +410,7 @@ class OpenAISpeechToText(OpenAIServing):
lora_request = self._maybe_get_adapters(request)
engine_prompts, duration_s = await self._preprocess_speech_to_text(
engine_inputs, duration_s = await self._preprocess_speech_to_text(
request=request,
audio_data=audio_data,
request_id=request_id,
@@ -420,7 +421,7 @@ class OpenAISpeechToText(OpenAIServing):
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
input_len = (
OpenAISpeechToText._get_decoder_prompt_len(engine_prompts)
OpenAISpeechToText._get_decoder_prompt_len(engine_inputs)
if request.use_beam_search
else 0
)
@@ -450,12 +451,12 @@ class OpenAISpeechToText(OpenAIServing):
sampling_params.logprobs = 1
list_result_generator = []
for i, engine_prompt in enumerate(engine_prompts):
for i, engine_input in enumerate(engine_inputs):
request_id_item = f"{request_id}_{i}"
self._log_inputs(
request_id_item,
engine_prompt,
engine_input,
params=sampling_params,
lora_request=lora_request,
)
@@ -468,7 +469,7 @@ class OpenAISpeechToText(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
prompt=engine_input,
params=sampling_params,
request_id=request_id_item,
lora_request=lora_request,
@@ -476,7 +477,7 @@ class OpenAISpeechToText(OpenAIServing):
)
else:
generator = self.engine_client.generate(
engine_prompt,
engine_input,
sampling_params,
request_id_item,
lora_request=lora_request,