[Renderer] Move InputPreprocessor into Renderer (1/2) (#34510)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Cyrus Leung
2026-02-15 02:14:21 +08:00
committed by GitHub
parent b3c14229b0
commit 73391a1baa
39 changed files with 456 additions and 458 deletions

View File

@@ -54,6 +54,7 @@ class MockModelConfig:
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
@@ -67,7 +68,7 @@ class MockVllmConfig:
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
return HfRenderer.from_config(
MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)

View File

@@ -53,6 +53,7 @@ class MockModelConfig:
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
@@ -78,7 +79,7 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
return HfRenderer.from_config(
MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)

View File

@@ -52,6 +52,7 @@ class MockModelConfig:
generation_config: str = "auto"
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
@@ -95,7 +96,7 @@ def register_mock_resolver():
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
return HfRenderer.from_config(
MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)

View File

@@ -529,6 +529,7 @@ class MockModelConfig:
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
@@ -542,7 +543,7 @@ class MockVllmConfig:
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
return HfRenderer.from_config(
MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
@@ -756,9 +757,8 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer(
MockVllmConfig(mock_engine.model_config),
tokenizer_kwargs={},
tokenizer=mock_tokenizer,
)
mock_renderer._tokenizer = mock_tokenizer
# Force the Mistral chat template renderer to return token IDs.
# Choose a prompt length that is < max_model_len, but large enough that
# adding max_tokens should exceed the model context window.
@@ -798,9 +798,8 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer(
MockVllmConfig(mock_engine.model_config),
tokenizer_kwargs={},
tokenizer=mock_tokenizer,
)
mock_renderer._tokenizer = mock_tokenizer
# prompt_token_ids length == max_model_len should be rejected for
# completion-like requests (ChatCompletionRequest).
mock_renderer.render_messages_async = AsyncMock(

View File

@@ -38,6 +38,7 @@ class MockModelConfig:
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
@dataclass
@@ -78,15 +79,16 @@ def _build_renderer(
renderer = HfRenderer(
MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
tokenizer=(
None
if model_config.skip_tokenizer_init
else DummyTokenizer(
truncation_side=truncation_side,
max_chars_per_token=max_chars_per_token,
)
),
)
if not model_config.skip_tokenizer_init:
renderer._tokenizer = DummyTokenizer(
truncation_side=truncation_side,
max_chars_per_token=max_chars_per_token,
)
return renderer
@@ -277,7 +279,7 @@ class TestRenderPrompt:
)
# Should not even attempt tokenization
assert renderer._tokenizer._captured_encode_kwargs == {}
assert renderer.tokenizer._captured_encode_kwargs == {}
def test_text_max_length_exceeded_nonobvious(self):
renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2)
@@ -298,8 +300,8 @@ class TestRenderPrompt:
)
# Should only tokenize the first max_total_tokens + 1 tokens
assert renderer._tokenizer._captured_encode_kwargs["truncation"] is True
assert renderer._tokenizer._captured_encode_kwargs["max_length"] == 101
assert renderer.tokenizer._captured_encode_kwargs["truncation"] is True
assert renderer.tokenizer._captured_encode_kwargs["max_length"] == 101
def test_token_max_length_exceeded(self):
renderer = _build_renderer(MockModelConfig())

View File

@@ -36,6 +36,7 @@ class MockModelConfig:
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
@dataclass
@@ -57,9 +58,8 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
mock_tokenizer.apply_chat_template = mocked_apply_chat_template
mock_renderer = MistralRenderer(
MockVllmConfig(mock_model_config),
tokenizer_kwargs={},
tokenizer=mock_tokenizer,
)
mock_renderer._tokenizer = mock_tokenizer
task = mock_renderer.render_messages_async([], ChatParams())

View File

@@ -19,7 +19,7 @@ import pytest
import pytest_asyncio
from vllm import SamplingParams
from vllm.inputs import StreamingInput
from vllm.engine.protocol import StreamingInput
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind

View File

@@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.inputs import StreamingInput
from vllm.engine.protocol import StreamingInput
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM

View File

@@ -18,7 +18,7 @@ import dataclasses
import json
import time
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING, Any
import numpy as np
@@ -28,9 +28,6 @@ from vllm.benchmarks.datasets import (
)
from vllm.benchmarks.throughput import get_requests
from vllm.engine.arg_utils import EngineArgs
from vllm.multimodal.processing.context import (
get_timing_stats_from_engine_client,
)
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.import_utils import PlaceholderModule
@@ -39,16 +36,103 @@ try:
except ImportError:
pd = PlaceholderModule("pandas")
if TYPE_CHECKING: # Avoid having to mock during docs build
from vllm.v1.engine.llm_engine import LLMEngine
else:
LLMEngine = object
def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, float]]:
"""
Get all multimodal timing stats from the LLM engine.
Collects both preprocessing stats (HF processor, hashing, cache lookup,
prompt update) and encoder forward pass timing, merged by request_id.
Args:
llm_engine: The LLM engine (has input_processor and workers).
Returns:
Dictionary mapping request_id to merged stats dict containing
both preprocessing and encoder timing metrics.
Example:
{
'request-123': {
'hf_processor_time': 0.45,
'hashing_time': 0.02,
'cache_lookup_time': 0.01,
'prompt_update_time': 0.03,
'preprocessor_total_time': 0.51,
'encoder_forward_time': 0.23,
'num_encoder_calls': 1
}
}
"""
observability_config = llm_engine.vllm_config.observability_config
if not observability_config or not observability_config.enable_mm_processor_stats:
return {}
renderer = llm_engine.renderer
mm_processor = renderer.get_mm_processor()
preprocessing_stats = mm_processor.info.ctx.get_all_timing_stats()
encoder_stats = dict[str, dict[str, float]]()
for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"):
if not worker_stats:
continue
for request_id, stats_dict in worker_stats.items():
if request_id not in encoder_stats:
encoder_stats[request_id] = dict(stats_dict)
else:
# Aggregate timing metrics across workers
current_time = encoder_stats[request_id].get(
"encoder_forward_time", 0.0
)
new_time = stats_dict.get("encoder_forward_time", 0.0)
encoder_stats[request_id]["encoder_forward_time"] = max(
current_time, new_time
)
current_calls = encoder_stats[request_id].get("num_encoder_calls", 0)
new_calls = stats_dict.get("num_encoder_calls", 0)
encoder_stats[request_id]["num_encoder_calls"] = max(
current_calls, new_calls
)
merged_stats = dict[str, dict[str, float]]()
for request_id, prep_dict in preprocessing_stats.items():
merged_stats[request_id] = dict(prep_dict)
for request_id, enc_dict in encoder_stats.items():
if request_id in merged_stats:
merged_stats[request_id].update(enc_dict)
continue
# In V1 engine, the request_id in encoder_stats has a suffix
# appended to the original request_id (which is used in
# preprocessing_stats).
# We try to strip the suffix to find the matching request.
possible_original_id = request_id.rpartition("-")[0]
if possible_original_id and possible_original_id in merged_stats:
merged_stats[possible_original_id].update(enc_dict)
else:
merged_stats[request_id] = dict(enc_dict)
return merged_stats
def collect_mm_processor_stats(
llm_engine: Any,
llm_engine: LLMEngine,
num_warmup_reqs: int = 0,
) -> dict[str, list[float]]:
"""
Collect multimodal processor timing stats.
Returns a dictionary mapping stage names to lists of timing values (in seconds).
"""
all_stats = get_timing_stats_from_engine_client(llm_engine)
all_stats = get_timing_stats_from_engine(llm_engine)
stat_keys = [
"hf_processor_time",

View File

@@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Iterable, Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from vllm.config import ModelConfig, VllmConfig
@@ -10,7 +11,7 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.inputs.data import PromptType, StreamingInput
from vllm.inputs.data import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
@@ -26,6 +27,18 @@ if TYPE_CHECKING:
from vllm.v1.engine import PauseMode
@dataclass
class StreamingInput:
"""Input data for a streaming generation request.
This is used with generate() to support multi-turn streaming sessions
where inputs are provided via an async generator.
"""
prompt: PromptType
sampling_params: SamplingParams | None = None
class EngineClient(ABC):
"""Protocol class for Clients to Engine"""

View File

@@ -72,7 +72,7 @@ from vllm.outputs import (
)
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers import ChatParams, merge_kwargs
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import (
conversation_to_seq,
@@ -384,7 +384,7 @@ class LLM:
return parallel_config.world_size
def reset_mm_cache(self) -> None:
self.input_processor.clear_mm_cache()
self.renderer.clear_mm_cache()
self.llm_engine.reset_mm_cache()
def get_default_sampling_params(self) -> SamplingParams:
@@ -876,19 +876,6 @@ class LLM:
return outputs
def _get_cmpl_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens=not model_config.is_encoder_decoder,
).with_kwargs(tokenization_kwargs)
def _preprocess_cmpl(
self,
prompts: Sequence[PromptType],
@@ -910,20 +897,12 @@ class LLM:
parsed_prompts = [
parse_model_prompt(model_config, prompt) for prompt in prompts
]
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
return renderer.render_cmpl(parsed_prompts, tok_params)
def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=False,
).with_kwargs(tokenization_kwargs)
def _preprocess_chat(
self,
conversations: Sequence[list[ChatCompletionMessageParam]],
@@ -961,7 +940,9 @@ class LLM:
),
),
)
tok_params = self._get_chat_tok_params(tokenization_kwargs)
tok_params = renderer.default_chat_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
_, engine_prompts = renderer.render_chat(
conversations,
@@ -1653,7 +1634,10 @@ class LLM:
architecture=architecture,
)
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
renderer = self.renderer
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
encode_kwargs = tok_params.get_encode_kwargs()
if model_config.is_cross_encoder:
@@ -1970,7 +1954,10 @@ class LLM:
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
)
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
renderer = self.renderer
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(

View File

@@ -8,11 +8,11 @@ from typing import Literal, cast
import numpy as np
from vllm.engine.protocol import EngineClient
from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs.data import PromptType, StreamingInput
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsRealtime

View File

@@ -12,7 +12,6 @@ from .data import (
PromptType,
SingletonInputs,
SingletonPrompt,
StreamingInput,
TextPrompt,
TokenInputs,
TokensPrompt,
@@ -36,5 +35,4 @@ __all__ = [
"EncoderDecoderInputs",
"ProcessorInputs",
"SingletonInputs",
"StreamingInput",
]

View File

@@ -1,13 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import torch
from typing_extensions import NotRequired, TypedDict
from vllm.sampling_params import SamplingParams
if TYPE_CHECKING:
from vllm.multimodal.inputs import (
MultiModalDataDict,
@@ -299,15 +296,3 @@ which can be passed to
SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
"""The inputs for a single encoder/decoder prompt."""
@dataclass
class StreamingInput:
"""Input data for a streaming generation request.
This is used with generate() to support multi-turn streaming sessions
where inputs are provided via an async generator.
"""
prompt: PromptType
sampling_params: SamplingParams | None = None

View File

@@ -9,13 +9,11 @@ from typing_extensions import assert_never
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalInputs,
MultiModalUUIDDict,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.renderers.inputs import (
DecoderDictPrompt,
@@ -28,8 +26,6 @@ from vllm.renderers.inputs import (
)
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
from vllm.tokenizers import TokenizerLike
from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats
from .data import (
DecoderInputs,
@@ -57,17 +53,12 @@ class InputPreprocessor:
vllm_config: VllmConfig,
renderer: BaseRenderer | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None:
super().__init__()
self.model_config = vllm_config.model_config
self.observability_config = vllm_config.observability_config
self.renderer = renderer or renderer_from_config(vllm_config)
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
@property
def tokenizer(self) -> TokenizerLike | None:
@@ -124,23 +115,6 @@ class InputPreprocessor:
return decoder_input_ids
def _get_tokenization_kw(
self,
overrides: dict[str, Any] | None = None,
) -> dict[str, Any]:
kwargs = dict[str, Any]()
if self.model_config.is_encoder_decoder:
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
kwargs["add_special_tokens"] = False
if overrides:
kwargs.update(overrides)
return kwargs
def _tokenize_prompt(
self,
prompt: str,
@@ -150,26 +124,18 @@ class InputPreprocessor:
Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs.
"""
tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
renderer = self.renderer
encoder_config = self.model_config.encoder_config
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
if encoder_config and encoder_config.get("do_lower_case", False):
prompt = prompt.lower()
tok_prompt = renderer.tokenize_prompt(
TextPrompt(prompt=prompt),
tok_params,
)
return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_processor(self) -> BaseMultiModalProcessor:
if not hasattr(self, "_mm_processor"):
self._mm_processor = self.mm_registry.create_processor(
self.model_config,
self.observability_config,
tokenizer=self.tokenizer,
cache=self.mm_processor_cache,
)
return self._mm_processor
return tok_prompt["prompt_token_ids"]
def _process_multimodal(
self,
@@ -184,33 +150,20 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
"""
mm_processor = self._get_mm_processor()
mm_processor = self.renderer.get_mm_processor()
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
mm_items = mm_processor.info.parse_mm_data(mm_data)
mm_input = mm_processor.apply(
return mm_processor.apply(
prompt,
mm_items,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
mm_hashes = mm_input["mm_hashes"]
# Validate that all mm items have a string as their hash
contains_only_strings = all(
isinstance(leaf, str) for leaf in json_iter_leaves(mm_hashes)
)
if not contains_only_strings:
raise ValueError(
f"mm_hashes must contain only strings, got: {mm_hashes}. "
"This is likely due to an incorrect custom implementation of "
"MultiModalProcessor.apply method."
)
return mm_input
def _process_embeds(
self,
@@ -245,19 +198,18 @@ class InputPreprocessor:
def _truncate_inputs(
self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
) -> list[int]:
if (
not tokenization_kwargs
or "truncation" not in tokenization_kwargs
or self.tokenizer is None
):
return inputs
renderer = self.renderer
max_length = tokenization_kwargs["max_length"]
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
if self.tokenizer.truncation_side == "left":
return inputs[-max_length:]
else:
return inputs[:max_length]
tok_prompt = renderer.tokenize_prompt(
TokensPrompt(prompt_token_ids=inputs),
tok_params,
)
return tok_prompt["prompt_token_ids"]
def _process_tokens(
self,
@@ -539,26 +491,6 @@ class InputPreprocessor:
"""Preprocess the input prompt."""
res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)
if self.mm_processor_cache and self.mm_cache_stats is not None:
delta = self.mm_processor_cache.make_stats(delta=True)
self.mm_cache_stats.requests += 1
self.mm_cache_stats.queries += delta.total
self.mm_cache_stats.hits += delta.hits
self.renderer.update_mm_cache_stats()
return res
def stat_mm_cache(self) -> MultiModalCacheStats | None:
mm_cache_stats = self.mm_cache_stats
if mm_cache_stats is None:
return None
self.mm_cache_stats = MultiModalCacheStats()
return mm_cache_stats
def clear_mm_cache(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache()
if self.mm_cache_stats is not None:
self.mm_cache_stats.reset = True

View File

@@ -208,14 +208,23 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
if prompt and mm_items:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
"Image-only inputs means passing an image with an empty text "
"prompt."
)
if mm_items:
if isinstance(prompt, str):
if len(prompt) > 0:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty text prompt."
)
else:
special_tokens = self.info.get_tokenizer().all_special_ids
if all(tok in special_tokens for tok in prompt):
prompt = []
else:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty token prompt."
)
# For multi-modal data, the prompt after processing should
# only contain the dummy image tokens
tokenization_kwargs = {

View File

@@ -42,6 +42,7 @@ from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdateDetails,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -90,6 +91,9 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
def get_image_processor(self, **kwargs: object) -> Lfm2VlImageProcessorFast:
return self.get_hf_processor(**kwargs).image_processor
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}

View File

@@ -66,6 +66,7 @@ from vllm.multimodal.processing import (
PromptUpdate,
PromptUpdateDetails,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -554,6 +555,9 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
)
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
# Although vLLM can support more images from an infra capability
# perspective, we do not recommend using >10 images in practice.

View File

@@ -76,6 +76,7 @@ from vllm.multimodal.processing.processor import (
PromptUpdateDetails,
_seq2tokens,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.transformers_utils.configs.radio import RadioConfig
@@ -1093,6 +1094,9 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
) -> BaseNanoNemotronVLProcessor:
raise NotImplementedError
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}

View File

@@ -58,6 +58,7 @@ from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdate,
)
from vllm.renderers import TokenizeParams
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -608,6 +609,9 @@ class NemotronParseProcessingInfo(BaseProcessingInfo):
**kwargs,
)
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
@property
def skip_prompt_length_check(self) -> bool:
return True # Because the encoder prompt is padded

View File

@@ -53,6 +53,7 @@ from vllm.multimodal.processing import (
BaseProcessingInfo,
PromptReplacement,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -264,6 +265,9 @@ class OvisProcessingInfo(BaseProcessingInfo):
**kwargs,
)
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_image_segment_len(self) -> int:
visual_tokenizer_config = self.get_hf_config().visual_tokenizer_config
image_size = visual_tokenizer_config.backbone_config.image_size

View File

@@ -35,6 +35,7 @@ from vllm.multimodal.processing import (
BaseProcessingInfo,
PromptReplacement,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -183,6 +184,9 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
temporal_patch_size=vit_config.temporal_patch_size,
)
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_image_processor(self) -> BaseImageProcessor:
return self.get_hf_processor().image_processor # type: ignore

View File

@@ -32,6 +32,7 @@ from vllm.multimodal.processing import (
PromptUpdate,
PromptUpdateDetails,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -102,6 +103,9 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": 1}

View File

@@ -194,14 +194,23 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
if prompt and mm_items:
raise ValueError(
"Siglip accepts text-only or image-only inputs, not both! "
"Image-only inputs means passing an image with an empty text "
"prompt."
)
if mm_items:
if isinstance(prompt, str):
if len(prompt) > 0:
raise ValueError(
"SigLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty text prompt."
)
else:
special_tokens = self.info.get_tokenizer().all_special_ids
if all(tok in special_tokens for tok in prompt):
prompt = []
else:
raise ValueError(
"SigLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty token prompt."
)
# For multi-modal data, the prompt after processing should
# only contain the image token
tokenization_kwargs = {

View File

@@ -42,6 +42,7 @@ from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdate,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -133,6 +134,9 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()

View File

@@ -17,8 +17,9 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.engine.protocol import StreamingInput
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
from vllm.inputs.data import PromptType, StreamingInput, TokensPrompt
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
from vllm.model_executor.models.voxtral import (

View File

@@ -55,6 +55,7 @@ from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdate,
)
from vllm.renderers import TokenizeParams
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -644,6 +645,12 @@ class WhisperProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig)
def get_default_tok_params(self) -> TokenizeParams:
# Special tokens should be provided by the user based on the
# task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()

View File

@@ -21,6 +21,7 @@ from vllm.multimodal.parse import (
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.renderers import TokenizeParams
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
@@ -93,110 +94,6 @@ class MultiModalProcessorTimingStats:
}
def get_timing_stats_from_engine_client(
engine_client: Any,
) -> dict[str, dict[str, float]]:
"""
Get all multimodal timing stats from the engine client.
Collects both preprocessing stats (HF processor, hashing, cache lookup,
prompt update) and encoder forward pass timing, merged by request_id.
Args:
engine_client: The engine client (has input_processor and workers).
Returns:
Dictionary mapping request_id to merged stats dict containing
both preprocessing and encoder timing metrics.
Example:
{
'request-123': {
'hf_processor_time': 0.45,
'hashing_time': 0.02,
'cache_lookup_time': 0.01,
'prompt_update_time': 0.03,
'preprocessor_total_time': 0.51,
'encoder_forward_time': 0.23,
'num_encoder_calls': 1
}
}
"""
try:
if not engine_client.vllm_config.observability_config.enable_mm_processor_stats:
return {}
except (AttributeError, RuntimeError):
return {}
preprocessing_stats = {}
try:
input_processor = engine_client.input_processor
input_preprocessor = input_processor.input_preprocessor
if hasattr(input_preprocessor, "_get_mm_processor"):
mm_processor = input_preprocessor._get_mm_processor()
if mm_processor is not None and hasattr(mm_processor, "info"):
ctx = mm_processor.info.ctx
preprocessing_stats = ctx.get_all_timing_stats()
except (AttributeError, RuntimeError):
pass
encoder_stats = {}
try:
if hasattr(engine_client, "collective_rpc"):
encoder_stats_results = engine_client.collective_rpc(
"get_encoder_timing_stats"
)
if encoder_stats_results and len(encoder_stats_results) > 0:
for worker_stats in encoder_stats_results:
if not worker_stats:
continue
for request_id, stats_dict in worker_stats.items():
if request_id not in encoder_stats:
encoder_stats[request_id] = dict(stats_dict)
else:
# Aggregate timing metrics across workers
current_time = encoder_stats[request_id].get(
"encoder_forward_time", 0.0
)
new_time = stats_dict.get("encoder_forward_time", 0.0)
encoder_stats[request_id]["encoder_forward_time"] = max(
current_time, new_time
)
current_calls = encoder_stats[request_id].get(
"num_encoder_calls", 0
)
new_calls = stats_dict.get("num_encoder_calls", 0)
encoder_stats[request_id]["num_encoder_calls"] = max(
current_calls, new_calls
)
except (AttributeError, RuntimeError):
pass
merged_stats = {}
for request_id, prep_dict in preprocessing_stats.items():
merged_stats[request_id] = dict(prep_dict)
for request_id, enc_dict in encoder_stats.items():
if request_id in merged_stats:
merged_stats[request_id].update(enc_dict)
continue
# In V1 engine, the request_id in encoder_stats has a suffix
# appended to the original request_id (which is used in
# preprocessing_stats).
# We try to strip the suffix to find the matching request.
possible_original_id = request_id.rpartition("-")[0]
if possible_original_id and possible_original_id in merged_stats:
merged_stats[possible_original_id].update(enc_dict)
else:
merged_stats[request_id] = dict(enc_dict)
return merged_stats
@contextmanager
def timed_preprocessor_operation(ctx: "InputProcessingContext", stage_name: str):
"""
@@ -576,6 +473,21 @@ class BaseProcessingInfo:
"""
return self.ctx.get_hf_processor(**kwargs)
def get_default_tok_params(self) -> TokenizeParams:
"""Construct the default parameters for tokenization."""
model_config = self.ctx.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=True,
)
@cached_property
def default_tok_params(self) -> TokenizeParams:
return self.get_default_tok_params()
def _get_expected_hidden_size(self) -> int | None:
"""
Get expected hidden size for embedding validation if `mm_embeds` are enabled.

View File

@@ -3,12 +3,17 @@
import asyncio
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, overload
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload
from typing_extensions import TypeVar
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.metrics.stats import MultiModalCacheStats
from .embed_utils import safe_load_prompt_embeds
from .inputs import (
@@ -26,11 +31,16 @@ if TYPE_CHECKING:
ChatCompletionMessageParam,
ConversationMessage,
)
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.processing import BaseMultiModalProcessor
logger = init_logger(__name__)
class BaseRenderer(ABC):
_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)
class BaseRenderer(ABC, Generic[_T]):
@classmethod
@abstractmethod
def from_config(
@@ -40,20 +50,36 @@ class BaseRenderer(ABC):
) -> "BaseRenderer":
raise NotImplementedError
def __init__(self, config: "VllmConfig") -> None:
def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:
super().__init__()
self.config = config
self.model_config = config.model_config
self.tokenizer = tokenizer
# Lazy initialization since offline LLM doesn't use async
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
@property
@abstractmethod
def tokenizer(self) -> TokenizerLike | None:
raise NotImplementedError
self.mm_processor: BaseMultiModalProcessor | None = None
self._mm_cache_stats: MultiModalCacheStats | None = None
if config.model_config.is_multimodal_model:
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
def get_tokenizer(self) -> TokenizerLike:
mm_processor_cache = mm_registry.processor_cache_from_config(config)
with set_default_torch_num_threads():
self.mm_processor = mm_registry.create_processor(
config.model_config,
config.observability_config,
tokenizer=tokenizer,
cache=mm_processor_cache,
)
if mm_processor_cache:
self._mm_cache_stats = MultiModalCacheStats()
def get_tokenizer(self) -> _T:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
@@ -66,6 +92,49 @@ class BaseRenderer(ABC):
return self._async_tokenizer
def get_mm_processor(self) -> "BaseMultiModalProcessor":
if self.mm_processor is None:
raise ValueError("Multi-modal processor not available for text-only models")
return self.mm_processor
@property
def mm_processor_cache(self) -> "BaseMultiModalProcessorCache | None":
if self.mm_processor is None:
return None
return self.mm_processor.cache
def stat_mm_cache(self) -> MultiModalCacheStats | None:
mm_cache_stats = self._mm_cache_stats
if mm_cache_stats is None:
return None
self._mm_cache_stats = MultiModalCacheStats()
return mm_cache_stats
def update_mm_cache_stats(self) -> None:
mm_processor_cache = self.mm_processor_cache
mm_cache_stats = self._mm_cache_stats
if mm_processor_cache and mm_cache_stats:
delta = mm_processor_cache.make_stats(delta=True)
mm_cache_stats.record(delta.total, delta.hits)
def clear_mm_cache(self) -> None:
mm_processor_cache = self.mm_processor_cache
if mm_processor_cache is not None:
mm_processor_cache.clear_cache()
if self._mm_cache_stats is not None:
self._mm_cache_stats.reset = True
def shutdown(self) -> None:
mm_processor_cache = self.mm_processor_cache
if mm_processor_cache is not None:
mm_processor_cache.close()
def get_bos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
@@ -84,6 +153,36 @@ class BaseRenderer(ABC):
return self.tokenizer.eos_token_id
@cached_property
def default_cmpl_tok_params(self) -> TokenizeParams:
mm_processor = self.mm_processor
if mm_processor is not None:
return mm_processor.info.default_tok_params
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=True,
)
@cached_property
def default_chat_tok_params(self) -> TokenizeParams:
mm_processor = self.mm_processor
if mm_processor is not None:
return mm_processor.info.default_tok_params
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=False,
)
# Step 1: Convert raw inputs to prompts
def render_prompt(
self,
@@ -317,18 +416,14 @@ class BaseRenderer(ABC):
def render_cmpl(
self,
prompts: Sequence[DictPrompt | bytes],
tok_params: TokenizeParams,
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
):
if tok_params is None:
tok_params = self.default_cmpl_tok_params
dict_prompts = self.render_prompts(prompts)
# NOTE: Some MM models have non-default `add_special_tokens`
# so we handle tokenization in multi-modal processor
if self.model_config.is_multimodal_model:
self._apply_prompt_extras(dict_prompts, prompt_extras)
return dict_prompts
tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
@@ -339,14 +434,14 @@ class BaseRenderer(ABC):
async def render_cmpl_async(
self,
prompts: Sequence[DictPrompt | bytes],
tok_params: TokenizeParams,
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
):
dict_prompts = await self.render_prompts_async(prompts)
if tok_params is None:
tok_params = self.default_cmpl_tok_params
# NOTE: MM data cannot be passed to online Completions API
# so we don't have the special case that is in the offline version
dict_prompts = await self.render_prompts_async(prompts)
tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
@@ -358,10 +453,13 @@ class BaseRenderer(ABC):
self,
conversations: Sequence[list["ChatCompletionMessageParam"]],
chat_params: ChatParams,
tok_params: TokenizeParams,
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
):
if tok_params is None:
tok_params = self.default_chat_tok_params
rendered = [
self.render_messages(conversation, chat_params)
for conversation in conversations
@@ -384,10 +482,13 @@ class BaseRenderer(ABC):
self,
conversations: Sequence[list["ChatCompletionMessageParam"]],
chat_params: ChatParams,
tok_params: TokenizeParams,
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
):
if tok_params is None:
tok_params = self.default_chat_tok_params
rendered = [
self.render_messages_async(conversation, chat_params)
for conversation in conversations

View File

@@ -13,7 +13,6 @@ from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from ..tokenizers.hf import HfTokenizer
from .base import BaseRenderer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
@@ -22,23 +21,14 @@ from .params import ChatParams
logger = init_logger(__name__)
class DeepseekV32Renderer(BaseRenderer):
class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
@classmethod
def from_config(
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer":
return cls(config, tokenizer_kwargs)
def __init__(
self,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__(config)
model_config = self.model_config
) -> "DeepseekV32Renderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
@@ -47,18 +37,7 @@ class DeepseekV32Renderer(BaseRenderer):
**tokenizer_kwargs,
)
self._tokenizer = tokenizer
@property
def tokenizer(self) -> HfTokenizer | None:
return self._tokenizer
def get_tokenizer(self) -> HfTokenizer:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
return cls(config, tokenizer)
def render_messages(
self,

View File

@@ -21,23 +21,14 @@ from .params import ChatParams
logger = init_logger(__name__)
class Grok2Renderer(BaseRenderer):
class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
@classmethod
def from_config(
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer":
return cls(config, tokenizer_kwargs)
def __init__(
self,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__(config)
model_config = self.model_config
) -> "Grok2Renderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
@@ -46,18 +37,7 @@ class Grok2Renderer(BaseRenderer):
**tokenizer_kwargs,
)
self._tokenizer = tokenizer
@property
def tokenizer(self) -> Grok2Tokenizer | None:
return self._tokenizer
def get_tokenizer(self) -> Grok2Tokenizer:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
return cls(config, tokenizer)
def render_messages(
self,

View File

@@ -585,27 +585,14 @@ def replace_vision_chunk_video_placeholder(
return prompt_raw
class HfRenderer(BaseRenderer):
class HfRenderer(BaseRenderer[HfTokenizer]):
@classmethod
def from_config(
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer":
return cls(config, tokenizer_kwargs)
def __init__(
self,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__(config)
model_config = self.model_config
self.use_unified_vision_chunk = getattr(
model_config.hf_config, "use_unified_vision_chunk", False
)
) -> "HfRenderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
@@ -617,18 +604,18 @@ class HfRenderer(BaseRenderer):
),
)
self._tokenizer = tokenizer
return cls(config, tokenizer)
@property
def tokenizer(self) -> HfTokenizer | None:
return self._tokenizer
def __init__(
self,
config: VllmConfig,
tokenizer: HfTokenizer | None,
) -> None:
super().__init__(config, tokenizer)
def get_tokenizer(self) -> HfTokenizer:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
self.use_unified_vision_chunk = getattr(
config.model_config.hf_config, "use_unified_vision_chunk", False
)
def render_messages(
self,

View File

@@ -50,23 +50,14 @@ def safe_apply_chat_template(
raise ValueError(str(e)) from e
class MistralRenderer(BaseRenderer):
class MistralRenderer(BaseRenderer[MistralTokenizer]):
@classmethod
def from_config(
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer":
return cls(config, tokenizer_kwargs)
def __init__(
self,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__(config)
model_config = self.model_config
) -> "MistralRenderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
@@ -75,24 +66,20 @@ class MistralRenderer(BaseRenderer):
**tokenizer_kwargs,
)
self._tokenizer = tokenizer
return cls(config, tokenizer)
def __init__(
self,
config: VllmConfig,
tokenizer: MistralTokenizer | None,
) -> None:
super().__init__(config, tokenizer)
self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1)
self._apply_chat_template_async = make_async(
safe_apply_chat_template, executor=self._apply_chat_template_executor
)
@property
def tokenizer(self) -> MistralTokenizer | None:
return self._tokenizer
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def render_messages(
self,
messages: list[ChatCompletionMessageParam],

View File

@@ -3,7 +3,6 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypeVar
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
@@ -12,9 +11,13 @@ from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
import torch
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
else:
torch = LazyLoader("torch", globals(), "torch")
ChatTemplateContentFormatOption = object
logger = init_logger(__name__)
@@ -43,7 +46,7 @@ class ChatParams:
chat_template: str | None = None
"""The chat template to apply."""
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
chat_template_content_format: "ChatTemplateContentFormatOption" = "auto"
"""The format of the chat template."""
chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
@@ -163,10 +166,7 @@ class TokenizeParams:
value=truncate_prompt_tokens,
)
def with_kwargs(self, tokenization_kwargs: dict[str, Any] | None):
if tokenization_kwargs is None:
tokenization_kwargs = {}
def with_kwargs(self, **tokenization_kwargs: Any):
max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens)
pad_prompt_tokens = tokenization_kwargs.pop(
"pad_prompt_tokens", self.pad_prompt_tokens

View File

@@ -10,7 +10,6 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from .base import BaseRenderer
from .inputs import DictPrompt
@@ -24,24 +23,14 @@ class TerratorchRenderer(BaseRenderer):
@classmethod
def from_config(
cls,
config: VllmConfig,
config: VllmConfig, # type: ignore[override]
tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer":
return cls(config)
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
model_config = self.model_config
) -> "TerratorchRenderer":
model_config = config.model_config
if not model_config.skip_tokenizer_init:
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
@property
def tokenizer(self) -> TokenizerLike | None:
return None
def get_tokenizer(self) -> TokenizerLike:
raise ValueError("Tokenizer not available for Terratorch renderer")
return cls(config, None)
def render_messages(
self,

View File

@@ -19,8 +19,8 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferUpdateRequest,
)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.inputs import PromptType, StreamingInput
from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
@@ -268,12 +268,12 @@ class AsyncLLM(EngineClient):
shutdown_prometheus()
if renderer := getattr(self, "renderer", None):
renderer.shutdown()
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
if input_processor := getattr(self, "input_processor", None):
input_processor.close()
handler = getattr(self, "output_handler", None)
if handler is not None:
cancel_task_threadsafe(handler)
@@ -654,7 +654,7 @@ class AsyncLLM(EngineClient):
output_processor = self.output_processor
log_stats = self.log_stats
logger_manager = self.logger_manager
input_processor = self.input_processor
renderer = self.renderer
chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
async def output_handler():
@@ -702,7 +702,7 @@ class AsyncLLM(EngineClient):
engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=input_processor.stat_mm_cache(),
mm_cache_stats=renderer.stat_mm_cache(),
)
except Exception as e:
logger.exception("AsyncLLM output_handler failed.")
@@ -881,7 +881,7 @@ class AsyncLLM(EngineClient):
await asyncio.gather(*coros)
async def reset_mm_cache(self) -> None:
self.input_processor.clear_mm_cache()
self.renderer.clear_mm_cache()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(

View File

@@ -33,9 +33,9 @@ from vllm.sampling_params import SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.jsontree import json_iter_leaves
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
logger = init_logger(__name__)
@@ -60,8 +60,6 @@ class InputProcessor:
self.generation_config_fields = model_config.try_get_generation_config()
self.renderer = renderer or renderer_from_config(vllm_config)
self.mm_registry = mm_registry
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
self.supports_mm_inputs = mm_registry.supports_multimodal_inputs(model_config)
self.mm_encoder_cache_size = 0
@@ -78,7 +76,6 @@ class InputProcessor:
vllm_config,
renderer=renderer,
mm_registry=mm_registry,
mm_processor_cache=self.mm_processor_cache,
)
@property
@@ -136,7 +133,7 @@ class InputProcessor:
)
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
mm_processor = self.input_preprocessor._get_mm_processor()
mm_processor = self.renderer.get_mm_processor()
return mm_processor.info.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
@@ -415,6 +412,15 @@ class InputProcessor:
decoder_mm_positions = decoder_inputs["mm_placeholders"]
decoder_mm_hashes = decoder_inputs["mm_hashes"]
if not all(
isinstance(leaf, str) for leaf in json_iter_leaves(decoder_mm_hashes)
):
raise ValueError(
f"mm_hashes must contain only strings, got: {decoder_mm_hashes}. "
"This is likely due to an incorrect custom implementation of "
"MultiModalProcessor.apply method."
)
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
@@ -562,13 +568,3 @@ class InputProcessor:
self._validate_model_input(encoder_inputs, prompt_type="encoder")
self._validate_model_input(decoder_inputs, prompt_type="decoder")
def stat_mm_cache(self) -> MultiModalCacheStats | None:
return self.input_preprocessor.stat_mm_cache()
def clear_mm_cache(self) -> None:
self.input_preprocessor.clear_mm_cache()
def close(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.close()

View File

@@ -320,7 +320,7 @@ class LLMEngine:
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=self.input_processor.stat_mm_cache(),
mm_cache_stats=self.renderer.stat_mm_cache(),
)
self.do_log_stats_with_interval()
@@ -333,7 +333,7 @@ class LLMEngine:
self.engine_core.profile(False)
def reset_mm_cache(self):
self.input_processor.clear_mm_cache()
self.renderer.clear_mm_cache()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(

View File

@@ -151,6 +151,12 @@ class MultiModalCacheStats(BaseCacheStats):
that were queried.
"""
def record(self, num_queries: int, num_hits: int) -> None:
"""Aggregate request information into the stats."""
self.requests += 1
self.queries += num_queries
self.hits += num_hits
@dataclass
class KVCacheEvictionEvent: