[Renderer] Move InputPreprocessor into Renderer (2/2) (#34560)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-17 21:29:01 +08:00
committed by GitHub
parent c61a98f529
commit 574fe75245
32 changed files with 984 additions and 1054 deletions

View File

@@ -195,18 +195,15 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test):
valid_msg = [{"role": "user", "content": "Hello"}]
long_text = "This is a very long text to test the error " * 50
invalid_msg = [{"role": "user", "content": long_text}]
batch_1 = [
valid_msg,
valid_msg,
invalid_msg,
]
batch_2 = [
valid_msg,
valid_msg,
]
batch_1 = [valid_msg, valid_msg, invalid_msg]
batch_2 = [valid_msg, valid_msg]
sampling_params = SamplingParams(temperature=0, max_tokens=10)
with pytest.raises(ValueError, match="context length is only"):
llm.chat(batch_1, sampling_params=sampling_params)
assert llm.llm_engine.get_num_unfinished_requests() == 0
outputs_2 = llm.chat(batch_2, sampling_params=sampling_params)
assert len(outputs_2) == len(batch_2)
assert llm.llm_engine.get_num_unfinished_requests() == 0

View File

@@ -489,8 +489,9 @@ def _assert_inputs_equal(
if ignore_mm_keys is None:
ignore_mm_keys = set()
a_rest = {k: v for k, v in a.items() if k != "mm_kwargs"}
b_rest = {k: v for k, v in b.items() if k != "mm_kwargs"}
ignore_prompt_keys = ("prompt", "mm_kwargs")
a_rest = {k: v for k, v in a.items() if k not in ignore_prompt_keys}
b_rest = {k: v for k, v in b.items() if k not in ignore_prompt_keys}
assert a_rest == b_rest, msg

View File

@@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
stop_pil_image = ImageAsset("stop_sign").pil_image
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
def _build_renderer(
*, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
) -> HfRenderer:
model_config = ModelConfig(
model="Qwen/Qwen2.5-VL-3B-Instruct",
max_model_len=128,
mm_processor_cache_gb=mm_cache_gb,
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
)
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer.from_config(
vllm_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
def test_multi_modal_uuids_length_mismatch_raises():
renderer = _build_renderer()
mm_data = {"image": [cherry_pil_image, stop_pil_image]}
# Mismatch: 2 items but only 1 uuid provided
mm_uuids = {"image": ["hash_cherry"]}
mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
with pytest.raises(ValueError, match="must have same length as"):
renderer._process_mm_uuids(mm_data, mm_items, mm_uuids, "req-1")
def test_multi_modal_uuids_missing_modality_raises():
renderer = _build_renderer()
mm_data = {
"image": [cherry_pil_image],
"video": None,
}
# Only image uuids provided; video missing should raise
mm_uuids = {"image": ["hash_cherry"]}
mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
with pytest.raises(ValueError, match="is empty but .* is missing"):
renderer._process_mm_uuids(mm_data, mm_items, mm_uuids, "req-2")
@pytest.mark.parametrize(
"mm_cache_gb, enable_prefix_caching",
[
(4.0, True), # default behavior
(4.0, False), # prefix caching disabled
(0.0, True), # processor cache disabled
],
)
def test_multi_modal_uuids_accepts_none_and_passes_through(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
):
renderer = _build_renderer(
mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching,
)
mm_data = {
"image": [cherry_pil_image, stop_pil_image],
"video": baby_reading_np_ndarrays,
}
# Use a consistent two-image scenario across all configurations
mm_uuids = {"image": [None, "hash_stop"], "video": None}
mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
processed_mm_uuids = renderer._process_mm_uuids(
mm_data, mm_items, mm_uuids, "req-3"
)
assert processed_mm_uuids == mm_uuids
@pytest.mark.parametrize(
"mm_cache_gb, enable_prefix_caching",
[
(4.0, True), # default behavior
(4.0, False), # prefix caching disabled
(0.0, True), # processor cache disabled
],
)
def test_multi_modal_uuids_accepts_empty(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
):
renderer = _build_renderer(
mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching,
)
# While None means cached multi-modal input requiring UUIDs
# an empty list means no multi-modal input
mm_data = {"image": [], "video": []} # type: ignore[var-annotated]
mm_uuids = {"image": [], "video": None} # type: ignore[var-annotated]
mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
processed_mm_uuids = renderer._process_mm_uuids(
mm_data, mm_items, mm_uuids, "req-4"
)
assert processed_mm_uuids == mm_uuids
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
# When both processor cache is 0 and prefix caching disabled, the
# processor builds overrides from request id instead of using user UUIDs.
renderer = _build_renderer(mm_cache_gb=0.0, enable_prefix_caching=False)
request_id = "req-42"
mm_data = {
"image": [cherry_pil_image, stop_pil_image],
"video": baby_reading_np_ndarrays,
}
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]}
mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
processed_mm_uuids = renderer._process_mm_uuids(
mm_data, mm_items, mm_uuids, request_id
)
# Expect request-id-based overrides are passed through
assert set(mm_uuids.keys()) == {"image", "video"}
assert len(mm_uuids["image"]) == 2
assert len(mm_uuids["video"]) == 1
assert processed_mm_uuids["image"][0].startswith(
f"{request_id}-image-"
) and processed_mm_uuids["image"][0].endswith("-0")
assert processed_mm_uuids["image"][1].startswith(
f"{request_id}-image-"
) and processed_mm_uuids["image"][1].endswith("-1")
assert processed_mm_uuids["video"][0].startswith(
f"{request_id}-video-"
) and processed_mm_uuids["video"][0].endswith("-0")

View File

@@ -20,7 +20,6 @@ MM_BEAM_WIDTHS = [2]
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
@pytest.mark.skip_v1 # V1 engine does not yet support beam search
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@@ -62,7 +61,6 @@ def test_beam_search_single_input(
)
@pytest.mark.skip_v1 # V1 engine does not yet support beam search
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)

View File

