[Refactor] Decouple TimingContext from InputProcessingContext (#35083)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-23 22:15:50 +08:00
committed by GitHub
parent 1e8438a89a
commit 392645454b
38 changed files with 419 additions and 649 deletions

View File

@@ -389,13 +389,13 @@ def _test_processing_correctness_one(
mm_items = baseline_processor.info.parse_mm_data(mm_data)
ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
baseline_tokenized_result = baseline_processor.apply(
baseline_tokenized_result = baseline_processor(
token_prompt,
mm_items=mm_items,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply(
cached_tokenized_result = cached_processor(
token_prompt,
mm_items=mm_items,
hf_processor_mm_kwargs={},
@@ -409,12 +409,12 @@ def _test_processing_correctness_one(
)
if text_prompt is not None:
baseline_text_result = baseline_processor.apply(
baseline_text_result = baseline_processor(
text_prompt,
mm_items=mm_items,
hf_processor_mm_kwargs={},
)
cached_text_result = cached_processor.apply(
cached_text_result = cached_processor(
text_prompt,
mm_items=mm_items,
hf_processor_mm_kwargs={},

View File

@@ -176,7 +176,7 @@ def test_get_image_size_with_most_features(
for asset in image_assets:
mm_data = {"image": [asset.pil_image]}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,

View File

@@ -52,7 +52,7 @@ def test_processor_override(
metadata["fps"] = fps
mm_data = {"video": [(video, metadata)]}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@@ -104,12 +104,12 @@ def test_video_loader_consistency(
static_mm_data = {"video": [(static_video, static_metadata)]}
dynamic_mm_data = {"video": [(dynamic_video, dynamic_metadata)]}
static_outputs = processor.apply(
static_outputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(static_mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
dynamic_outputs = processor.apply(
dynamic_outputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(dynamic_mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,

View File

@@ -106,7 +106,7 @@ def _run_check(
for image in images
)
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=mm_processor_kwargs,

View File

@@ -61,7 +61,7 @@ def test_processor_override(
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,

View File

@@ -66,7 +66,7 @@ def _run_check(
for image in images
)
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=mm_processor_kwargs,

View File

@@ -49,7 +49,7 @@ def test_processor_override(
if tokenized_prompt:
prompt = tokenizer.encode(prompt)
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=mm_processor_kwargs,

View File

@@ -87,7 +87,7 @@ def _validate_image_prompt_replacements_one(
try:
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},

View File

@@ -87,7 +87,7 @@ def _validate_image_prompt_replacements_one(
try:
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},

View File

@@ -29,7 +29,7 @@ def test_processor_override(
image = Image.new("RGB", size=(364, 364))
mm_data = {"image": [image] * num_imgs}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
@@ -50,7 +50,7 @@ def _validate_image_prompt_replacements_one(
mm_data = {"image": [image] * num_imgs}
try:
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},

View File

@@ -68,7 +68,7 @@ def _run_check(
for image in images
)
print(total_expected_num_patches)
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=mm_processor_kwargs,

View File

@@ -47,7 +47,7 @@ def test_processor_override(
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,

View File

@@ -51,7 +51,7 @@ def test_processor_override(
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,

View File

@@ -42,7 +42,7 @@ def test_processor_override(
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@@ -88,7 +88,7 @@ def test_get_image_size_with_most_features(
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
for asset in image_assets:
mm_data = {"image": [asset.pil_image]}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,

View File

@@ -51,7 +51,7 @@ def test_processor_with_audio_sample_rate(
hf_processor_mm_kwargs: dict[str, Any] = {
"audio_sample_rate": audio_sample_rate,
}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@@ -94,7 +94,7 @@ def test_longer_audio_generates_more_tokens(model_id: str) -> None:
hf_processor_mm_kwargs: dict[str, Any] = {
"audio_sample_rate": audio_sample_rate,
}
processed = processor.apply(
processed = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,

View File

@@ -61,7 +61,7 @@ def test_processor_override(
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,

View File

@@ -99,7 +99,7 @@ def create_batched_mm_kwargs(
mm_counts=mm_counts,
mm_options={},
)
mm_items = processor_inputs.mm_items
mm_items = processor_inputs.mm_data_items
resized_mm_data = {
modality: resize_mm_data(items.data, size_factors)
for modality, items in mm_items.items()
@@ -108,11 +108,10 @@ def create_batched_mm_kwargs(
# video metadata will be added back to the resized video data here.
text_prompt, token_prompt = get_text_token_prompts(processor, resized_mm_data)
mm_kwargs = processor.apply(
mm_kwargs = processor(
prompt=token_prompt if text_prompt is None else text_prompt,
mm_items=processor.info.parse_mm_data(resized_mm_data),
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)["mm_kwargs"].require_data()
return group_mm_kwargs_by_modality(

View File

@@ -19,7 +19,7 @@ def test_multimodal_processor(model_id):
image_pil = ImageAsset("cherry_blossom").pil_image
mm_data = {"image": image_pil}
str_prompt = "<|im_start|>user <image>\nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501
str_processed_inputs = mm_processor.apply(
str_processed_inputs = mm_processor(
prompt=str_prompt,
mm_items=mm_processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
@@ -44,7 +44,7 @@ def test_multimodal_processor(model_id):
77091,
198,
]
ids_processed_inputs = mm_processor.apply(
ids_processed_inputs = mm_processor(
prompt=ids_prompt,
mm_items=mm_processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},

View File

@@ -934,7 +934,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
with exc_ctx:
processor.apply(
processor(
"<image>" * num_images,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},

View File

@@ -17,8 +17,9 @@ import argparse
import dataclasses
import json
import time
from collections import defaultdict
from datetime import datetime
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
import numpy as np
@@ -59,12 +60,13 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
Example:
{
'request-123': {
'hf_processor_time': 0.45,
'hashing_time': 0.02,
'cache_lookup_time': 0.01,
'prompt_update_time': 0.03,
'preprocessor_total_time': 0.51,
'encoder_forward_time': 0.23,
'get_mm_hashes_secs': 0.02,
'get_cache_missing_items_secs': 0.01,
'apply_hf_processor_secs': 0.45,
'merge_mm_kwargs_secs': 0.01,
'apply_prompt_updates_secs': 0.03,
'preprocessor_total_secs': 0.51,
'encoder_forward_secs': 0.23,
'num_encoder_calls': 1
}
}
@@ -74,8 +76,7 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
return {}
renderer = llm_engine.renderer
mm_processor = renderer.get_mm_processor()
preprocessing_stats = mm_processor.info.ctx.get_all_timing_stats()
mm_processor_stats = renderer._mm_timing_registry.stat()
encoder_stats = dict[str, dict[str, float]]()
for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"):
@@ -88,10 +89,10 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
else:
# Aggregate timing metrics across workers
current_time = encoder_stats[request_id].get(
"encoder_forward_time", 0.0
"encoder_forward_secs", 0.0
)
new_time = stats_dict.get("encoder_forward_time", 0.0)
encoder_stats[request_id]["encoder_forward_time"] = max(
new_time = stats_dict.get("encoder_forward_secs", 0.0)
encoder_stats[request_id]["encoder_forward_secs"] = max(
current_time, new_time
)
@@ -103,7 +104,7 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
merged_stats = dict[str, dict[str, float]]()
for request_id, prep_dict in preprocessing_stats.items():
for request_id, prep_dict in mm_processor_stats.items():
merged_stats[request_id] = dict(prep_dict)
for request_id, enc_dict in encoder_stats.items():
@@ -124,34 +125,18 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
return merged_stats
def collect_mm_processor_stats(
llm_engine: LLMEngine,
num_warmup_reqs: int = 0,
) -> dict[str, list[float]]:
def collect_mm_processor_stats(llm_engine: LLMEngine) -> dict[str, list[float]]:
"""
Collect multimodal processor timing stats.
Returns a dictionary mapping stage names to lists of timing values (in seconds).
"""
all_stats = get_timing_stats_from_engine(llm_engine)
stat_keys = [
"hf_processor_time",
"hashing_time",
"cache_lookup_time",
"prompt_update_time",
"preprocessor_total_time",
"encoder_forward_time",
"num_encoder_calls",
]
stats_by_stage = {key: [] for key in stat_keys}
stats_by_stage = defaultdict[str, list[float]](list)
# Skip warmup requests
stats_list = list(all_stats.values())[num_warmup_reqs:]
for stats_dict in stats_list:
for key in stat_keys:
if key in stats_dict:
stats_by_stage[key].append(stats_dict[key])
for stats_dict in all_stats.values():
for stat_key, stat_val in stats_dict.items():
stats_by_stage[stat_key].append(stat_val)
return stats_by_stage
@@ -159,13 +144,20 @@ def collect_mm_processor_stats(
def calculate_mm_processor_metrics(
stats_by_stage: dict[str, list[float]],
selected_percentiles: list[float],
*,
unit: Literal["us", "ms", "s"] = "ms",
) -> dict[str, dict[str, float]]:
"""
Calculate aggregate metrics from stats by stage.
"""
unit2mult = {"us": 1000000, "ms": 1000, "s": 1}
unit_mult = unit2mult[unit]
metrics = {}
for stage_name, times in stats_by_stage.items():
for stage, times in stats_by_stage.items():
stage_name = stage.replace("_secs", "_" + unit)
if not times:
metrics[stage_name] = {
"mean": 0.0,
@@ -175,8 +167,8 @@ def calculate_mm_processor_metrics(
}
continue
is_count_metric = stage_name == "num_encoder_calls"
values = times if is_count_metric else [t * 1000 for t in times]
is_count_metric = stage == "num_encoder_calls"
values = times if is_count_metric else [t * unit_mult for t in times]
metrics[stage_name] = {
"mean": float(np.mean(values)),
@@ -285,6 +277,9 @@ def benchmark_multimodal_processor(
use_tqdm=not getattr(args, "disable_tqdm", False),
)
# Clear stats from warmup requests
collect_mm_processor_stats(llm.llm_engine)
print(f"Processing {len(prompts)} requests...")
start_time = time.perf_counter()
@@ -295,7 +290,7 @@ def benchmark_multimodal_processor(
end_time = time.perf_counter()
total_time = end_time - start_time
mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine, num_warmups)
mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine)
if not any(mm_stats_by_stage.values()):
print(
@@ -475,11 +470,8 @@ def main(args: argparse.Namespace) -> None:
]
mm_data = []
for stage, metrics in result["mm_processor_stats"].items():
is_count = stage == "num_encoder_calls"
unit = "" if is_count else " (ms)"
row = {
"Stage": stage + unit,
"Stage": stage,
"Mean": f"{metrics['mean']:.2f}",
"Median": f"{metrics['median']:.2f}",
"Std": f"{metrics['std']:.2f}",

View File

@@ -41,15 +41,16 @@ from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
ProcessorInputs,
PromptIndexTargets,
PromptReplacement,
PromptUpdate,
TimingContext,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -204,23 +205,20 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
if mm_items:
if isinstance(prompt, str):
if len(prompt) > 0:
if inputs.mm_data_items:
if isinstance(inputs.prompt, str):
if len(inputs.prompt) > 0:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty text prompt."
)
else:
special_tokens = self.info.get_tokenizer().all_special_ids
if all(tok in special_tokens for tok in prompt):
prompt = []
if all(tok in special_tokens for tok in inputs.prompt):
inputs.prompt = []
else:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
@@ -229,18 +227,12 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
# For multi-modal data, the prompt after processing should
# only contain the dummy image tokens
tokenization_kwargs = {
**(tokenization_kwargs or {}),
inputs.tokenization_kwargs = {
**inputs.tokenization_kwargs,
"add_special_tokens": False,
}
return super().apply(
prompt=prompt,
mm_items=mm_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return super().apply(inputs, timing_ctx)
def _hf_processor_applies_updates(
self,

View File

@@ -30,15 +30,16 @@ from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalProcessingInfo,
ProcessorInputs,
PromptReplacement,
PromptUpdate,
TimingContext,
)
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
@@ -310,32 +311,17 @@ class DeepseekVL2MultiModalProcessor(
def _cached_apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 2:
return self._apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
if inputs.mm_data_items.get_count("image", strict=False) > 2:
return self._apply_hf_processor(inputs, timing_ctx)
return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return super()._cached_apply_hf_processor(inputs, timing_ctx)
@MULTIMODAL_REGISTRY.register_processor(

View File

@@ -21,13 +21,14 @@ from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing.processor import (
MultiModalProcessingInfo,
ProcessorInputs,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
TimingContext,
)
from vllm.tokenizers import TokenizerLike
@@ -490,32 +491,17 @@ class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingIn
def _cached_apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 1:
return self._apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
if inputs.mm_data_items.get_count("image", strict=False) > 1:
return self._apply_hf_processor(inputs, timing_ctx)
return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return super()._cached_apply_hf_processor(inputs, timing_ctx)
@MULTIMODAL_REGISTRY.register_processor(

View File

@@ -37,16 +37,17 @@ from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
ProcessorInputs,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
TimingContext,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -770,11 +771,8 @@ class MantisProcessingInfo(LlavaProcessingInfo):
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
@@ -785,15 +783,9 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
image_height=-1,
)
result = super().apply(
prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
result = super().apply(inputs, timing_ctx)
mm_item_counts = mm_items.get_all_counts()
mm_item_counts = inputs.mm_data_items.get_all_counts()
mm_kwargs = result["mm_kwargs"]
mm_hashes = result["mm_hashes"]
@@ -825,8 +817,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
)
orig_repls = self._get_mm_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
inputs.mm_data_items,
inputs.hf_processor_mm_kwargs,
mm_kwargs,
)
mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)

View File

@@ -21,16 +21,17 @@ from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
ProcessorInputs,
PromptIndexTargets,
PromptInsertion,
PromptUpdate,
PromptUpdateDetails,
TimingContext,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
@@ -228,19 +229,10 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
mm_inputs = super().apply(
prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
mm_inputs = super().apply(inputs, timing_ctx)
prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer()

View File

@@ -50,16 +50,17 @@ from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalProcessingInfo,
ProcessorInputs,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
TimingContext,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
@@ -277,7 +278,6 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_images = dummy_mm_data.get("image", [])
tokenization_kwargs = {"truncation": False}
request = ChatCompletionRequest(
messages=[
@@ -294,11 +294,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
return ProcessorInputs(
prompt=dummy_tokens,
mm_items=dummy_mm_items,
tokenization_kwargs=tokenization_kwargs,
)
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]):
@@ -344,19 +340,10 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
def _cached_apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(inputs, timing_ctx)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_info, True

View File

@@ -47,15 +47,16 @@ from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
ProcessorInputs,
PromptIndexTargets,
PromptReplacement,
PromptUpdate,
TimingContext,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -190,23 +191,20 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
if mm_items:
if isinstance(prompt, str):
if len(prompt) > 0:
if inputs.mm_data_items:
if isinstance(inputs.prompt, str):
if len(inputs.prompt) > 0:
raise ValueError(
"SigLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty text prompt."
)
else:
special_tokens = self.info.get_tokenizer().all_special_ids
if all(tok in special_tokens for tok in prompt):
prompt = []
if all(tok in special_tokens for tok in inputs.prompt):
inputs.prompt = []
else:
raise ValueError(
"SigLIP accepts text-only or image-only inputs, not both! "
@@ -214,19 +212,13 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
)
# For multi-modal data, the prompt after processing should
# only contain the image token
tokenization_kwargs = {
**(tokenization_kwargs or {}),
# only contain the dummy image tokens
inputs.tokenization_kwargs = {
**inputs.tokenization_kwargs,
"add_special_tokens": False,
}
return super().apply(
prompt=prompt,
mm_items=mm_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return super().apply(inputs, timing_ctx)
def _hf_processor_applies_updates(
self,

View File

@@ -54,13 +54,14 @@ from vllm.multimodal.parse import (
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
ProcessorInputs,
PromptUpdate,
TimingContext,
)
from vllm.sequence import IntermediateTensors
@@ -193,29 +194,21 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
if hf_processor_mm_kwargs is None:
hf_processor_mm_kwargs = {}
if tokenization_kwargs is None:
tokenization_kwargs = {}
mm_items = inputs.mm_data_items
hf_processor_mm_kwargs = inputs.hf_processor_mm_kwargs
mm_hashes = self._hash_mm_items(
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
_, passthrough_data = self._get_hf_mm_data(mm_items)
mm_processed_data = BatchFeature(
{k: torch.as_tensor(v).unsqueeze(0) for k, v in passthrough_data.items()},
tensor_type="pt",
)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
with timing_ctx.record("apply_hf_processor"):
_, passthrough_data = self._get_hf_mm_data(mm_items)
mm_processed_data = BatchFeature(
{
k: torch.as_tensor(v).unsqueeze(0)
for k, v in passthrough_data.items()
},
tensor_type="pt",
)
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data,
@@ -226,6 +219,11 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
),
)
with timing_ctx.record("get_mm_hashes"):
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
return mm_inputs(
prompt_token_ids=[1],
mm_kwargs=mm_kwargs,

View File

@@ -37,12 +37,13 @@ from vllm.multimodal.inputs import (
from vllm.multimodal.parse import (
ImageProcessorItems,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
ProcessorInputs,
TimingContext,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
@@ -177,11 +178,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
@@ -189,29 +187,30 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
"""
if hf_processor_mm_kwargs is None:
hf_processor_mm_kwargs = {}
if tokenization_kwargs is None:
tokenization_kwargs = {}
prompt = inputs.prompt
mm_items = inputs.mm_data_items
hf_processor_mm_kwargs = inputs.hf_processor_mm_kwargs
tokenization_kwargs = inputs.tokenization_kwargs
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if not isinstance(prompt, str):
# the prompt is the tokenized ids which is not supported
# by the hf_processor, which is why we would need to decode the ids
# into string
prompt = hf_processor.decode(prompt)
with timing_ctx.record("apply_hf_processor"):
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if not isinstance(prompt, str):
# the prompt is the tokenized ids which is not supported
# by the hf_processor, which is why we would need to decode the ids
# into string
prompt = hf_processor.decode(prompt)
# Bypass cached processor and always apply to the full set of mm inputs
# NOTE: we can't just set caching=False because base class method
# transforms outputs to `MultiModalKwargs` which is not going to
# work for Transformers. We have a lot of logic tied to
# `mm_tokens_per_modality` below
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# Bypass cached processor and always apply to the full set of mm inputs
# NOTE: we can't just set caching=False because base class method
# transforms outputs to `MultiModalKwargs` which is not going to
# work for Transformers. We have a lot of logic tied to
# `mm_tokens_per_modality` below
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# For gemma3 we check `token_type_ids` as the key
token_type_key = (
@@ -225,15 +224,14 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
# it for each input `mm_data`.
mm_positions = torch.where(mm_token_type_ids == 1)[1]
images = mm_items.get_items("image", ImageProcessorItems)
multimodal_config = self.info.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
image_sizes = []
for item_idx in range(len(images)):
image_size = images.get_image_size(item_idx)
image_sizes.append((image_size.height, image_size.width))
mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
image_sizes=image_sizes, **mm_processor_kwargs
image_sizes=image_sizes,
**self.info.ctx.get_merged_mm_kwargs({}),
)
mm_placeholders = {}
@@ -261,11 +259,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = self._hash_mm_items(
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
with timing_ctx.record("get_mm_hashes"):
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
return mm_inputs(
prompt_token_ids=prompt_ids,

View File

@@ -47,16 +47,17 @@ from vllm.multimodal.parse import (
AudioProcessorItems,
MultiModalDataItems,
MultiModalDataParser,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalProcessingInfo,
PlaceholderFeaturesInfo,
ProcessorInputs,
PromptReplacement,
PromptUpdate,
TimingContext,
)
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
@@ -265,13 +266,13 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
dummy_mm_inputs = self.info.parse_mm_data(
dummy_mm_items = self.info.parse_mm_data(
# whixtral tokenizer adds padding to the audio
# so we need to update the audio arrays
{**dummy_mm_data, "audio": [a.audio_array for a in res.audios]},
)
return ProcessorInputs(prompt=dummy_tokens, mm_items=dummy_mm_inputs)
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]):
@@ -361,19 +362,10 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
def _cached_apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(inputs, timing_ctx)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_info, True

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .context import BaseProcessingInfo, InputProcessingContext
from .dummy_inputs import BaseDummyInputsBuilder, ProcessorInputs
from .context import BaseProcessingInfo, InputProcessingContext, TimingContext
from .dummy_inputs import BaseDummyInputsBuilder
from .inputs import ProcessorInputs
from .processor import (
BaseMultiModalProcessor,
EncDecMultiModalProcessor,
@@ -15,6 +16,7 @@ from .processor import (
__all__ = [
"BaseProcessingInfo",
"InputProcessingContext",
"TimingContext",
"BaseDummyInputsBuilder",
"ProcessorInputs",
"BaseMultiModalProcessor",

View File

@@ -1,10 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextvars
import threading
import time
from abc import abstractmethod
from collections.abc import Generator, Mapping
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import cached_property
@@ -33,104 +31,53 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig, ObservabilityConfig
from vllm.config import ModelConfig
else:
PretrainedConfig = object
BatchFeature = object
ProcessorMixin = object
ModelConfig = object
ObservabilityConfig = object
logger = init_logger(__name__)
_request_id_context: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"_request_id_context", default=None
)
def get_current_request_id() -> str | None:
"""Get the current request_id from the context, if available."""
return _request_id_context.get()
@contextmanager
def set_request_id(request_id: str) -> Generator[None, None, None]:
"""Context manager to set the request_id for the current context."""
token = _request_id_context.set(request_id)
try:
yield
finally:
_request_id_context.reset(token)
@dataclass
class MultiModalProcessorTimingStats:
"""Per-request timing statistics for multimodal processor stages."""
class TimingContext:
"""Helper class to record execution times during multi-modal processing."""
hf_processor_time: float = 0.0
"""Time spent in HuggingFace processor calls (seconds)."""
enabled: bool = True
"""If disabled, `TimingContext.record` becomes a no-op."""
hashing_time: float = 0.0
"""Time spent computing multimodal item hashes (seconds)."""
stage_secs: dict[str, float] = field(default_factory=dict)
"""The execution time (in seconds) for each processing stage."""
cache_lookup_time: float = 0.0
"""Time spent in cache lookups and merges (seconds)."""
@property
def total_secs(self) -> float:
return sum(self.stage_secs.values())
prompt_update_time: float = 0.0
"""Time spent applying prompt updates and finding placeholders (seconds)."""
@contextmanager
def record(self, stage: str):
"""Record the execution time for a processing stage."""
if not self.enabled:
yield
return
preprocessor_total_time: float = 0.0
"""Total preprocessing time (seconds)."""
start_time = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start_time
self.stage_secs.setdefault(stage, 0.0)
self.stage_secs[stage] += elapsed
def to_dict(self) -> dict[str, float]:
"""Convert stats to a dictionary for JSON serialization."""
return {
"hf_processor_time": self.hf_processor_time,
"hashing_time": self.hashing_time,
"cache_lookup_time": self.cache_lookup_time,
"prompt_update_time": self.prompt_update_time,
"preprocessor_total_time": self.preprocessor_total_time,
def get_stats_dict(self):
stats_dict = {
f"{stage}_secs": time_s for stage, time_s in self.stage_secs.items()
}
stats_dict["preprocessor_total_secs"] = self.total_secs
@contextmanager
def timed_preprocessor_operation(ctx: "InputProcessingContext", stage_name: str):
"""
Context manager to time an operation using the context's timing stats.
The request_id is automatically retrieved from the context variable,
so it doesn't need to be passed as a parameter.
Args:
ctx: The InputProcessingContext containing the timing stats registry.
stage_name: Name of the stage being timed.
"""
request_id = get_current_request_id()
if ctx is None or request_id is None:
yield
return
stats = ctx.get_timing_stats(request_id)
if stats is None:
yield
return
start_time = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start_time
if stage_name == "hf_processor":
stats.hf_processor_time += elapsed
elif stage_name == "hashing":
stats.hashing_time += elapsed
elif stage_name == "cache_lookup":
stats.cache_lookup_time += elapsed
elif stage_name == "prompt_update":
stats.prompt_update_time += elapsed
stats.preprocessor_total_time += elapsed
return stats_dict
_T = TypeVar("_T")
@@ -151,21 +98,6 @@ class InputProcessingContext:
tokenizer: TokenizerLike | None
"""The tokenizer used to tokenize the inputs."""
observability_config: "ObservabilityConfig | None" = field(
default=None, compare=False, repr=False
)
"""Configuration for observability features."""
timing_stats_registry: dict[str, MultiModalProcessorTimingStats] = field(
default_factory=dict, compare=False, repr=False
)
"""Registry for storing timing stats keyed by request_id."""
_timing_stats_registry_lock: threading.Lock = field(
default_factory=threading.Lock, compare=False, repr=False
)
"""Lock for thread-safe access to timing_stats_registry."""
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
@@ -379,71 +311,6 @@ class InputProcessingContext:
return self._postprocess_output(output)
def get_timing_stats(
self, request_id: str
) -> MultiModalProcessorTimingStats | None:
"""
Get timing stats for a request.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return None
with self._timing_stats_registry_lock:
return self.timing_stats_registry.get(request_id)
def create_timing_stats(self, request_id: str) -> MultiModalProcessorTimingStats:
"""
Create and store timing stats in the registry for a request.
This should be called at the start of processing for a request.
The stats object is created immediately and stored in the registry.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return MultiModalProcessorTimingStats()
with self._timing_stats_registry_lock:
if request_id in self.timing_stats_registry:
raise ValueError(
f"Timing stats already exist for request_id: {request_id}"
)
stats = MultiModalProcessorTimingStats()
self.timing_stats_registry[request_id] = stats
return stats
def clear_timing_stats_registry(self) -> int:
"""
Clear all stats from the registry. Returns the number of stats cleared.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return 0
with self._timing_stats_registry_lock:
count = len(self.timing_stats_registry)
self.timing_stats_registry.clear()
return count
def get_all_timing_stats(self) -> dict[str, dict[str, float]]:
"""
Get all timing stats as a dictionary for API endpoints.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return {}
with self._timing_stats_registry_lock:
return {
rid: stats.to_dict()
for rid, stats in self.timing_stats_registry.items()
}
class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing."""

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, TypeVar
import numpy as np
@@ -18,27 +17,14 @@ from vllm.config.multimodal import (
from vllm.logger import init_logger
from ..inputs import MultiModalDataDict
from ..parse import MultiModalDataItems
from .context import BaseProcessingInfo
from .inputs import ProcessorInputs
_I = TypeVar("_I", bound=BaseProcessingInfo)
logger = init_logger(__name__)
@dataclass
class ProcessorInputs:
"""
Represents the keyword arguments to
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
"""
prompt: str | list[int]
mm_items: MultiModalDataItems
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
class BaseDummyInputsBuilder(ABC, Generic[_I]):
"""
Abstract base class that constructs the dummy data to profile
@@ -101,7 +87,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
return ProcessorInputs(
prompt=dummy_text,
mm_items=dummy_mm_items,
mm_data_items=dummy_mm_items,
tokenization_kwargs=tokenization_kwargs,
)

View File

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass, field
from ..hasher import MultiModalHasher
from ..inputs import MultiModalHashes
from ..parse import MultiModalDataItems, MultiModalUUIDItems
@dataclass
class ProcessorInputs:
"""
Represents the keyword arguments to
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
"""
prompt: str | list[int]
mm_data_items: MultiModalDataItems
mm_uuid_items: MultiModalUUIDItems | None = None
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
def get_mm_hashes(self, model_id: str) -> MultiModalHashes:
mm_data_items = self.mm_data_items
mm_uuid_items = self.mm_uuid_items or {}
hf_processor_mm_kwargs = self.hf_processor_mm_kwargs
mm_hashes: MultiModalHashes = {}
hasher = MultiModalHasher
for modality, data_items in mm_data_items.items():
if modality in mm_uuid_items:
uuid_items = mm_uuid_items[modality]
# For None entries, compute a hash; otherwise, use provided ID.
hashes: list[str] = []
for i, item in enumerate(data_items.get_all_items_for_hash()):
uuid_item = uuid_items[i]
# NOTE: Even if a uuid_item is provided, we still compute a hash
# if `hf_processor_mm_kwargs` is provided.
# This is because the processed multimodal inputs can be different
# depending on the processor kwargs.
if uuid_item is None or hf_processor_mm_kwargs:
# NOTE: use provided hash string to hash with kwargs
# if available for better performance.
item = uuid_item if uuid_item is not None else item
hashes.append(
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
)
else:
hashes.append(uuid_item)
mm_hashes[modality] = hashes
else:
mm_hashes[modality] = [
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
for item in data_items
]
return mm_hashes

View File

@@ -23,7 +23,6 @@ from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from ..hasher import MultiModalHasher
from ..inputs import (
MultiModalEncDecInputs,
MultiModalFieldConfig,
@@ -42,12 +41,9 @@ from ..parse import (
MultiModalDataItems,
MultiModalUUIDItems,
)
from .context import (
BaseProcessingInfo,
get_current_request_id,
timed_preprocessor_operation,
)
from .context import BaseProcessingInfo, TimingContext
from .dummy_inputs import BaseDummyInputsBuilder
from .inputs import ProcessorInputs
if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature
@@ -1017,13 +1013,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
) -> MultiModalInputs:
return self.apply(
processor_inputs = ProcessorInputs(
prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
hf_processor_mm_kwargs=hf_processor_mm_kwargs or {},
)
return self.apply(processor_inputs, TimingContext(enabled=False))
@abstractmethod
def _get_mm_fields_config(
self,
@@ -1139,12 +1137,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Call the HF processor on the prompt text and
associated multi-modal data.
"""
with timed_preprocessor_operation(self.info.ctx, "hf_processor"):
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
def _hf_processor_applies_updates(
self,
@@ -1306,60 +1303,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return prompt_ids, mm_processed_data, False
def _hash_mm_items(
self,
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalHashes:
model_id = self.info.model_id
if mm_uuid_items is None:
mm_uuid_items = {}
mm_hashes: MultiModalHashes = {}
hasher = MultiModalHasher
for modality, data_items in mm_data_items.items():
if modality in mm_uuid_items:
uuid_items = mm_uuid_items[modality]
# For None entries, compute a hash; otherwise, use provided ID.
hashes: list[str] = []
for i, item in enumerate(data_items.get_all_items_for_hash()):
uuid_item = uuid_items[i]
# NOTE: Even if a uuid_item is provided, we still compute a hash
# if `hf_processor_mm_kwargs` is provided.
# This is because the processed multimodal inputs can be different
# depending on the processor kwargs.
if uuid_item is None or hf_processor_mm_kwargs:
# NOTE: use provided hash string to hash with kwargs
# if available for better performance.
item = uuid_item if uuid_item is not None else item
hashes.append(
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
)
else:
hashes.append(uuid_item)
mm_hashes[modality] = hashes
else:
mm_hashes[modality] = [
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
for item in data_items
]
return mm_hashes
def _get_cache_missing_items(
self,
cache: BaseMultiModalProcessorCache,
@@ -1461,40 +1404,36 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
(
prompt_ids,
mm_processed_data,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
enable_hf_prompt_update=True,
)
with timing_ctx.record("apply_hf_processor"):
(
prompt_ids,
mm_processed_data,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=inputs.prompt,
mm_items=inputs.mm_data_items,
hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs,
tokenization_kwargs=inputs.tokenization_kwargs,
enable_hf_prompt_update=True,
)
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data,
self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
self._get_mm_fields_config(
mm_processed_data, inputs.hf_processor_mm_kwargs
),
)
# Use overrides if provided; fallback to data-dependent hashing.
with timed_preprocessor_operation(self.info.ctx, "hashing"):
mm_hashes = self._hash_mm_items(
mm_data_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
with timing_ctx.record("get_mm_hashes"):
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items,
hf_processor_mm_kwargs,
inputs.mm_data_items,
inputs.hf_processor_mm_kwargs,
mm_kwargs,
)
@@ -1508,11 +1447,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _cached_apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
"""
Apply the HF processor on the full prompt text,
@@ -1520,59 +1456,50 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
cache = self.cache
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
_, passthrough_data = self._get_hf_mm_data(inputs.mm_data_items)
if cache is None or passthrough_data:
return self._apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return self._apply_hf_processor(inputs, timing_ctx)
with timed_preprocessor_operation(self.info.ctx, "hashing"):
mm_hashes = self._hash_mm_items(
mm_data_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
with timing_ctx.record("get_mm_hashes"):
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
with timing_ctx.record("get_cache_missing_items"):
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
cache=cache,
mm_data_items=mm_data_items,
mm_data_items=inputs.mm_data_items,
mm_hashes=mm_hashes,
)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal
# items are combined with the cached multimodal items
(
prompt_ids,
mm_missing_processed_data,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
enable_hf_prompt_update=False,
)
with timing_ctx.record("apply_hf_processor"):
(
prompt_ids,
mm_missing_processed_data,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=inputs.prompt,
mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs,
tokenization_kwargs=inputs.tokenization_kwargs,
enable_hf_prompt_update=False,
)
mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_missing_processed_data,
self._get_mm_fields_config(
mm_missing_processed_data, hf_processor_mm_kwargs
mm_missing_processed_data, inputs.hf_processor_mm_kwargs
),
)
mm_missing_prompt_updates = self._get_mm_prompt_updates(
mm_missing_data_items,
hf_processor_mm_kwargs,
inputs.hf_processor_mm_kwargs,
mm_missing_kwargs,
)
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
with timing_ctx.record("merge_mm_kwargs"):
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
cache,
mm_hashes=mm_hashes,
@@ -1742,11 +1669,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
@@ -1761,31 +1685,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
3. Extract information about the placeholder tokens from the
processed token IDs.
"""
request_id = get_current_request_id()
if request_id is not None:
self.info.ctx.create_timing_stats(request_id)
if hf_processor_mm_kwargs is None:
hf_processor_mm_kwargs = {}
if tokenization_kwargs is None:
tokenization_kwargs = {}
(
prompt_ids,
mm_info,
is_update_applied,
) = self._cached_apply_hf_processor(
prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
) = self._cached_apply_hf_processor(inputs, timing_ctx)
# NOTE: tokenization_kwargs are not required to init processor
with timed_preprocessor_operation(self.info.ctx, "prompt_update"):
with timing_ctx.record("apply_prompt_updates"):
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items,
mm_items=inputs.mm_data_items,
prompt_ids=prompt_ids,
mm_kwargs=mm_info.kwargs,
mm_prompt_updates=mm_info.prompt_updates,
@@ -1851,11 +1760,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalEncDecInputs:
"""
Process multi-modal inputs to be used in vLLM.
@@ -1864,17 +1770,22 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt = self.create_encoder_prompt(prompt, mm_items)
encoder_inputs = super().apply(
encoder_prompt = self.create_encoder_prompt(
inputs.prompt,
inputs.mm_data_items,
)
encoder_processor_inputs = ProcessorInputs(
encoder_prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
inputs.mm_data_items,
inputs.mm_uuid_items,
hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs,
tokenization_kwargs=inputs.tokenization_kwargs,
)
encoder_inputs = super().apply(encoder_processor_inputs, timing_ctx)
return self._get_enc_dec_inputs(
prompt=prompt,
mm_items=mm_items,
prompt=inputs.prompt,
mm_items=inputs.mm_data_items,
encoder_inputs=encoder_inputs,
)

View File

@@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import dataclass
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
from vllm.config.observability import ObservabilityConfig
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
@@ -24,6 +25,7 @@ from .processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
TimingContext,
)
if TYPE_CHECKING:
@@ -174,32 +176,26 @@ class MultiModalRegistry:
def _create_processing_ctx(
self,
model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext:
if tokenizer is None:
tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(
model_config, tokenizer, observability_config=observability_config
)
return InputProcessingContext(model_config, tokenizer)
def _create_processing_info(
self,
model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
*,
tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
ctx = self._create_processing_ctx(model_config, tokenizer)
return factories.info(ctx)
def create_processor(
self,
model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
*,
tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None,
@@ -213,7 +209,7 @@ class MultiModalRegistry:
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
ctx = self._create_processing_ctx(model_config, tokenizer)
return factories.build_processor(ctx, cache=cache)
@@ -242,10 +238,8 @@ class MultiModalRegistry:
mm_options=mm_config.limit_per_prompt,
)
mm_inputs = processor.apply(
prompt=processor_inputs.prompt,
mm_items=processor_inputs.mm_items,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
processor_inputs,
timing_ctx=TimingContext(enabled=False),
)
prompt_token_ids = mm_inputs["prompt_token_ids"]
@@ -335,3 +329,34 @@ class MultiModalRegistry:
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
else:
raise ValueError(f"Unknown cache type: {cache_type!r}")
class MultiModalTimingRegistry:
def __init__(self, observability_config: "ObservabilityConfig | None") -> None:
super().__init__()
if observability_config and observability_config.enable_mm_processor_stats:
self._lock = threading.Lock()
self._ctx_by_request_id = defaultdict[str, TimingContext](TimingContext)
self._enabled = True
else:
self._enabled = False
def get(self, request_id: str) -> TimingContext:
if not self._enabled:
return TimingContext(enabled=False)
with self._lock:
return self._ctx_by_request_id[request_id]
def stat(self) -> dict[str, dict[str, float]]:
if not self._enabled:
return {}
with self._lock:
stats = {
req_id: ctx.get_stats_dict()
for req_id, ctx in self._ctx_by_request_id.items()
}
self._ctx_by_request_id.clear()
return stats

View File

@@ -85,13 +85,13 @@ class BaseRenderer(ABC, Generic[_T]):
self._mm_cache_stats: MultiModalCacheStats | None = None
if config.model_config.is_multimodal_model:
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
from vllm.multimodal.registry import MultiModalTimingRegistry
mm_processor_cache = mm_registry.processor_cache_from_config(config)
with set_default_torch_num_threads():
self.mm_processor = mm_registry.create_processor(
config.model_config,
config.observability_config,
tokenizer=tokenizer,
cache=mm_processor_cache,
)
@@ -102,6 +102,9 @@ class BaseRenderer(ABC, Generic[_T]):
# This is used to generate internal request ID for MM processing
# It has no relation to the request ID for engine core
self._mm_req_counter = AtomicCounter()
self._mm_timing_registry = MultiModalTimingRegistry(
config.observability_config
)
def get_tokenizer(self) -> _T:
tokenizer = self.tokenizer
@@ -534,7 +537,7 @@ class BaseRenderer(ABC, Generic[_T]):
tokenization_kwargs: dict[str, Any] | None,
) -> "MultiModalInputs":
from vllm.multimodal.parse import parse_mm_uuids
from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"
@@ -543,18 +546,21 @@ class BaseRenderer(ABC, Generic[_T]):
mm_data_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuid_items = parse_mm_uuids(mm_uuids)
mm_uuids = self._process_mm_uuids(
mm_uuid_items = self._process_mm_uuids(
mm_data, mm_data_items, mm_uuid_items, mm_req_id
)
with set_request_id(mm_req_id), set_default_torch_num_threads():
mm_inputs = mm_processor.apply(
prompt,
mm_data_items,
mm_uuid_items,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
mm_processor_inputs = MMProcessorInputs(
prompt,
mm_data_items,
mm_uuid_items,
hf_processor_mm_kwargs=mm_processor_kwargs or {},
tokenization_kwargs=tokenization_kwargs or {},
)
mm_timing_ctx = self._mm_timing_registry.get(mm_req_id)
with set_default_torch_num_threads():
mm_inputs = mm_processor.apply(mm_processor_inputs, mm_timing_ctx)
self.update_mm_cache_stats()

View File

@@ -6272,7 +6272,7 @@ class GPUModelRunner(
self.encoder_timing_registry[req_id] = EncoderTimingStats()
stats = self.encoder_timing_registry[req_id]
stats.encoder_forward_time += per_request_time
stats.encoder_forward_secs += per_request_time
stats.num_encoder_calls += 1
@@ -6280,7 +6280,7 @@ class GPUModelRunner(
class EncoderTimingStats:
"""Per-request timing statistics for encoder forward pass."""
encoder_forward_time: float = 0.0
encoder_forward_secs: float = 0.0
"""Time spent in vision encoder forward pass (seconds)."""
num_encoder_calls: int = 0
@@ -6288,6 +6288,6 @@ class EncoderTimingStats:
def to_dict(self) -> dict[str, float | int]:
return {
"encoder_forward_time": self.encoder_forward_time,
"encoder_forward_secs": self.encoder_forward_secs,
"num_encoder_calls": self.num_encoder_calls,
}