[Renderer] Move InputPreprocessor into Renderer (2/2) (#34560)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-17 21:29:01 +08:00
committed by GitHub
parent c61a98f529
commit 574fe75245
32 changed files with 984 additions and 1054 deletions

View File

@@ -3,8 +3,8 @@
import itertools
import warnings
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any
import cloudpickle
import torch.nn as nn
@@ -55,6 +55,7 @@ from vllm.entrypoints.pooling.score.utils import (
from vllm.entrypoints.utils import log_non_default_args
from vllm.inputs.data import (
DataPrompt,
ProcessorInputs,
PromptType,
SingletonPrompt,
TextPrompt,
@@ -73,10 +74,8 @@ from vllm.outputs import (
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, merge_kwargs
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import (
conversation_to_seq,
extract_prompt_components,
parse_model_prompt,
prompt_to_seq,
)
@@ -86,6 +85,7 @@ from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils.counter import Counter
from vllm.utils.tqdm_utils import maybe_tqdm
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor
@@ -400,7 +400,7 @@ class LLM:
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[RequestOutput]:
@@ -462,7 +462,7 @@ class LLM:
self,
prompts: PromptType | Sequence[PromptType],
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
tokenization_kwargs: dict[str, Any] | None = None,
@@ -495,34 +495,32 @@ class LLM:
# Use the same preprocessing as _run_completion
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(sampling_params, len(seq_prompts))
if any(param.truncate_prompt_tokens is not None for param in seq_params):
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt
for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_cmpl(
[prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
),
)
]
else:
engine_prompts = self._preprocess_cmpl(
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
seq_tok_kwargs = [
merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
)
for param in seq_params
]
seq_priority = self._priority_to_seq(priority, len(prompts))
request_ids = self._validate_and_add_requests(
prompts=engine_prompts,
params=seq_params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
request_ids = self._render_and_add_requests(
prompts=(
self._preprocess_cmpl_one(prompt, tok_kwargs)
for prompt, tok_kwargs in zip(
maybe_tqdm(
seq_prompts,
use_tqdm=use_tqdm,
desc="Rendering prompts",
),
seq_tok_kwargs,
)
),
params=seq_params,
lora_requests=seq_lora_requests,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
priorities=seq_priority,
)
return request_ids
@@ -545,53 +543,41 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def _get_modality_specific_lora_reqs(
def _resolve_lora_reqs(
self,
prompts: Sequence[DictPrompt | TokPrompt],
lora_request: list[LoRARequest] | LoRARequest | None,
prompts: Sequence[ProcessorInputs],
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
):
# Grab the lora config off the vllm config on the engine,
# since this is the same for both v0 & v1.
lora_config = self.llm_engine.vllm_config.lora_config
seq_lora_requests = self._lora_request_to_seq(lora_request, len(prompts))
# If there's no lora config / default_mm_loras, or the model
# isn't multimodal, leave the lora as is.
if (
lora_config is None
or not self.model_config.is_multimodal_model
or (lora_config and lora_config.default_mm_loras is None)
):
return lora_request
optional_loras = (
[lora_request] * len(prompts)
if not isinstance(lora_request, Sequence)
else lora_request
)
return seq_lora_requests
return [
self._resolve_single_prompt_mm_lora(
prompt,
opt_lora_req,
lora_req,
lora_config.default_mm_loras,
)
for prompt, opt_lora_req in zip(prompts, optional_loras)
for prompt, lora_req in zip(prompts, seq_lora_requests)
]
def _resolve_single_prompt_mm_lora(
self,
prompt: DictPrompt | TokPrompt,
prompt: ProcessorInputs,
lora_request: LoRARequest | None,
default_mm_loras: dict[str, str] | None,
):
if not default_mm_loras or not (
mm_data := prompt.get("multi_modal_data") or {}
):
if not default_mm_loras or prompt["type"] != "multimodal":
return lora_request
intersection = set(
mm_data.keys() # type: ignore
).intersection(default_mm_loras.keys())
prompt_modalities = prompt["mm_placeholders"].keys()
intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
if not intersection:
return lora_request
if len(intersection) > 1:
@@ -674,22 +660,6 @@ class LLM:
"""
return self.llm_engine.apply_model(func)
def _get_beam_search_lora_requests(
self,
lora_request: list[LoRARequest] | LoRARequest | None,
prompts: list[TokensPrompt | TextPrompt],
) -> list[LoRARequest | None]:
"""Get the optional lora request corresponding to each prompt."""
if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
raise ValueError(
"Lora request list should be the same length as the prompts"
)
if lora_request is None or isinstance(lora_request, LoRARequest):
return [lora_request] * len(prompts)
raise TypeError(f"Invalid lora_request type {type(lora_request)}")
def beam_search(
self,
prompts: list[TokensPrompt | TextPrompt],
@@ -718,13 +688,12 @@ class LLM:
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
tokenizer = self.renderer.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
tokenizer = self.get_tokenizer()
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id,
length_penalty,
)
engine_prompts = self._preprocess_cmpl(prompts)
lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts))
if use_tqdm and concurrency_limit is not None:
logger.warning(
@@ -734,21 +703,12 @@ class LLM:
use_tqdm = False
if concurrency_limit is None:
concurrency_limit = len(prompts)
def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
if beam.multi_modal_data is not None:
token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data
if beam.mm_processor_kwargs is not None:
token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
return TokensPrompt(**token_prompt_kwargs)
concurrency_limit = len(engine_prompts)
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams(
sampling_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
@@ -756,30 +716,25 @@ class LLM:
)
instances: list[BeamSearchInstance] = []
for lora_req, prompt in zip(lora_requests, prompts):
# Add multimodal processor kwargs & data
mm_kwargs = {}
if "multi_modal_data" in prompt:
mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
if "mm_processor_kwargs" in prompt:
mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
if "prompt_token_ids" in prompt:
prompt = cast(TokensPrompt, prompt) # Needed for mypy
prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])
for lora_req, prompt in zip(lora_requests, engine_prompts):
if prompt["type"] == "embeds":
raise NotImplementedError(
"Embedding prompt not supported for beam search"
)
if prompt["type"] == "enc_dec":
raise NotImplementedError(
"Encoder-decoder prompt not supported for beam search"
)
instances.append(
BeamSearchInstance(
prompt_tokens,
prompt,
lora_request=lora_req,
logprobs=None,
**mm_kwargs,
),
)
for prompt_start in range(0, len(prompts), concurrency_limit):
for prompt_start in range(0, len(instances), concurrency_limit):
instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
token_iter = range(max_tokens)
@@ -808,22 +763,15 @@ class LLM:
if len(all_beams) == 0:
break
# create corresponding batch entries for prompt & optional lora
prompts_batch, lora_req_batch = zip(
*[
(create_tokens_prompt_from_beam(beam), beam.lora_request)
for beam in all_beams
]
)
# only runs for one step
# we don't need to use tqdm here
output = self.generate(
prompts_batch,
sampling_params=beam_search_params,
raw_output = self._render_and_run_requests(
prompts=(beam.get_prompt() for beam in all_beams),
params=self._params_to_seq(sampling_params, len(all_beams)),
lora_requests=[beam.lora_request for beam in all_beams],
use_tqdm=False,
lora_request=lora_req_batch,
)
output = self.engine_class.validate_outputs(raw_output, RequestOutput)
for (start, end), instance in zip(
instance_start_and_end, instances_batch
@@ -841,19 +789,15 @@ class LLM:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
current_beam.orig_prompt,
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.mm_processor_kwargs,
)
if (
token_id == tokenizer.eos_token_id
and not ignore_eos
):
if token_id == eos_token_id and not ignore_eos:
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
@@ -872,6 +816,7 @@ class LLM:
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens)
outputs.append(BeamSearchOutput(sequences=best_beams))
return outputs
@@ -880,7 +825,7 @@ class LLM:
self,
prompts: Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[DictPrompt | TokPrompt]:
) -> Sequence[ProcessorInputs]:
"""
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
a format that can be passed to `_add_request`.
@@ -888,8 +833,7 @@ class LLM:
Refer to [LLM.generate][] for a complete description of the arguments.
Returns:
A list of `TokPrompt` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
"""
renderer = self.renderer
model_config = self.model_config
@@ -903,6 +847,14 @@ class LLM:
return renderer.render_cmpl(parsed_prompts, tok_params)
def _preprocess_cmpl_one(
self,
prompt: PromptType,
tokenization_kwargs: dict[str, Any] | None = None,
) -> ProcessorInputs:
(engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
return engine_prompt
def _preprocess_chat(
self,
conversations: Sequence[list[ChatCompletionMessageParam]],
@@ -914,7 +866,7 @@ class LLM:
tools: list[dict[str, Any]] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> Sequence[TokPrompt]:
) -> Sequence[ProcessorInputs]:
"""
Convert a list of conversations into prompts so that they can then
be used as input for other LLM APIs.
@@ -922,8 +874,7 @@ class LLM:
Refer to [LLM.chat][] for a complete description of the arguments.
Returns:
A list of `TokPrompt` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
"""
renderer = self.renderer
@@ -953,13 +904,39 @@ class LLM:
return engine_prompts
def _preprocess_chat_one(
self,
conversation: list[ChatCompletionMessageParam],
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
chat_template_kwargs: dict[str, Any] | None = None,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> ProcessorInputs:
(engine_prompt,) = self._preprocess_chat(
[conversation],
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
return engine_prompt
def chat(
self,
messages: list[ChatCompletionMessageParam]
| Sequence[list[ChatCompletionMessageParam]],
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: LoRARequest | None = None,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
@@ -1805,47 +1782,41 @@ class LLM:
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
):
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(params, len(seq_prompts))
if any(param.truncate_prompt_tokens is not None for param in seq_params):
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
# Then, move the code from the `else` block to the top and let
# `self._preprocess_cmpl` handle prompt normalization
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt
for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_cmpl(
[prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
),
)
]
else:
engine_prompts = self._preprocess_cmpl(
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
seq_tok_kwargs = [
merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
)
for param in seq_params
]
seq_priority = self._priority_to_seq(priority, len(prompts))
self._validate_and_add_requests(
prompts=engine_prompts,
return self._render_and_run_requests(
prompts=(
self._preprocess_cmpl_one(prompt, tok_kwargs)
for prompt, tok_kwargs in zip(
maybe_tqdm(
seq_prompts,
use_tqdm=use_tqdm,
desc="Rendering prompts",
),
seq_tok_kwargs,
)
),
params=seq_params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
lora_requests=seq_lora_requests,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
priorities=seq_priority,
)
return self._run_engine(use_tqdm=use_tqdm)
def _run_chat(
self,
messages: list[ChatCompletionMessageParam]
@@ -1855,7 +1826,7 @@ class LLM:
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: LoRARequest | None = None,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
@@ -1865,68 +1836,94 @@ class LLM:
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
):
engine_prompts = self._preprocess_chat(
conversation_to_seq(messages),
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
seq_convs = conversation_to_seq(messages)
seq_params = self._params_to_seq(params, len(seq_convs))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
seq_tok_kwargs = [
merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
)
for param in seq_params
]
return self._render_and_run_requests(
prompts=(
self._preprocess_chat_one(
conversation,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenization_kwargs=tok_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
for conversation, tok_kwargs in zip(
maybe_tqdm(
seq_convs,
use_tqdm=use_tqdm,
desc="Rendering conversations",
),
seq_tok_kwargs,
)
),
params=seq_params,
lora_requests=seq_lora_requests,
use_tqdm=use_tqdm,
tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
self._validate_and_add_requests(
prompts=engine_prompts,
def _render_and_run_requests(
self,
prompts: Iterable[ProcessorInputs],
params: Sequence[SamplingParams | PoolingParams],
*,
lora_requests: Sequence[LoRARequest | None] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
priorities: Sequence[int] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
):
if isinstance(prompts, (list, tuple)):
logger.warning_once(
"Rendering all prompts before adding them to the engine "
"is less efficient than performing both on the same prompt "
"before processing the next prompt. You should instead pass "
"a generator that renders one prompt per iteration, as that allows "
"engine execution to begin for the first prompt while processing "
"the next prompt."
)
self._render_and_add_requests(
prompts=prompts,
params=params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
lora_requests=lora_requests,
tokenization_kwargs=tokenization_kwargs,
priorities=priorities,
)
return self._run_engine(use_tqdm=use_tqdm)
def _validate_and_add_requests(
def _render_and_add_requests(
self,
prompts: Sequence[DictPrompt | TokPrompt],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
prompts: Iterable[ProcessorInputs],
params: Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
lora_requests: Sequence[LoRARequest | None] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None,
priorities: Sequence[int] | None = None,
) -> list[str]:
num_requests = len(prompts)
seq_params = self._params_to_seq(params, num_requests)
seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests)
seq_priority = self._priority_to_seq(priority, num_requests)
for sp in seq_params:
if isinstance(sp, SamplingParams):
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
it = prompts
if use_tqdm:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
added_request_ids: list[str] = []
try:
for i, prompt in enumerate(it):
for i, prompt in enumerate(prompts):
request_id = self._add_request(
prompt,
seq_params[i],
lora_request=seq_lora_requests[i],
params[i],
lora_request=None if lora_requests is None else lora_requests[i],
tokenization_kwargs=tokenization_kwargs,
priority=seq_priority[i],
priority=0 if priorities is None else priorities[i],
)
added_request_ids.append(request_id)
except Exception as e:
@@ -1938,13 +1935,16 @@ class LLM:
def _add_request(
self,
prompt: PromptType | DictPrompt | TokPrompt,
prompt: ProcessorInputs,
params: SamplingParams | PoolingParams,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: int = 0,
) -> str:
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
if isinstance(params, SamplingParams):
# We only care about the final output
params.output_kind = RequestOutputKind.FINAL_ONLY
request_id = str(next(self.request_counter))
if params.truncate_prompt_tokens is not None:
@@ -1962,33 +1962,15 @@ class LLM:
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
)
renderer = self.renderer
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
return self.llm_engine.add_request(
request_id,
prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
supported_tasks=self.supported_tasks,
)
self.llm_engine.add_request(
request_id,
engine_request,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
prompt_text=prompt_text,
)
return engine_request.request_id
def _run_engine(
self,
*,