@@ -1,174 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.multimodal import MultiModalUUIDDict
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.input_processor import InputProcessor
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
stop_pil_image = ImageAsset("stop_sign").pil_image
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
def _build_input_processor(
*, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
) -> InputProcessor:
model_config = ModelConfig(
model="Qwen/Qwen2.5-VL-3B-Instruct",
max_model_len=128,
mm_processor_cache_gb=mm_cache_gb,
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
)
return InputProcessor(vllm_config)
def test_multi_modal_uuids_length_mismatch_raises():
input_processor = _build_input_processor()
prompt = {
"prompt": "USER: <image>\nDescribe\nASSISTANT:",
"multi_modal_data": {"image": [cherry_pil_image, stop_pil_image]},
# Mismatch: 2 items but only 1 uuid provided
"multi_modal_uuids": {"image": ["hash_cherry"]},
}
with pytest.raises(ValueError, match="must have same length as"):
input_processor.process_inputs(
request_id="req-1",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
def test_multi_modal_uuids_missing_modality_raises():
input_processor = _build_input_processor()
prompt = {
"prompt": "USER: <image><video>\nDescribe\nASSISTANT:",
# Two modalities provided in data
"multi_modal_data": {
"image": [cherry_pil_image],
"video": None,
},
# Only image uuids provided; video missing should raise
"multi_modal_uuids": {"image": ["hash_cherry"]},
}
with pytest.raises(ValueError, match="is empty but .* is missing"):
input_processor.process_inputs(
request_id="req-2",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
@pytest.mark.parametrize(
"mm_cache_gb, enable_prefix_caching",
[
(4.0, True), # default behavior
(4.0, False), # prefix caching disabled
(0.0, True), # processor cache disabled
],
)
def test_multi_modal_uuids_accepts_none_and_passes_through(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
):
input_processor = _build_input_processor(
mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching,
)
# Capture the overrides passed to InputPreprocessor.preprocess
captured: dict[str, object] = {}
def fake_preprocess(
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
):
captured["mm_uuids"] = mm_uuids
# Minimal processed inputs for decoder-only flow
return {"type": "token", "prompt_token_ids": [1]}
# Monkeypatch only the bound preprocess method on this instance
monkeypatch.setattr(
input_processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
)
# Use a consistent two-image scenario across all configurations
mm_uuids = {"image": [None, "hash_stop"], "video": None}
prompt = {
"prompt": "USER: <image><image>\nTwo images\nASSISTANT:",
"multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image],
"video": baby_reading_np_ndarrays,
},
"multi_modal_uuids": mm_uuids,
}
input_processor.process_inputs(
request_id="req-3",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
assert captured["mm_uuids"] == mm_uuids
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
# When both processor cache is 0 and prefix caching disabled, the
# processor builds overrides from request id instead of using user UUIDs.
input_processor = _build_input_processor(
mm_cache_gb=0.0, enable_prefix_caching=False
)
captured: dict[str, MultiModalUUIDDict] = {}
def fake_preprocess(
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
):
captured["mm_uuids"] = mm_uuids
return {"type": "token", "prompt_token_ids": [1]}
monkeypatch.setattr(
input_processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
)
request_id = "req-42"
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]}
prompt = {
"prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:",
"multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image],
"video": [baby_reading_np_ndarrays],
},
"multi_modal_uuids": mm_uuids,
}
input_processor.process_inputs(
request_id=request_id,
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
# Expect request-id-based overrides are passed through
assert set(mm_uuids.keys()) == {"image", "video"}
assert len(mm_uuids["image"]) == 2
assert len(mm_uuids["video"]) == 1
assert captured["mm_uuids"]["image"][0].startswith(
f"{request_id}-image-"
) and captured["mm_uuids"]["image"][0].endswith("-0")
assert captured["mm_uuids"]["image"][1].startswith(
f"{request_id}-image-"
) and captured["mm_uuids"]["image"][1].endswith("-1")
assert captured["mm_uuids"]["video"][0].startswith(
f"{request_id}-video-"
) and captured["mm_uuids"]["video"][0].endswith("-0")

View File

@@ -2,13 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from vllm.inputs import TokenInputs, token_inputs
from vllm.logprobs import Logprob
from vllm.lora.request import LoRARequest
if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
@dataclass
@@ -19,6 +17,8 @@ class BeamSearchSequence:
about to be returned to the user.
"""
orig_prompt: TokenInputs | MultiModalInputs
# The tokens include the prompt.
tokens: list[int]
logprobs: list[dict[int, Logprob]]
@@ -27,8 +27,28 @@ class BeamSearchSequence:
text: str | None = None
finish_reason: str | None = None
stop_reason: int | str | None = None
multi_modal_data: "MultiModalDataDict | None" = None
mm_processor_kwargs: dict[str, Any] | None = None
def get_prompt(self):
prompt = self.orig_prompt
prompt_text = prompt.get("prompt")
cache_salt = prompt.get("cache_salt")
if prompt["type"] == "token":
return token_inputs(
self.tokens,
prompt=prompt_text,
cache_salt=cache_salt,
)
return mm_inputs(
prompt_token_ids=self.tokens,
mm_kwargs=prompt["mm_kwargs"],
mm_hashes=prompt["mm_hashes"],
mm_placeholders=prompt["mm_placeholders"],
prompt=prompt_text,
cache_salt=cache_salt,
)
@dataclass
@@ -44,14 +64,15 @@ class BeamSearchOutput:
class BeamSearchInstance:
def __init__(
self,
prompt_tokens: list[int],
prompt: TokenInputs | MultiModalInputs,
lora_request: LoRARequest | None = None,
logprobs: list[dict[int, Logprob]] | None = None,
**kwargs,
):
self.beams: list[BeamSearchSequence] = [
BeamSearchSequence(
tokens=prompt_tokens,
orig_prompt=prompt,
tokens=prompt["prompt_token_ids"],
logprobs=[] if logprobs is None else list(logprobs),
lora_request=lora_request,
**kwargs,

View File

@@ -11,13 +11,12 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.inputs.data import PromptType
from vllm.inputs.data import ProcessorInputs, PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.v1.engine import EngineCoreRequest
@@ -35,7 +34,7 @@ class StreamingInput:
where inputs are provided via an async generator.
"""
prompt: PromptType
prompt: ProcessorInputs
sampling_params: SamplingParams | None = None
@@ -69,8 +68,7 @@ class EngineClient(ABC):
self,
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| ProcessorInputs
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
@@ -81,6 +79,7 @@ class EngineClient(ABC):
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
reasoning_ended: bool | None = None,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
...
@@ -88,13 +87,14 @@ class EngineClient(ABC):
@abstractmethod
def encode(
self,
prompt: PromptType | DictPrompt | TokPrompt,
prompt: PromptType | ProcessorInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
tokenization_kwargs: dict[str, Any] | None = None,
reasoning_ended: bool | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model."""
...

View File

@@ -3,8 +3,8 @@
import itertools
import warnings
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any
import cloudpickle
import torch.nn as nn
@@ -55,6 +55,7 @@ from vllm.entrypoints.pooling.score.utils import (
from vllm.entrypoints.utils import log_non_default_args
from vllm.inputs.data import (
DataPrompt,
ProcessorInputs,
PromptType,
SingletonPrompt,
TextPrompt,
@@ -73,10 +74,8 @@ from vllm.outputs import (
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, merge_kwargs
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import (
conversation_to_seq,
extract_prompt_components,
parse_model_prompt,
prompt_to_seq,
)
@@ -86,6 +85,7 @@ from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils.counter import Counter
from vllm.utils.tqdm_utils import maybe_tqdm
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor
@@ -400,7 +400,7 @@ class LLM:
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[RequestOutput]:
@@ -462,7 +462,7 @@ class LLM:
self,
prompts: PromptType | Sequence[PromptType],
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
tokenization_kwargs: dict[str, Any] | None = None,
@@ -495,34 +495,32 @@ class LLM:
# Use the same preprocessing as _run_completion
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(sampling_params, len(seq_prompts))
if any(param.truncate_prompt_tokens is not None for param in seq_params):
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt
for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_cmpl(
[prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
),
)
]
else:
engine_prompts = self._preprocess_cmpl(
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
seq_tok_kwargs = [
merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
)
for param in seq_params
]
seq_priority = self._priority_to_seq(priority, len(prompts))
request_ids = self._validate_and_add_requests(
prompts=engine_prompts,
params=seq_params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
request_ids = self._render_and_add_requests(
prompts=(
self._preprocess_cmpl_one(prompt, tok_kwargs)
for prompt, tok_kwargs in zip(
maybe_tqdm(
seq_prompts,
use_tqdm=use_tqdm,
desc="Rendering prompts",
),
seq_tok_kwargs,
)
),
params=seq_params,
lora_requests=seq_lora_requests,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
priorities=seq_priority,
)
return request_ids
@@ -545,53 +543,41 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def _get_modality_specific_lora_reqs(
def _resolve_lora_reqs(
self,
prompts: Sequence[DictPrompt | TokPrompt],
lora_request: list[LoRARequest] | LoRARequest | None,
prompts: Sequence[ProcessorInputs],
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
):
# Grab the lora config off the vllm config on the engine,
# since this is the same for both v0 & v1.
lora_config = self.llm_engine.vllm_config.lora_config
seq_lora_requests = self._lora_request_to_seq(lora_request, len(prompts))
# If there's no lora config / default_mm_loras, or the model
# isn't multimodal, leave the lora as is.
if (
lora_config is None
or not self.model_config.is_multimodal_model
or (lora_config and lora_config.default_mm_loras is None)
):
return lora_request
optional_loras = (
[lora_request] * len(prompts)
if not isinstance(lora_request, Sequence)
else lora_request
)
return seq_lora_requests
return [
self._resolve_single_prompt_mm_lora(
prompt,
opt_lora_req,
lora_req,
lora_config.default_mm_loras,
)
for prompt, opt_lora_req in zip(prompts, optional_loras)
for prompt, lora_req in zip(prompts, seq_lora_requests)
]
def _resolve_single_prompt_mm_lora(
self,
prompt: DictPrompt | TokPrompt,
prompt: ProcessorInputs,
lora_request: LoRARequest | None,
default_mm_loras: dict[str, str] | None,
):
if not default_mm_loras or not (
mm_data := prompt.get("multi_modal_data") or {}
):
if not default_mm_loras or prompt["type"] != "multimodal":
return lora_request
intersection = set(
mm_data.keys() # type: ignore
).intersection(default_mm_loras.keys())
prompt_modalities = prompt["mm_placeholders"].keys()
intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
if not intersection:
return lora_request
if len(intersection) > 1:
@@ -674,22 +660,6 @@ class LLM:
"""
return self.llm_engine.apply_model(func)
def _get_beam_search_lora_requests(
self,
lora_request: list[LoRARequest] | LoRARequest | None,
prompts: list[TokensPrompt | TextPrompt],
) -> list[LoRARequest | None]:
"""Get the optional lora request corresponding to each prompt."""
if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
raise ValueError(
"Lora request list should be the same length as the prompts"
)
if lora_request is None or isinstance(lora_request, LoRARequest):
return [lora_request] * len(prompts)
raise TypeError(f"Invalid lora_request type {type(lora_request)}")
def beam_search(
self,
prompts: list[TokensPrompt | TextPrompt],
@@ -718,13 +688,12 @@ class LLM:
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
tokenizer = self.renderer.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
tokenizer = self.get_tokenizer()
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id,
length_penalty,
)
engine_prompts = self._preprocess_cmpl(prompts)
lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts))
if use_tqdm and concurrency_limit is not None:
logger.warning(
@@ -734,21 +703,12 @@ class LLM:
use_tqdm = False
if concurrency_limit is None:
concurrency_limit = len(prompts)
def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
if beam.multi_modal_data is not None:
token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data
if beam.mm_processor_kwargs is not None:
token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
return TokensPrompt(**token_prompt_kwargs)
concurrency_limit = len(engine_prompts)
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams(
sampling_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
@@ -756,30 +716,25 @@ class LLM:
)
instances: list[BeamSearchInstance] = []
for lora_req, prompt in zip(lora_requests, prompts):
# Add multimodal processor kwargs & data
mm_kwargs = {}
if "multi_modal_data" in prompt:
mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
if "mm_processor_kwargs" in prompt:
mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
if "prompt_token_ids" in prompt:
prompt = cast(TokensPrompt, prompt) # Needed for mypy
prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])
for lora_req, prompt in zip(lora_requests, engine_prompts):
if prompt["type"] == "embeds":
raise NotImplementedError(
"Embedding prompt not supported for beam search"
)
if prompt["type"] == "enc_dec":
raise NotImplementedError(
"Encoder-decoder prompt not supported for beam search"
)
instances.append(
BeamSearchInstance(
prompt_tokens,
prompt,
lora_request=lora_req,
logprobs=None,
**mm_kwargs,
),
)
for prompt_start in range(0, len(prompts), concurrency_limit):
for prompt_start in range(0, len(instances), concurrency_limit):
instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
token_iter = range(max_tokens)
@@ -808,22 +763,15 @@ class LLM:
if len(all_beams) == 0:
break
# create corresponding batch entries for prompt & optional lora
prompts_batch, lora_req_batch = zip(
*[
(create_tokens_prompt_from_beam(beam), beam.lora_request)
for beam in all_beams
]
)
# only runs for one step
# we don't need to use tqdm here
output = self.generate(
prompts_batch,
sampling_params=beam_search_params,
raw_output = self._render_and_run_requests(
prompts=(beam.get_prompt() for beam in all_beams),
params=self._params_to_seq(sampling_params, len(all_beams)),
lora_requests=[beam.lora_request for beam in all_beams],
use_tqdm=False,
lora_request=lora_req_batch,
)
output = self.engine_class.validate_outputs(raw_output, RequestOutput)
for (start, end), instance in zip(
instance_start_and_end, instances_batch
@@ -841,19 +789,15 @@ class LLM:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
current_beam.orig_prompt,
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.mm_processor_kwargs,
)
if (
token_id == tokenizer.eos_token_id
and not ignore_eos
):
if token_id == eos_token_id and not ignore_eos:
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
@@ -872,6 +816,7 @@ class LLM:
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens)
outputs.append(BeamSearchOutput(sequences=best_beams))
return outputs
@@ -880,7 +825,7 @@ class LLM:
self,
prompts: Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[DictPrompt | TokPrompt]:
) -> Sequence[ProcessorInputs]:
"""
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
a format that can be passed to `_add_request`.
@@ -888,8 +833,7 @@ class LLM:
Refer to [LLM.generate][] for a complete description of the arguments.
Returns:
A list of `TokPrompt` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
"""
renderer = self.renderer
model_config = self.model_config
@@ -903,6 +847,14 @@ class LLM:
return renderer.render_cmpl(parsed_prompts, tok_params)
def _preprocess_cmpl_one(
self,
prompt: PromptType,
tokenization_kwargs: dict[str, Any] | None = None,
) -> ProcessorInputs:
(engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
return engine_prompt
def _preprocess_chat(
self,
conversations: Sequence[list[ChatCompletionMessageParam]],
@@ -914,7 +866,7 @@ class LLM:
tools: list[dict[str, Any]] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> Sequence[TokPrompt]:
) -> Sequence[ProcessorInputs]:
"""
Convert a list of conversations into prompts so that they can then
be used as input for other LLM APIs.
@@ -922,8 +874,7 @@ class LLM:
Refer to [LLM.chat][] for a complete description of the arguments.
Returns:
A list of `TokPrompt` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
"""
renderer = self.renderer
@@ -953,13 +904,39 @@ class LLM:
return engine_prompts
def _preprocess_chat_one(
self,
conversation: list[ChatCompletionMessageParam],
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
chat_template_kwargs: dict[str, Any] | None = None,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> ProcessorInputs:
(engine_prompt,) = self._preprocess_chat(
[conversation],
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
return engine_prompt
def chat(
self,
messages: list[ChatCompletionMessageParam]
| Sequence[list[ChatCompletionMessageParam]],
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: LoRARequest | None = None,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
@@ -1805,47 +1782,41 @@ class LLM:
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
):
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(params, len(seq_prompts))
if any(param.truncate_prompt_tokens is not None for param in seq_params):
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
# Then, move the code from the `else` block to the top and let
# `self._preprocess_cmpl` handle prompt normalization
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt
for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_cmpl(
[prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
),
)
]
else:
engine_prompts = self._preprocess_cmpl(
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
seq_tok_kwargs = [
merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
)
for param in seq_params
]
seq_priority = self._priority_to_seq(priority, len(prompts))
self._validate_and_add_requests(
prompts=engine_prompts,
return self._render_and_run_requests(
prompts=(
self._preprocess_cmpl_one(prompt, tok_kwargs)
for prompt, tok_kwargs in zip(
maybe_tqdm(
seq_prompts,
use_tqdm=use_tqdm,
desc="Rendering prompts",
),
seq_tok_kwargs,
)
),
params=seq_params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
lora_requests=seq_lora_requests,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
priorities=seq_priority,
)
return self._run_engine(use_tqdm=use_tqdm)
def _run_chat(
self,
messages: list[ChatCompletionMessageParam]
@@ -1855,7 +1826,7 @@ class LLM:
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: LoRARequest | None = None,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
@@ -1865,68 +1836,94 @@ class LLM:
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
):
engine_prompts = self._preprocess_chat(
conversation_to_seq(messages),
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
seq_convs = conversation_to_seq(messages)
seq_params = self._params_to_seq(params, len(seq_convs))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
seq_tok_kwargs = [
merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
)
for param in seq_params
]
return self._render_and_run_requests(
prompts=(
self._preprocess_chat_one(
conversation,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenization_kwargs=tok_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
for conversation, tok_kwargs in zip(
maybe_tqdm(
seq_convs,
use_tqdm=use_tqdm,
desc="Rendering conversations",
),
seq_tok_kwargs,
)
),
params=seq_params,
lora_requests=seq_lora_requests,
use_tqdm=use_tqdm,
tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
self._validate_and_add_requests(
prompts=engine_prompts,
def _render_and_run_requests(
self,
prompts: Iterable[ProcessorInputs],
params: Sequence[SamplingParams | PoolingParams],
*,
lora_requests: Sequence[LoRARequest | None] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
priorities: Sequence[int] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
):
if isinstance(prompts, (list, tuple)):
logger.warning_once(
"Rendering all prompts before adding them to the engine "
"is less efficient than performing both on the same prompt "
"before processing the next prompt. You should instead pass "
"a generator that renders one prompt per iteration, as that allows "
"engine execution to begin for the first prompt while processing "
"the next prompt."
)
self._render_and_add_requests(
prompts=prompts,
params=params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
lora_requests=lora_requests,
tokenization_kwargs=tokenization_kwargs,
priorities=priorities,
)
return self._run_engine(use_tqdm=use_tqdm)
def _validate_and_add_requests(
def _render_and_add_requests(
self,
prompts: Sequence[DictPrompt | TokPrompt],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
prompts: Iterable[ProcessorInputs],
params: Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
lora_requests: Sequence[LoRARequest | None] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None,
priorities: Sequence[int] | None = None,
) -> list[str]:
num_requests = len(prompts)
seq_params = self._params_to_seq(params, num_requests)
seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests)
seq_priority = self._priority_to_seq(priority, num_requests)
for sp in seq_params:
if isinstance(sp, SamplingParams):
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
it = prompts
if use_tqdm:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
added_request_ids: list[str] = []
try:
for i, prompt in enumerate(it):
for i, prompt in enumerate(prompts):
request_id = self._add_request(
prompt,
seq_params[i],
lora_request=seq_lora_requests[i],
params[i],
lora_request=None if lora_requests is None else lora_requests[i],
tokenization_kwargs=tokenization_kwargs,
priority=seq_priority[i],
priority=0 if priorities is None else priorities[i],
)
added_request_ids.append(request_id)
except Exception as e:
@@ -1938,13 +1935,16 @@ class LLM:
def _add_request(
self,
prompt: PromptType | DictPrompt | TokPrompt,
prompt: ProcessorInputs,
params: SamplingParams | PoolingParams,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: int = 0,
) -> str:
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
if isinstance(params, SamplingParams):
# We only care about the final output
params.output_kind = RequestOutputKind.FINAL_ONLY
request_id = str(next(self.request_counter))
if params.truncate_prompt_tokens is not None:
@@ -1962,33 +1962,15 @@ class LLM:
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
)
renderer = self.renderer
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
return self.llm_engine.add_request(
request_id,
prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
supported_tasks=self.supported_tasks,
)
self.llm_engine.add_request(
request_id,
engine_request,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
prompt_text=prompt_text,
)
return engine_request.request_id
def _run_engine(
self,
*,

View File

@@ -67,13 +67,12 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
)
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import ProcessorInputs, TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.parser import ParserManager
from vllm.reasoning import ReasoningParser
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
@@ -221,7 +220,7 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request(
self,
request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[TokPrompt]] | ErrorResponse:
) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse:
"""
render chat request by validating and preprocessing inputs.
@@ -380,7 +379,9 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text = self._extract_prompt_text(engine_prompt)
prompt_token_ids = self._extract_prompt_components(
engine_prompt
).token_ids
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
@@ -431,35 +432,21 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers,
)
else:
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
reasoning_ended = (
reasoning_parser.is_reasoning_end(prompt_token_ids or [])
if reasoning_parser
else None
)
engine_request = self.input_processor.process_inputs(
sub_request_id,
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
reasoning_ended = None
if reasoning_parser:
reasoning_ended = reasoning_parser.is_reasoning_end(
engine_request.prompt_token_ids or [] # type: ignore[attr-defined]
)
engine_request.reasoning_ended = reasoning_ended
generator = self.engine_client.generate(
engine_request,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank,
reasoning_ended=reasoning_ended,
)
generators.append(generator)

View File

@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import (
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ProcessorInputs
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators
@@ -80,7 +80,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def render_completion_request(
self,
request: CompletionRequest,
) -> list[TokPrompt] | ErrorResponse:
) -> list[ProcessorInputs] | ErrorResponse:
"""
render completion request by validating and preprocessing inputs.
@@ -163,8 +163,6 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text = self._extract_prompt_text(engine_prompt)
max_tokens = get_max_tokens(
max_model_len,
request.max_tokens,
@@ -208,29 +206,13 @@ class OpenAIServingCompletion(OpenAIServing):
trace_headers=trace_headers,
)
else:
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id_item,
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generator = self.engine_client.generate(
engine_request,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank,
)
@@ -312,7 +294,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator(
self,
request: CompletionRequest,
engine_prompts: list[TokPrompt],
engine_prompts: list[ProcessorInputs],
result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str,
created_time: int,

View File

@@ -96,15 +96,19 @@ from vllm.entrypoints.serve.tokenize.protocol import (
)
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType, SingletonPrompt, TokensPrompt
from vllm.inputs.data import (
ProcessorInputs,
PromptType,
SingletonPrompt,
TokensPrompt,
token_inputs,
)
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import (
extract_prompt_components,
extract_prompt_len,
@@ -206,7 +210,7 @@ class ServeContext(Generic[RequestT]):
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[TokPrompt] | None = None
engine_prompts: list[ProcessorInputs] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
@@ -249,7 +253,7 @@ class OpenAIServing:
async def beam_search(
self,
prompt: TokPrompt,
prompt: ProcessorInputs,
request_id: str,
params: BeamSearchParams,
lora_request: LoRARequest | None = None,
@@ -262,86 +266,53 @@ class OpenAIServing:
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output
input_processor = self.input_processor
tokenizer = input_processor.tokenizer
if tokenizer is None:
raise VLLMValidationError(
"You cannot use beam search when `skip_tokenizer_init=True`",
parameter="skip_tokenizer_init",
value=True,
)
eos_token_id: int = tokenizer.eos_token_id # type: ignore
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
raise NotImplementedError("Encoder-decoder prompt not supported")
prompt_text: str | None = prompt.get("prompt") # type: ignore
prompt_token_ids: list[int] = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data: MultiModalDataDict | None = prompt.get("multi_modal_data") # type: ignore
mm_processor_kwargs: dict[str, Any] | None = None
# This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
tokenized_length = len(prompt_token_ids)
tokenizer = self.renderer.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
if prompt["type"] == "embeds":
raise NotImplementedError("Embedding prompt not supported for beam search")
if prompt["type"] == "enc_dec":
raise NotImplementedError(
"Encoder-decoder prompt not supported for beam search"
)
prompt_text = prompt.get("prompt")
prompt_token_ids = prompt["prompt_token_ids"]
tokenized_length = len(prompt_token_ids)
logprobs_num = 2 * beam_width
beam_search_params = SamplingParams(
sampling_params = SamplingParams(
logprobs=logprobs_num,
max_tokens=1,
temperature=temperature,
)
all_beams = [
BeamSearchSequence(
orig_prompt=prompt,
tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
lora_request=lora_request,
)
]
completed = []
for _ in range(max_tokens):
prompts_batch, lora_req_batch = zip(
*[
(
TokensPrompt(
prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs,
),
beam.lora_request,
)
for beam in all_beams
]
)
tasks = []
request_id_batch = f"{request_id}-{random_uuid()}"
for i, (individual_prompt, lora_req) in enumerate(
zip(prompts_batch, lora_req_batch)
):
for i, beam in enumerate(all_beams):
prompt_item = beam.get_prompt()
lora_request_item = beam.lora_request
request_id_item = f"{request_id_batch}-beam-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.engine_client.generate(
individual_prompt,
beam_search_params,
prompt_item,
sampling_params,
request_id_item,
lora_request=lora_req,
lora_request=lora_request_item,
trace_headers=trace_headers,
)
)
@@ -406,6 +377,7 @@ class OpenAIServing:
logprobs_entry = result.outputs[0].logprobs[0]
completed.append(
BeamSearchSequence(
orig_prompt=prompt,
tokens=current_beam.tokens + [eos_token_id]
if include_stop_str_in_output
else current_beam.tokens,
@@ -433,12 +405,11 @@ class OpenAIServing:
logprobs_entry = result.outputs[0].logprobs[0]
new_beams.append(
BeamSearchSequence(
orig_prompt=prompt,
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs_entry],
lora_request=current_beam.lora_request,
cum_logprob=float(all_beams_logprob[idx]),
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.mm_processor_kwargs,
)
)
@@ -958,7 +929,7 @@ class OpenAIServing:
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokPrompt]:
) -> list[ProcessorInputs]:
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
@@ -971,7 +942,7 @@ class OpenAIServing:
self,
request: RendererRequest,
prompts: Sequence[PromptType | bytes],
) -> list[TokPrompt]:
) -> list[ProcessorInputs]:
renderer = self.renderer
model_config = self.model_config
@@ -1004,7 +975,7 @@ class OpenAIServing:
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
from vllm.tokenizers.mistral import MistralTokenizer
renderer = self.renderer
@@ -1052,13 +1023,13 @@ class OpenAIServing:
return conversation, [engine_prompt]
def _extract_prompt_components(self, prompt: object):
def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
return extract_prompt_components(self.model_config, prompt)
def _extract_prompt_text(self, prompt: object):
def _extract_prompt_text(self, prompt: ProcessorInputs):
return self._extract_prompt_components(prompt).text
def _extract_prompt_len(self, prompt: object):
def _extract_prompt_len(self, prompt: ProcessorInputs):
return extract_prompt_len(self.model_config, prompt)
async def _render_next_turn(
@@ -1088,16 +1059,14 @@ class OpenAIServing:
async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: TokPrompt,
engine_prompt: ProcessorInputs,
sampling_params: SamplingParams,
tok_params: TokenizeParams,
context: ConversationContext,
lora_request: LoRARequest | None = None,
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
):
max_model_len = self.model_config.max_model_len
prompt_text = self._extract_prompt_text(engine_prompt)
orig_priority = priority
sub_request = 0
@@ -1112,26 +1081,13 @@ class OpenAIServing:
lora_request=lora_request,
)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
sub_request_id,
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
)
generator = self.engine_client.generate(
engine_request,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
)
async for res in generator:
@@ -1154,11 +1110,11 @@ class OpenAIServing:
# Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=token_ids)
engine_prompt = token_inputs(token_ids)
sampling_params.max_tokens = max_model_len - len(token_ids)
elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn(
(engine_prompt,) = await self._render_next_turn(
context.request,
context.parser.response_messages,
context.tool_dicts,
@@ -1166,8 +1122,6 @@ class OpenAIServing:
context.chat_template,
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
prompt_text = self._extract_prompt_text(engine_prompt)
sampling_params.max_tokens = get_max_tokens(
max_model_len,
@@ -1184,7 +1138,7 @@ class OpenAIServing:
def _log_inputs(
self,
request_id: str,
inputs: PromptType | TokPrompt,
inputs: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:

View File

@@ -15,6 +15,7 @@ from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsRealtime
from vllm.renderers.inputs.preprocess import parse_model_prompt
logger = init_logger(__name__)
@@ -70,15 +71,20 @@ class OpenAIServingRealtime(OpenAIServing):
Yields:
StreamingInput objects containing audio prompts for the engine
"""
model_config = self.model_config
renderer = self.renderer
# mypy is being stupid
# TODO(Patrick) - fix this
stream_input_iter = cast(
AsyncGenerator[PromptType, None],
self.model_cls.buffer_realtime_audio(
audio_stream, input_stream, self.model_config
audio_stream, input_stream, model_config
),
)
async for prompt in stream_input_iter:
yield StreamingInput(prompt=prompt)
parsed_prompt = parse_model_prompt(model_config, prompt)
(engine_prompt,) = await renderer.render_cmpl_async([parsed_prompt])
yield StreamingInput(prompt=engine_prompt)

View File

@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import AsyncExitStack
from dataclasses import replace
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Final, Union
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
@@ -304,7 +304,7 @@ class ParsableContext(ConversationContext):
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.chat_template_content_format: Final = chat_template_content_format
self.input_messages: list[ResponseRawMessageAndToken] = []
self.output_messages: list[ResponseRawMessageAndToken] = []

View File

@@ -116,13 +116,12 @@ from vllm.entrypoints.openai.responses.utils import (
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import ProcessorInputs, token_inputs
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.parser import ParserManager
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
@@ -298,7 +297,7 @@ class OpenAIServingResponses(OpenAIServing):
def _validate_generator_input(
self,
engine_prompt: TokPrompt,
engine_prompt: ProcessorInputs,
) -> ErrorResponse | None:
"""Add validations to the input to the generator here."""
prompt_len = self._extract_prompt_len(engine_prompt)
@@ -458,7 +457,6 @@ class OpenAIServingResponses(OpenAIServing):
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)
tok_params = request.build_tok_params(self.model_config)
trace_headers = (
None
@@ -512,7 +510,6 @@ class OpenAIServingResponses(OpenAIServing):
request_id=request.request_id,
engine_prompt=engine_prompt,
sampling_params=sampling_params,
tok_params=tok_params,
context=context,
lora_request=lora_request,
priority=request.priority,
@@ -647,7 +644,7 @@ class OpenAIServingResponses(OpenAIServing):
messages = self._construct_input_messages_with_harmony(request, prev_response)
prompt_token_ids = render_for_completion(messages)
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
engine_prompt = token_inputs(prompt_token_ids)
# Add cache_salt if provided in the request
if request.cache_salt is not None:

View File

@@ -36,14 +36,15 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranslationSegment,
TranslationStreamResponse,
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType
from vllm.inputs import ProcessorInputs
from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription, supports_transcription
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule
@@ -202,8 +203,6 @@ class OpenAISpeechToText(OpenAIServing):
return
try:
from vllm.sampling_params import SamplingParams
warmup_start = time.perf_counter()
logger.info("Warming up multimodal input processor...")
@@ -221,21 +220,11 @@ class OpenAISpeechToText(OpenAIServing):
request_prompt="",
to_language=None,
)
# Create minimal sampling params
dummy_params = SamplingParams(
max_tokens=1,
temperature=0.0,
skip_clone=True, # Internal warmup, safe to skip clone
)
parsed_prompt = parse_model_prompt(self.model_config, dummy_prompt)
# Process the dummy input through the input processor
# This will trigger all the multimodal processing initialization
_ = self.input_processor.process_inputs(
request_id="warmup",
prompt=dummy_prompt,
params=dummy_params,
)
_ = self.renderer.render_cmpl([parsed_prompt])
warmup_elapsed = time.perf_counter() - warmup_start
logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
@@ -257,7 +246,7 @@ class OpenAISpeechToText(OpenAIServing):
self,
request: SpeechToTextRequest,
audio_data: bytes,
) -> tuple[list[PromptType], float]:
) -> tuple[list[ProcessorInputs], float]:
# Validate request
language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper.
@@ -285,7 +274,7 @@ class OpenAISpeechToText(OpenAIServing):
and duration > self.asr_config.max_audio_clip_s
)
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
prompts = []
parsed_prompts: list[DictPrompt] = []
for chunk in chunks:
# The model has control over the construction, as long as it
# returns a valid PromptType.
@@ -298,12 +287,19 @@ class OpenAISpeechToText(OpenAIServing):
request_prompt=request.prompt,
to_language=to_language,
)
parsed_prompt: DictPrompt
if request.response_format == "verbose_json":
prompt = self._preprocess_verbose_prompt(parse_enc_dec_prompt(prompt))
parsed_prompt = parse_enc_dec_prompt(prompt)
parsed_prompt = self._preprocess_verbose_prompt(parsed_prompt)
else:
parsed_prompt = parse_model_prompt(self.model_config, prompt)
prompts.append(prompt)
parsed_prompts.append(parsed_prompt)
return prompts, duration
engine_prompts = await self.renderer.render_cmpl_async(parsed_prompts)
return engine_prompts, duration
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
dec_prompt = prompt["decoder_prompt"]
@@ -436,7 +432,7 @@ class OpenAISpeechToText(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
prompts, duration_s = await self._preprocess_speech_to_text(
engine_prompts, duration_s = await self._preprocess_speech_to_text(
request=request,
audio_data=audio_data,
)
@@ -445,57 +441,54 @@ class OpenAISpeechToText(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e)
# Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
try:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# generated by respecting the extra completion tokens arg.
if request.max_completion_tokens is None:
default_max_tokens = self.model_config.max_model_len
else:
default_max_tokens = min(
self.model_config.max_model_len, request.max_completion_tokens
)
max_tokens = get_max_tokens(
max_model_len,
request.max_completion_tokens,
0,
self.default_sampling_params,
)
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
max_tokens,
self.default_sampling_params,
)
if request.response_format == "verbose_json":
sampling_params.logprobs = 1
self._log_inputs(
request_id,
# It will not display special tokens like <|startoftranscript|>
request.prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
list_result_generator = []
for i, prompt in enumerate(prompts):
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}_{i}"
engine_request = self.input_processor.process_inputs(
self._log_inputs(
request_id_item,
prompt,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=0,
)
list_result_generator.append(
self.engine_client.generate(
engine_request,
sampling_params,
request_id_item,
lora_request=lora_request,
)
)
list_result_generator.append(generator)
except ValueError as e:
return self.create_error_response(e)

View File

@@ -28,11 +28,10 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import TokPrompt
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import EmbedDType, Endianness
@@ -256,7 +255,7 @@ class OpenAIServingEmbedding(OpenAIServing):
chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
# Create engine prompt for this chunk
chunk_engine_prompt = TokensPrompt(prompt_token_ids=chunk_tokens)
chunk_engine_prompt = token_inputs(chunk_tokens)
# Log the chunk
self._log_inputs(
@@ -266,16 +265,12 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request,
)
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Create generator for this chunk and wrap it to return indices
original_generator = self.engine_client.encode(
chunk_engine_prompt,
pooling_params,
chunk_request_id,
lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=ctx.request.priority,
)
@@ -362,7 +357,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: TokPrompt,
engine_prompt: ProcessorInputs,
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
@@ -377,16 +372,12 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request,
)
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Return the original generator without wrapping
return self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=ctx.request.priority,
)

View File

@@ -33,10 +33,9 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.inputs import PromptType
from vllm.inputs import ProcessorInputs
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import prompt_to_seq
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
@@ -93,7 +92,7 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported"
)
engine_prompts: Sequence[PromptType | TokPrompt]
engine_prompts: Sequence[ProcessorInputs]
if use_io_processor := isinstance(request, IOProcessorRequest):
if self.io_processor is None:
raise ValueError(
@@ -152,9 +151,6 @@ class OpenAIServingPooling(OpenAIServing):
else:
pooling_params = request.to_pooling_params() # type: ignore
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
@@ -176,7 +172,6 @@ class OpenAIServingPooling(OpenAIServing):
pooling_params,
request_id_item,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
)

View File

@@ -35,7 +35,7 @@ from vllm.entrypoints.pooling.score.utils import (
get_score_prompt,
validate_score_input,
)
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
@@ -108,12 +108,15 @@ class ServingScores(OpenAIServing):
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
)
engine_prompts: list[TokensPrompt] = []
engine_prompts: list[ProcessorInputs] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_prompts.append(
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
token_inputs(
text_token_prompt["prompt_token_ids"],
prompt=input_text,
)
)
# Schedule the request and get the result generator.
@@ -125,7 +128,7 @@ class ServingScores(OpenAIServing):
self._log_inputs(
request_id_item,
input_texts[i],
engine_prompt,
params=pooling_params,
lora_request=lora_request,
)
@@ -207,12 +210,15 @@ class ServingScores(OpenAIServing):
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
)
engine_prompts: list[TokensPrompt] = []
engine_prompts: list[ProcessorInputs] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_prompts.append(
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
token_inputs(
text_token_prompt["prompt_token_ids"],
prompt=input_text,
)
)
# Schedule the request and get the result generator.
@@ -225,7 +231,7 @@ class ServingScores(OpenAIServing):
self._log_inputs(
request_id_item,
input_texts[i],
engine_prompt,
params=pooling_params,
lora_request=lora_request,
)

View File

@@ -29,7 +29,6 @@ from vllm.entrypoints.serve.disagg.protocol import (
GenerateResponse,
GenerateResponseChoice,
)
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
@@ -116,7 +115,7 @@ class ServingTokens(OpenAIServing):
self._log_inputs(
request_id,
TokensPrompt(prompt_token_ids=request.token_ids),
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
@@ -127,27 +126,13 @@ class ServingTokens(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id,
result_generator = self.engine_client.generate(
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
)
result_generator = self.engine_client.generate(
engine_request,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
tokenization_kwargs=tokenization_kwargs,
)
except ValueError as e:

View File

@@ -20,7 +20,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeResponse,
TokenizerInfoResponse,
)
from vllm.inputs import TokensPrompt
from vllm.inputs import TokensPrompt, token_inputs
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
@@ -135,7 +135,7 @@ class OpenAIServingTokenization(OpenAIServing):
self._log_inputs(
request_id,
TokensPrompt(prompt_token_ids=request.tokens),
token_inputs(request.tokens),
params=None,
lora_request=lora_request,
)

View File

@@ -187,6 +187,9 @@ class _InputOptions(TypedDict):
Additional options available to all input types.
"""
arrival_time: NotRequired[float]
"""The time when the input was received (before rendering)."""
cache_salt: NotRequired[str]
"""Optional cache salt to be used for prefix caching."""
@@ -300,6 +303,9 @@ class EncoderDecoderInputs(TypedDict):
decoder_prompt: DecoderInputs
"""The inputs for the decoder portion."""
arrival_time: NotRequired[float]
"""The time when the input was received (before rendering)."""
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
"""

View File

@@ -19,11 +19,9 @@ from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.renderers.inputs import (
DecoderDictPrompt,
DecoderOnlyDictPrompt,
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDictPrompt,
SingletonDictPrompt,
TokPrompt,
)
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
from vllm.tokenizers import TokenizerLike
@@ -41,7 +39,6 @@ from .data import (
TextPrompt,
TokenInputs,
TokensPrompt,
embeds_inputs,
token_inputs,
)
@@ -83,7 +80,7 @@ class InputPreprocessor:
**(tokenization_kwargs or {})
)
tok_prompt = renderer.tokenize_prompt(
tok_prompt = renderer._tokenize_singleton_prompt(
TextPrompt(prompt=prompt),
tok_params,
)
@@ -103,17 +100,10 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
"""
mm_processor = self.renderer.get_mm_processor()
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
mm_items = mm_processor.info.parse_mm_data(mm_data)
return mm_processor.apply(
return self.renderer._process_multimodal(
prompt,
mm_items,
hf_processor_mm_kwargs=mm_processor_kwargs,
mm_data,
mm_processor_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
@@ -122,31 +112,7 @@ class InputPreprocessor:
self,
parsed_content: EmbedsPrompt,
) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
prompt_embeds = parsed_content["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
# Tensors must be on CPU for serialization between processes
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
# hidden device transfer in the critical path of generation.
prompt_embeds = prompt_embeds.cpu()
return embeds_inputs(
prompt_embeds=prompt_embeds, cache_salt=parsed_content.get("cache_salt")
)
return self.renderer._process_embeds(parsed_content)
def _truncate_inputs(
self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
@@ -157,7 +123,7 @@ class InputPreprocessor:
**(tokenization_kwargs or {})
)
tok_prompt = renderer.tokenize_prompt(
tok_prompt = renderer._tokenize_singleton_prompt(
TokensPrompt(prompt_token_ids=inputs),
tok_params,
)
@@ -168,8 +134,6 @@ class InputPreprocessor:
self,
parsed_content: TokensPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> TokenInputs | MultiModalInputs:
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs
@@ -182,11 +146,13 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs") or {},
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
mm_uuids=parsed_content.get("multi_modal_uuids"),
)
else:
inputs = token_inputs(prompt_token_ids)
if prompt_text := parsed_content.get("prompt"):
inputs["prompt"] = prompt_text
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
@@ -196,8 +162,6 @@ class InputPreprocessor:
self,
parsed_content: TextPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> TokenInputs | MultiModalInputs:
prompt_text = parsed_content["prompt"]
@@ -208,7 +172,6 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs") or {},
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
else:
prompt_token_ids = self._tokenize_prompt(
@@ -217,6 +180,8 @@ class InputPreprocessor:
)
inputs = token_inputs(prompt_token_ids)
inputs["prompt"] = prompt_text
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
@@ -227,8 +192,6 @@ class InputPreprocessor:
self,
prompt: EncoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> EncoderInputs: ...
@overload
@@ -236,8 +199,6 @@ class InputPreprocessor:
self,
prompt: DecoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderInputs: ...
@overload
@@ -245,16 +206,12 @@ class InputPreprocessor:
self,
prompt: DecoderOnlyDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderOnlyInputs: ...
def _prompt_to_llm_inputs(
self,
prompt: SingletonDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> SingletonInputs:
"""
Extract the singleton inputs from a prompt.
@@ -271,16 +228,12 @@ class InputPreprocessor:
return self._process_embeds(prompt) # type: ignore[arg-type]
if "prompt_token_ids" in prompt:
return self._process_tokens(
prompt, # type: ignore[arg-type]
mm_uuids=mm_uuids,
)
return self._process_tokens(prompt) # type: ignore[arg-type]
if "prompt" in prompt:
return self._process_text(
prompt, # type: ignore[arg-type]
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
assert_never(prompt) # type: ignore[arg-type]
@@ -289,8 +242,6 @@ class InputPreprocessor:
self,
prompt: EncoderDecoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
@@ -314,7 +265,6 @@ class InputPreprocessor:
encoder_inputs=self._prompt_to_llm_inputs(
encoder_prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
),
decoder_inputs=(
None
@@ -331,8 +281,6 @@ class InputPreprocessor:
self,
prompt: DecoderOnlyDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderOnlyInputs:
"""
For decoder-only models:
@@ -350,41 +298,23 @@ class InputPreprocessor:
return self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
def _preprocess(
def preprocess(
self,
prompt: PromptType | DictPrompt | TokPrompt,
prompt: PromptType,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder.
return self._process_encoder_decoder_prompt(
parse_enc_dec_prompt(prompt),
tokenization_kwargs,
mm_uuids=mm_uuids,
)
return self._process_decoder_only_prompt(
parse_dec_only_prompt(prompt),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
def preprocess(
self,
prompt: PromptType | DictPrompt | TokPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)
self.renderer.update_mm_cache_stats()
return res

View File

@@ -48,7 +48,6 @@ from vllm.multimodal.processing import (
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor
@@ -810,13 +809,7 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
# Use getattr with default to be compatible with transformers<4.48
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
audio_token_id = vocab[audio_token]
audio_token_id = processor.audio_token_id
out_mm_data = out_mm_kwargs.get_data()
@@ -836,17 +829,12 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
num_features = audio_embeds.shape[0]
audio_tokens = [audio_token_id] * num_features
return PromptUpdateDetails.select_token_id(
audio_tokens,
embed_token_id=audio_token_id,
)
return [audio_token_id] * num_features
return [
PromptReplacement(
modality="audio",
target=audio_token,
target=[audio_token_id],
replacement=get_replacement_qwen2_audio,
)
]

View File

@@ -59,7 +59,6 @@ from vllm.multimodal.processing import (
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -187,8 +186,10 @@ class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingIn
hf_processor = self.info.get_hf_processor()
audio_token = hf_processor.audio_token
audio_bos_token = hf_processor.audio_bos_token
audio_eos_token = hf_processor.audio_eos_token
return audio_token * num_audios
return (audio_bos_token + audio_token + audio_eos_token) * num_audios
def get_dummy_mm_data(
self,
@@ -262,17 +263,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
# Use getattr with default to be compatible with transformers<4.48
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
audio_bos_token = getattr(processor, "audio_bos_token", "<|audio_bos|>")
audio_eos_token = getattr(processor, "audio_eos_token", "<|audio_eos|>")
audio_token_id = vocab[audio_token]
audio_bos_id = vocab[audio_bos_token]
audio_eos_id = vocab[audio_eos_token]
audio_token_id = processor.audio_token_id
out_mm_data = out_mm_kwargs.get_data()
feature_attention_mask = out_mm_data.get("feature_attention_mask")
@@ -303,17 +294,12 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing
"to be represented inside the model"
)
audio_tokens = [audio_token_id] * num_features
return PromptUpdateDetails.select_token_id(
[audio_bos_id] + audio_tokens + [audio_eos_id],
embed_token_id=audio_token_id,
)
return [audio_token_id] * num_features
return [
PromptReplacement(
modality="audio",
target=audio_token,
target=[audio_token_id],
replacement=get_replacement_qwen2_audio,
)
]

View File

@@ -1843,15 +1843,18 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
tokenizer = self.info.get_tokenizer()
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_items)
if isinstance(decoder_prompt_raw, str):
decoder_prompt_text = decoder_prompt_raw
decoder_prompt_ids = tokenizer.encode(
decoder_prompt_raw, add_special_tokens=False
)
else:
decoder_prompt_text = None
decoder_prompt_ids = decoder_prompt_raw
return mm_enc_dec_inputs(
encoder_inputs,
decoder_prompt_ids,
decoder_prompt=decoder_prompt_text,
)
def apply(

View File

@@ -19,7 +19,6 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.attention.selector import AttentionSelectorConfig
@@ -569,7 +568,7 @@ class Platform:
@classmethod
def validate_request(
cls,
prompt: "PromptType | DictPrompt | TokPrompt",
prompt: "PromptType | ProcessorInputs",
params: "SamplingParams | PoolingParams",
processed_inputs: "ProcessorInputs",
) -> None:

View File

@@ -1,17 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload
from typing_extensions import TypeVar
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.inputs import (
EmbedsInputs,
EmbedsPrompt,
EncoderDecoderInputs,
ProcessorInputs,
SingletonInputs,
TextPrompt,
TokenInputs,
TokensPrompt,
)
from vllm.inputs.data import build_enc_dec_inputs, embeds_inputs, token_inputs
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.counter import AtomicCounter
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.metrics.stats import MultiModalCacheStats
@@ -20,6 +32,8 @@ from .inputs import (
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDecoderTokPrompt,
SingletonDictPrompt,
SingletonTokPrompt,
TokPrompt,
)
from .inputs.preprocess import extract_target_prompt
@@ -32,6 +46,12 @@ if TYPE_CHECKING:
ConversationMessage,
)
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalInputs,
MultiModalUUIDDict,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import BaseMultiModalProcessor
logger = init_logger(__name__)
@@ -79,6 +99,10 @@ class BaseRenderer(ABC, Generic[_T]):
if mm_processor_cache:
self._mm_cache_stats = MultiModalCacheStats()
# This is used to generate internal request ID for MM processing
# It has no relation to the request ID for engine core
self._mm_req_counter = AtomicCounter()
def get_tokenizer(self) -> _T:
tokenizer = self.tokenizer
if tokenizer is None:
@@ -284,17 +308,79 @@ class BaseRenderer(ABC, Generic[_T]):
return prompt
@overload
def _tokenize_singleton_prompt(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
def _tokenize_singleton_prompt( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
def _tokenize_singleton_prompt(
self,
prompt: SingletonDictPrompt,
params: TokenizeParams,
) -> SingletonTokPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
prompt = self._tokenize_prompt(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
@overload
async def _tokenize_singleton_prompt_async(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
async def _tokenize_singleton_prompt_async( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
async def _tokenize_singleton_prompt_async(
self,
prompt: SingletonDictPrompt,
params: TokenizeParams,
) -> SingletonTokPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
prompt = await self._tokenize_prompt_async(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
def _tokenize_enc_dec_prompt(
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = (
self.tokenize_prompt(prompt["encoder_prompt"], params),
self._tokenize_singleton_prompt(prompt["encoder_prompt"], params),
(
None
if prompt["decoder_prompt"] is None
else self.tokenize_prompt(prompt["decoder_prompt"], params)
else self._tokenize_singleton_prompt(prompt["decoder_prompt"], params)
),
)
@@ -309,11 +395,13 @@ class BaseRenderer(ABC, Generic[_T]):
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = await asyncio.gather(
self.tokenize_prompt_async(prompt["encoder_prompt"], params),
self._tokenize_singleton_prompt_async(prompt["encoder_prompt"], params),
(
asyncio.sleep(0)
if prompt["decoder_prompt"] is None
else self.tokenize_prompt_async(prompt["decoder_prompt"], params)
else self._tokenize_singleton_prompt_async(
prompt["decoder_prompt"], params
)
),
)
@@ -322,27 +410,6 @@ class BaseRenderer(ABC, Generic[_T]):
decoder_prompt=dec_prompt,
)
@overload
def tokenize_prompt(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
def tokenize_prompt( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
@overload
def tokenize_prompt( # type: ignore[misc]
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ...
def tokenize_prompt(
self,
prompt: DictPrompt,
@@ -351,17 +418,7 @@ class BaseRenderer(ABC, Generic[_T]):
if "encoder_prompt" in prompt:
return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type]
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
prompt = self._tokenize_prompt(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
return self._tokenize_singleton_prompt(prompt, params)
def tokenize_prompts(
self,
@@ -370,27 +427,6 @@ class BaseRenderer(ABC, Generic[_T]):
) -> list[TokPrompt]:
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
@overload
async def tokenize_prompt_async(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
async def tokenize_prompt_async( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
@overload
async def tokenize_prompt_async( # type: ignore[misc]
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ...
async def tokenize_prompt_async(
self,
prompt: DictPrompt,
@@ -399,17 +435,7 @@ class BaseRenderer(ABC, Generic[_T]):
if "encoder_prompt" in prompt:
return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type]
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
prompt = await self._tokenize_prompt_async(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
return await self._tokenize_singleton_prompt_async(prompt, params)
async def tokenize_prompts_async(
self,
@@ -423,7 +449,7 @@ class BaseRenderer(ABC, Generic[_T]):
# Step 3: Add extra keys to the prompts
def _apply_prompt_extras(
self,
prompts: Sequence[DictPrompt | TokPrompt],
prompts: Sequence[TokPrompt],
prompt_extras: dict[str, Any] | None,
):
if not prompt_extras:
@@ -433,6 +459,200 @@ class BaseRenderer(ABC, Generic[_T]):
target_prompt = extract_target_prompt(self.model_config, prompt)
target_prompt.update(prompt_extras) # type: ignore[arg-type]
# Step 4: Convert to engine inputs
def _validate_mm_uuids(
self,
mm_data: "MultiModalDataDict",
mm_items: "MultiModalDataItems",
mm_uuids: "MultiModalUUIDDict | None",
) -> None:
if mm_uuids is None:
mm_uuids = {}
# NOTE: Keys corresponding to `None` in `mm_data` don't appear in `mm_items`
modalities = mm_data.keys() | mm_uuids.keys()
for modality in modalities:
data_items = mm_items.get(modality) or list[Any]()
uuid_items = mm_uuids.get(modality) or list[str | None]()
if isinstance(uuid_items, str):
uuid_items = [uuid_items]
if len(data_items) > 0:
if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
raise ValueError(
f"If given, multi_modal_uuids[{modality!r}] must have "
f"same length as multi_modal_data[{modality!r}], but "
f"got {len(uuid_items)} vs {len(data_items)}."
)
for i, item in enumerate(data_items):
if item is None:
if not uuid_items:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
if uuid_items[i] is None:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}][{i}] is missing."
)
def _process_mm_uuids(
self,
mm_data: "MultiModalDataDict",
mm_items: "MultiModalDataItems",
mm_uuids: "MultiModalUUIDDict | None",
mm_req_id: str,
):
model_config = self.model_config
# NOTE: When users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# `<mm_req_id>-<modality>-<index>`, overriding even user-provided ones.
if (
model_config.multimodal_config
and model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.config.cache_config.enable_prefix_caching
):
mm_uuids = {
modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
for modality, data_count in mm_items.get_all_counts().items()
}
self._validate_mm_uuids(mm_data, mm_items, mm_uuids)
return mm_uuids
# TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
def _process_multimodal(
self,
prompt: list[int] | str,
mm_data: "MultiModalDataDict",
mm_processor_kwargs: Mapping[str, object] | None,
tokenization_kwargs: dict[str, Any] | None,
mm_uuids: "MultiModalUUIDDict | None",
) -> "MultiModalInputs":
from vllm.multimodal.processing.context import set_request_id
mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"
mm_processor = self.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuids = self._process_mm_uuids(mm_data, mm_items, mm_uuids, mm_req_id)
with set_request_id(mm_req_id), set_default_torch_num_threads():
mm_inputs = mm_processor.apply(
prompt,
mm_items,
hf_processor_mm_kwargs=mm_processor_kwargs or {},
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
self.update_mm_cache_stats()
return mm_inputs
def _process_tokens(
self,
prompt: TokensPrompt,
) -> "TokenInputs | MultiModalInputs":
prompt_token_ids = prompt["prompt_token_ids"]
inputs: TokenInputs | MultiModalInputs
if multi_modal_data := prompt.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
tokenization_kwargs=None, # Tokenization already done in Step 2
mm_uuids=prompt.get("multi_modal_uuids"),
)
else:
inputs = token_inputs(prompt_token_ids)
if prompt_text := prompt.get("prompt"):
inputs["prompt"] = prompt_text
if cache_salt := prompt.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _process_embeds(
self,
prompt: EmbedsPrompt,
) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
prompt_embeds = prompt["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
# Tensors must be on CPU for serialization between processes
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
# hidden device transfer in the critical path of generation.
prompt_embeds = prompt_embeds.cpu()
return embeds_inputs(
prompt_embeds=prompt_embeds,
cache_salt=prompt.get("cache_salt"),
)
def _process_singleton(
self,
prompt: SingletonTokPrompt,
) -> SingletonInputs:
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
return self._process_tokens(prompt) # type: ignore[arg-type]
def _process_enc_dec(
self,
prompt: EncoderDecoderTokPrompt,
) -> EncoderDecoderInputs:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
return build_enc_dec_inputs(
encoder_inputs=self._process_singleton(enc_prompt),
decoder_inputs=(
None if dec_prompt is None else self._process_singleton(dec_prompt)
),
decoder_start_token_id=self.get_dec_start_token_id(),
)
def process_for_engine(
self, prompt: TokPrompt, arrival_time: float
) -> ProcessorInputs:
engine_prompt: ProcessorInputs
if "encoder_prompt" in prompt:
engine_prompt = self._process_enc_dec(prompt) # type: ignore[arg-type]
else:
engine_prompt = self._process_singleton(prompt)
engine_prompt["arrival_time"] = arrival_time
return engine_prompt
# Top-level methods
def render_cmpl(
self,
@@ -441,6 +661,8 @@ class BaseRenderer(ABC, Generic[_T]):
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_cmpl_tok_params
@@ -449,8 +671,7 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor
return tok_prompts
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
async def render_cmpl_async(
self,
@@ -459,6 +680,8 @@ class BaseRenderer(ABC, Generic[_T]):
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_cmpl_tok_params
@@ -467,8 +690,7 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor
return tok_prompts
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
def render_chat(
self,
@@ -478,6 +700,8 @@ class BaseRenderer(ABC, Generic[_T]):
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_chat_tok_params
@@ -496,8 +720,11 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor
return out_conversations, tok_prompts
eng_prompts = [
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
]
return out_conversations, eng_prompts
async def render_chat_async(
self,
@@ -507,6 +734,8 @@ class BaseRenderer(ABC, Generic[_T]):
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_chat_tok_params
@@ -525,5 +754,8 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor
return out_conversations, tok_prompts
eng_prompts = [
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
]
return out_conversations, eng_prompts

View File

@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
from vllm.inputs import (
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
ProcessorInputs,
PromptType,
SingletonPrompt,
TextPrompt,
@@ -115,7 +116,7 @@ that has been standardized into a dictionary.
"""
def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt:
def parse_dec_only_prompt(prompt: PromptType | object) -> DecoderOnlyDictPrompt:
"""
Parse a prompt for a decoder-only model and normalize it to a dictionary.
"""
@@ -144,7 +145,7 @@ def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt:
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt:
def _parse_enc_prompt(prompt: PromptType | object) -> EncoderDictPrompt:
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
@@ -166,7 +167,7 @@ def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt:
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt:
def _parse_dec_prompt(prompt: PromptType | object) -> DecoderDictPrompt:
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
@@ -195,13 +196,13 @@ def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt:
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def parse_enc_dec_prompt(prompt: object) -> EncoderDecoderDictPrompt:
def parse_enc_dec_prompt(prompt: PromptType | object) -> EncoderDecoderDictPrompt:
"""
Parse a prompt for an encoder-decoder model and normalize it to a dictionary.
"""
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc_prompt: object = prompt["encoder_prompt"] # type: ignore[typeddict-item]
dec_prompt: object | None = prompt["decoder_prompt"] # type: ignore[typeddict-item]
enc_prompt = prompt["encoder_prompt"] # type: ignore[typeddict-item]
dec_prompt = prompt["decoder_prompt"] # type: ignore[typeddict-item]
else:
enc_prompt = prompt
dec_prompt = None
@@ -235,21 +236,23 @@ def extract_target_prompt(model_config: "ModelConfig", prompt: object):
def extract_prompt_components(
model_config: "ModelConfig",
prompt: object,
prompt: PromptType | ProcessorInputs,
) -> PromptComponents:
target_prompt = extract_target_prompt(model_config, prompt)
return PromptComponents(
text=target_prompt.get("prompt"),
token_ids=target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
token_ids=target_prompt.get("prompt_token_ids"),
embeds=target_prompt.get("prompt_embeds"),
)
def extract_prompt_len(model_config: "ModelConfig", prompt: object):
def extract_prompt_len(
model_config: "ModelConfig", prompt: PromptType | ProcessorInputs
):
target_prompt = extract_target_prompt(model_config, prompt)
return length_from_prompt_token_ids_or_embeds(
target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
target_prompt.get("prompt_token_ids"),
target_prompt.get("prompt_embeds"),
)

39
vllm/utils/tqdm_utils.py Normal file
View File

@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable, Sequence
from typing import Any, TypeVar, overload
from tqdm.auto import tqdm
_T = TypeVar("_T", bound=Iterable)
@overload
def maybe_tqdm(
it: Sequence[_T],
*,
use_tqdm: bool | Callable[..., tqdm],
**tqdm_kwargs: Any,
) -> Sequence[_T]: ...
@overload
def maybe_tqdm(
it: Iterable[_T],
*,
use_tqdm: bool | Callable[..., tqdm],
**tqdm_kwargs: Any,
) -> Iterable[_T]: ...
def maybe_tqdm(
it: Iterable[_T],
*,
use_tqdm: bool | Callable[..., tqdm],
**tqdm_kwargs: Any,
) -> Iterable[_T]:
if not use_tqdm:
return it
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
return tqdm_func(it, **tqdm_kwargs)

View File

@@ -20,7 +20,7 @@ from vllm.distributed.weight_transfer.base import (
)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.inputs import PromptType
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
@@ -28,7 +28,6 @@ from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import merge_kwargs, renderer_from_config
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tasks import SupportedTask
@@ -290,8 +289,7 @@ class AsyncLLM(EngineClient):
request_id: str,
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| ProcessorInputs
| AsyncGenerator[StreamingInput, None],
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
@@ -301,6 +299,7 @@ class AsyncLLM(EngineClient):
priority: int = 0,
data_parallel_rank: int | None = None,
prompt_text: str | None = None,
reasoning_ended: bool | None = None,
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""
@@ -336,6 +335,9 @@ class AsyncLLM(EngineClient):
)
if isinstance(prompt, AsyncGenerator):
if reasoning_ended is not None:
raise NotImplementedError
# Streaming input case.
return await self._add_streaming_input_request(
request_id,
@@ -359,10 +361,6 @@ class AsyncLLM(EngineClient):
"latter will be used, and the former will be ignored."
)
else:
if prompt_text is not None:
raise ValueError(
"should only provide prompt_text with EngineCoreRequest"
)
request = self.input_processor.process_inputs(
request_id,
prompt,
@@ -377,6 +375,9 @@ class AsyncLLM(EngineClient):
)
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
if reasoning_ended is not None:
request.reasoning_ended = reasoning_ended
self.input_processor.assign_request_id(request)
# We start the output_handler on the first call to add_request() so
@@ -536,8 +537,7 @@ class AsyncLLM(EngineClient):
self,
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| ProcessorInputs
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
@@ -548,6 +548,7 @@ class AsyncLLM(EngineClient):
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
reasoning_ended: bool | None = None,
) -> AsyncGenerator[RequestOutput, None]:
"""
Main function called by the API server to kick off a request
@@ -576,6 +577,7 @@ class AsyncLLM(EngineClient):
priority=priority,
data_parallel_rank=data_parallel_rank,
prompt_text=prompt_text,
reasoning_ended=reasoning_ended,
)
# The output_handler task pushes items into the queue.
@@ -770,13 +772,14 @@ class AsyncLLM(EngineClient):
async def encode(
self,
prompt: PromptType | DictPrompt | TokPrompt,
prompt: PromptType | ProcessorInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
tokenization_kwargs: dict[str, Any] | None = None,
reasoning_ended: bool | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""
Main function called by the API server to kick off a request
@@ -802,6 +805,7 @@ class AsyncLLM(EngineClient):
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
reasoning_ended=reasoning_ended,
)
# The output_handler task pushes items into the queue.

View File

@@ -3,7 +3,7 @@
import time
from collections.abc import Mapping
from typing import Any, Literal, cast
from typing import Any, Literal
import vllm.envs as envs
from vllm.config import VllmConfig
@@ -11,7 +11,6 @@ from vllm.inputs.data import (
ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
)
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
@@ -20,22 +19,16 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.encoder_budget import MultiModalBudget
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalUUIDDict,
)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.jsontree import json_iter_leaves
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
@@ -133,81 +126,6 @@ class InputProcessor:
f"but got {type(params).__name__}"
)
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
mm_processor = self.renderer.get_mm_processor()
return mm_processor.info.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
if not isinstance(prompt, dict):
return
mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
if not mm_data and not mm_uuids:
return
mm_data_parsed = self._parse_mm_items(
{k: v for k, v in mm_data.items() if v is not None}
)
mm_uuids_parsed = {
k: [v] if isinstance(v, str) else v
for k, v in mm_uuids.items()
if v is not None
}
# NOTE: Include the keys corresponding to `None`
modalities = mm_data.keys() | mm_uuids.keys()
for modality in modalities:
data_items = cast(
ModalityDataItems | list[Any], mm_data_parsed.get(modality, [])
)
uuid_items = cast(list[str | None], mm_uuids_parsed.get(modality, []))
if len(data_items) > 0:
if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
raise ValueError(
f"If given, multi_modal_uuids[{modality!r}] must have "
f"same length as multi_modal_data[{modality!r}], but "
f"got {len(uuid_items)} vs {len(data_items)}."
)
for i, item in enumerate(data_items):
if item is None:
if not uuid_items:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
if uuid_items[i] is None:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}][{i}] is missing."
)
else:
if len(uuid_items) == 0:
raise ValueError(
f"multi_modal_data[{modality!r}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
def _validate_mm_uuids(self, prompt: PromptType | DictPrompt | TokPrompt) -> None:
"""
Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s).
Only checks lengths; `None` entries are allowed and will be
auto-hashed downstream.
"""
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
self._validate_singleton_mm_uuids(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
if (dec_prompt := prompt["decoder_prompt"]) is not None: # type: ignore[typeddict-item]
self._validate_singleton_mm_uuids(dec_prompt)
else:
self._validate_singleton_mm_uuids(prompt)
def _validate_lora(self, lora_request: LoRARequest | None) -> None:
if lora_request is None:
return
@@ -227,47 +145,6 @@ class InputProcessor:
"[lora_path]` to use the LoRA tokenizer."
)
def _extract_singleton_mm_data(
self, prompt: SingletonPrompt
) -> MultiModalDataDict | None:
if not isinstance(prompt, dict):
return None
return prompt.get("multi_modal_data")
def _extract_mm_data(
self, prompt: PromptType | DictPrompt | TokPrompt
) -> MultiModalDataDict | None:
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
return self._extract_singleton_mm_data(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
else:
return self._extract_singleton_mm_data(prompt)
def _maybe_build_mm_uuids(
self,
request_id: str,
prompt: PromptType | DictPrompt | TokPrompt,
) -> MultiModalUUIDDict | None:
"""Build per-item multimodal hash overrides when enabled. In this case,
multimodal data items are identified by their request id, modality and
index rather than their content.
Returns a dictionary of modality -> list[str] of overrides, or None if
disabled or no multimodal data is present.
"""
mm_data = self._extract_mm_data(prompt)
if not mm_data:
return None
mm_items = self._parse_mm_items(
{k: v for k, v in mm_data.items() if v is not None}
)
return {
modality: [f"{request_id}-{modality}-{i}" for i in range(data_count)]
for modality, data_count in mm_items.get_all_counts().items()
}
def _get_mm_identifier(
self,
mm_hash: str,
@@ -309,7 +186,7 @@ class InputProcessor:
def process_inputs(
self,
request_id: str,
prompt: PromptType | DictPrompt | TokPrompt,
prompt: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
@@ -333,43 +210,18 @@ class InputProcessor:
f"is out of range [0, {num_ranks})."
)
if arrival_time is None:
arrival_time = time.time()
if isinstance(prompt, dict) and "type" in prompt:
if arrival_time is None:
arrival_time = prompt.get("arrival_time", time.time()) # type: ignore[assignment]
# Optionally generate multimodal hash overrides to avoid hashing
# multimodal data items by their content as their identifiers.
# NOTE: when users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# request id-modality-index as multimodal hash overrides.
if (
self.model_config.multimodal_config
and self.model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.cache_config.enable_prefix_caching
):
mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
processed_inputs: ProcessorInputs = prompt # type: ignore[assignment]
else:
# Otherwise, use user-provided uuids as multimodal hash overrides
# if provided.
self._validate_mm_uuids(prompt)
if isinstance(prompt, dict):
mm_uuids = cast(
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
)
else:
mm_uuids = None
if arrival_time is None:
arrival_time = time.time()
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
with set_request_id(request_id), set_default_torch_num_threads():
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
processed_inputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
from vllm.platforms import current_platform

View File

@@ -14,7 +14,7 @@ from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.distributed.parallel_state import get_dp_group
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
@@ -22,7 +22,6 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import renderer_from_config
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
@@ -220,7 +219,7 @@ class LLMEngine:
def add_request(
self,
request_id: str,
prompt: EngineCoreRequest | PromptType | DictPrompt | TokPrompt,
prompt: EngineCoreRequest | PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
@@ -228,7 +227,7 @@ class LLMEngine:
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
prompt_text: str | None = None,
) -> None:
) -> str:
# Validate the request_id type.
if not isinstance(request_id, str):
raise TypeError(f"request_id must be a string, got {type(request_id)}")
@@ -243,7 +242,6 @@ class LLMEngine:
"latter will be used, and the former will be ignored."
)
else:
assert prompt_text is None
request = self.input_processor.process_inputs(
request_id,
prompt,
@@ -259,6 +257,8 @@ class LLMEngine:
self.input_processor.assign_request_id(request)
req_id = request.request_id
# Use cloned params that may have been updated in process_inputs()
params = request.params
@@ -269,7 +269,7 @@ class LLMEngine:
self.output_processor.add_request(request, prompt_text, None, 0)
# Add the request to EngineCore.
self.engine_core.add_request(request)
return
return req_id
# Fan out child requests (for n>1).
parent_req = ParentRequest(request)
@@ -286,6 +286,8 @@ class LLMEngine:
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
return req_id
def step(self) -> list[RequestOutput | PoolingRequestOutput]:
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False