[Refactor] Decouple TimingContext from InputProcessingContext (#35083)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -389,13 +389,13 @@ def _test_processing_correctness_one(
|
||||
mm_items = baseline_processor.info.parse_mm_data(mm_data)
|
||||
ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
|
||||
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
baseline_tokenized_result = baseline_processor(
|
||||
token_prompt,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
cached_tokenized_result = cached_processor(
|
||||
token_prompt,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs={},
|
||||
@@ -409,12 +409,12 @@ def _test_processing_correctness_one(
|
||||
)
|
||||
|
||||
if text_prompt is not None:
|
||||
baseline_text_result = baseline_processor.apply(
|
||||
baseline_text_result = baseline_processor(
|
||||
text_prompt,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
cached_text_result = cached_processor.apply(
|
||||
cached_text_result = cached_processor(
|
||||
text_prompt,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs={},
|
||||
|
||||
@@ -176,7 +176,7 @@ def test_get_image_size_with_most_features(
|
||||
|
||||
for asset in image_assets:
|
||||
mm_data = {"image": [asset.pil_image]}
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
|
||||
@@ -52,7 +52,7 @@ def test_processor_override(
|
||||
metadata["fps"] = fps
|
||||
mm_data = {"video": [(video, metadata)]}
|
||||
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
@@ -104,12 +104,12 @@ def test_video_loader_consistency(
|
||||
static_mm_data = {"video": [(static_video, static_metadata)]}
|
||||
dynamic_mm_data = {"video": [(dynamic_video, dynamic_metadata)]}
|
||||
|
||||
static_outputs = processor.apply(
|
||||
static_outputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(static_mm_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
dynamic_outputs = processor.apply(
|
||||
dynamic_outputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(dynamic_mm_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
|
||||
@@ -106,7 +106,7 @@ def _run_check(
|
||||
for image in images
|
||||
)
|
||||
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
|
||||
@@ -61,7 +61,7 @@ def test_processor_override(
|
||||
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
|
||||
mm_data = {"image": [dummy_image] * num_imgs}
|
||||
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
|
||||
@@ -66,7 +66,7 @@ def _run_check(
|
||||
for image in images
|
||||
)
|
||||
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
|
||||
@@ -49,7 +49,7 @@ def test_processor_override(
|
||||
if tokenized_prompt:
|
||||
prompt = tokenizer.encode(prompt)
|
||||
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
|
||||
@@ -87,7 +87,7 @@ def _validate_image_prompt_replacements_one(
|
||||
try:
|
||||
# The processor will throw an error if there is a mismatch
|
||||
# in the prompt replacements
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs={},
|
||||
|
||||
@@ -87,7 +87,7 @@ def _validate_image_prompt_replacements_one(
|
||||
try:
|
||||
# The processor will throw an error if there is a mismatch
|
||||
# in the prompt replacements
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs={},
|
||||
|
||||
@@ -29,7 +29,7 @@ def test_processor_override(
|
||||
image = Image.new("RGB", size=(364, 364))
|
||||
mm_data = {"image": [image] * num_imgs}
|
||||
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs={},
|
||||
@@ -50,7 +50,7 @@ def _validate_image_prompt_replacements_one(
|
||||
mm_data = {"image": [image] * num_imgs}
|
||||
|
||||
try:
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs={},
|
||||
|
||||
@@ -68,7 +68,7 @@ def _run_check(
|
||||
for image in images
|
||||
)
|
||||
print(total_expected_num_patches)
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
|
||||
@@ -47,7 +47,7 @@ def test_processor_override(
|
||||
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
|
||||
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
|
||||
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_processor_override(
|
||||
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
|
||||
mm_data = {"image": [dummy_image] * num_imgs}
|
||||
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
|
||||
@@ -42,7 +42,7 @@ def test_processor_override(
|
||||
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
|
||||
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
|
||||
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
@@ -88,7 +88,7 @@ def test_get_image_size_with_most_features(
|
||||
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
for asset in image_assets:
|
||||
mm_data = {"image": [asset.pil_image]}
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_processor_with_audio_sample_rate(
|
||||
hf_processor_mm_kwargs: dict[str, Any] = {
|
||||
"audio_sample_rate": audio_sample_rate,
|
||||
}
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
@@ -94,7 +94,7 @@ def test_longer_audio_generates_more_tokens(model_id: str) -> None:
|
||||
hf_processor_mm_kwargs: dict[str, Any] = {
|
||||
"audio_sample_rate": audio_sample_rate,
|
||||
}
|
||||
processed = processor.apply(
|
||||
processed = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
|
||||
@@ -61,7 +61,7 @@ def test_processor_override(
|
||||
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
|
||||
mm_data = {"image": [dummy_image] * num_imgs}
|
||||
|
||||
processed_inputs = processor.apply(
|
||||
processed_inputs = processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
|
||||
@@ -99,7 +99,7 @@ def create_batched_mm_kwargs(
|
||||
mm_counts=mm_counts,
|
||||
mm_options={},
|
||||
)
|
||||
mm_items = processor_inputs.mm_items
|
||||
mm_items = processor_inputs.mm_data_items
|
||||
resized_mm_data = {
|
||||
modality: resize_mm_data(items.data, size_factors)
|
||||
for modality, items in mm_items.items()
|
||||
@@ -108,11 +108,10 @@ def create_batched_mm_kwargs(
|
||||
# video metadata will be added back to the resized video data here.
|
||||
text_prompt, token_prompt = get_text_token_prompts(processor, resized_mm_data)
|
||||
|
||||
mm_kwargs = processor.apply(
|
||||
mm_kwargs = processor(
|
||||
prompt=token_prompt if text_prompt is None else text_prompt,
|
||||
mm_items=processor.info.parse_mm_data(resized_mm_data),
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
||||
)["mm_kwargs"].require_data()
|
||||
|
||||
return group_mm_kwargs_by_modality(
|
||||
|
||||
@@ -19,7 +19,7 @@ def test_multimodal_processor(model_id):
|
||||
image_pil = ImageAsset("cherry_blossom").pil_image
|
||||
mm_data = {"image": image_pil}
|
||||
str_prompt = "<|im_start|>user <image>\nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501
|
||||
str_processed_inputs = mm_processor.apply(
|
||||
str_processed_inputs = mm_processor(
|
||||
prompt=str_prompt,
|
||||
mm_items=mm_processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs={},
|
||||
@@ -44,7 +44,7 @@ def test_multimodal_processor(model_id):
|
||||
77091,
|
||||
198,
|
||||
]
|
||||
ids_processed_inputs = mm_processor.apply(
|
||||
ids_processed_inputs = mm_processor(
|
||||
prompt=ids_prompt,
|
||||
mm_items=mm_processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs={},
|
||||
|
||||
@@ -934,7 +934,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
|
||||
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
|
||||
|
||||
with exc_ctx:
|
||||
processor.apply(
|
||||
processor(
|
||||
"<image>" * num_images,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs={},
|
||||
|
||||
@@ -17,8 +17,9 @@ import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -59,12 +60,13 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
|
||||
Example:
|
||||
{
|
||||
'request-123': {
|
||||
'hf_processor_time': 0.45,
|
||||
'hashing_time': 0.02,
|
||||
'cache_lookup_time': 0.01,
|
||||
'prompt_update_time': 0.03,
|
||||
'preprocessor_total_time': 0.51,
|
||||
'encoder_forward_time': 0.23,
|
||||
'get_mm_hashes_secs': 0.02,
|
||||
'get_cache_missing_items_secs': 0.01,
|
||||
'apply_hf_processor_secs': 0.45,
|
||||
'merge_mm_kwargs_secs': 0.01,
|
||||
'apply_prompt_updates_secs': 0.03,
|
||||
'preprocessor_total_secs': 0.51,
|
||||
'encoder_forward_secs': 0.23,
|
||||
'num_encoder_calls': 1
|
||||
}
|
||||
}
|
||||
@@ -74,8 +76,7 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
|
||||
return {}
|
||||
|
||||
renderer = llm_engine.renderer
|
||||
mm_processor = renderer.get_mm_processor()
|
||||
preprocessing_stats = mm_processor.info.ctx.get_all_timing_stats()
|
||||
mm_processor_stats = renderer._mm_timing_registry.stat()
|
||||
|
||||
encoder_stats = dict[str, dict[str, float]]()
|
||||
for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"):
|
||||
@@ -88,10 +89,10 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
|
||||
else:
|
||||
# Aggregate timing metrics across workers
|
||||
current_time = encoder_stats[request_id].get(
|
||||
"encoder_forward_time", 0.0
|
||||
"encoder_forward_secs", 0.0
|
||||
)
|
||||
new_time = stats_dict.get("encoder_forward_time", 0.0)
|
||||
encoder_stats[request_id]["encoder_forward_time"] = max(
|
||||
new_time = stats_dict.get("encoder_forward_secs", 0.0)
|
||||
encoder_stats[request_id]["encoder_forward_secs"] = max(
|
||||
current_time, new_time
|
||||
)
|
||||
|
||||
@@ -103,7 +104,7 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
|
||||
|
||||
merged_stats = dict[str, dict[str, float]]()
|
||||
|
||||
for request_id, prep_dict in preprocessing_stats.items():
|
||||
for request_id, prep_dict in mm_processor_stats.items():
|
||||
merged_stats[request_id] = dict(prep_dict)
|
||||
|
||||
for request_id, enc_dict in encoder_stats.items():
|
||||
@@ -124,34 +125,18 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
|
||||
return merged_stats
|
||||
|
||||
|
||||
def collect_mm_processor_stats(
|
||||
llm_engine: LLMEngine,
|
||||
num_warmup_reqs: int = 0,
|
||||
) -> dict[str, list[float]]:
|
||||
def collect_mm_processor_stats(llm_engine: LLMEngine) -> dict[str, list[float]]:
|
||||
"""
|
||||
Collect multimodal processor timing stats.
|
||||
Returns a dictionary mapping stage names to lists of timing values (in seconds).
|
||||
"""
|
||||
all_stats = get_timing_stats_from_engine(llm_engine)
|
||||
|
||||
stat_keys = [
|
||||
"hf_processor_time",
|
||||
"hashing_time",
|
||||
"cache_lookup_time",
|
||||
"prompt_update_time",
|
||||
"preprocessor_total_time",
|
||||
"encoder_forward_time",
|
||||
"num_encoder_calls",
|
||||
]
|
||||
stats_by_stage = {key: [] for key in stat_keys}
|
||||
stats_by_stage = defaultdict[str, list[float]](list)
|
||||
|
||||
# Skip warmup requests
|
||||
stats_list = list(all_stats.values())[num_warmup_reqs:]
|
||||
|
||||
for stats_dict in stats_list:
|
||||
for key in stat_keys:
|
||||
if key in stats_dict:
|
||||
stats_by_stage[key].append(stats_dict[key])
|
||||
for stats_dict in all_stats.values():
|
||||
for stat_key, stat_val in stats_dict.items():
|
||||
stats_by_stage[stat_key].append(stat_val)
|
||||
|
||||
return stats_by_stage
|
||||
|
||||
@@ -159,13 +144,20 @@ def collect_mm_processor_stats(
|
||||
def calculate_mm_processor_metrics(
|
||||
stats_by_stage: dict[str, list[float]],
|
||||
selected_percentiles: list[float],
|
||||
*,
|
||||
unit: Literal["us", "ms", "s"] = "ms",
|
||||
) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
Calculate aggregate metrics from stats by stage.
|
||||
"""
|
||||
unit2mult = {"us": 1000000, "ms": 1000, "s": 1}
|
||||
unit_mult = unit2mult[unit]
|
||||
|
||||
metrics = {}
|
||||
|
||||
for stage_name, times in stats_by_stage.items():
|
||||
for stage, times in stats_by_stage.items():
|
||||
stage_name = stage.replace("_secs", "_" + unit)
|
||||
|
||||
if not times:
|
||||
metrics[stage_name] = {
|
||||
"mean": 0.0,
|
||||
@@ -175,8 +167,8 @@ def calculate_mm_processor_metrics(
|
||||
}
|
||||
continue
|
||||
|
||||
is_count_metric = stage_name == "num_encoder_calls"
|
||||
values = times if is_count_metric else [t * 1000 for t in times]
|
||||
is_count_metric = stage == "num_encoder_calls"
|
||||
values = times if is_count_metric else [t * unit_mult for t in times]
|
||||
|
||||
metrics[stage_name] = {
|
||||
"mean": float(np.mean(values)),
|
||||
@@ -285,6 +277,9 @@ def benchmark_multimodal_processor(
|
||||
use_tqdm=not getattr(args, "disable_tqdm", False),
|
||||
)
|
||||
|
||||
# Clear stats from warmup requests
|
||||
collect_mm_processor_stats(llm.llm_engine)
|
||||
|
||||
print(f"Processing {len(prompts)} requests...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
@@ -295,7 +290,7 @@ def benchmark_multimodal_processor(
|
||||
end_time = time.perf_counter()
|
||||
total_time = end_time - start_time
|
||||
|
||||
mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine, num_warmups)
|
||||
mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine)
|
||||
|
||||
if not any(mm_stats_by_stage.values()):
|
||||
print(
|
||||
@@ -475,11 +470,8 @@ def main(args: argparse.Namespace) -> None:
|
||||
]
|
||||
mm_data = []
|
||||
for stage, metrics in result["mm_processor_stats"].items():
|
||||
is_count = stage == "num_encoder_calls"
|
||||
unit = "" if is_count else " (ms)"
|
||||
|
||||
row = {
|
||||
"Stage": stage + unit,
|
||||
"Stage": stage,
|
||||
"Mean": f"{metrics['mean']:.2f}",
|
||||
"Median": f"{metrics['median']:.2f}",
|
||||
"Std": f"{metrics['std']:.2f}",
|
||||
|
||||
@@ -41,15 +41,16 @@ from vllm.multimodal.parse import (
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
MultiModalUUIDItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
ProcessorInputs,
|
||||
PromptIndexTargets,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
TimingContext,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@@ -204,23 +205,20 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None = None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> MultiModalInputs:
|
||||
if mm_items:
|
||||
if isinstance(prompt, str):
|
||||
if len(prompt) > 0:
|
||||
if inputs.mm_data_items:
|
||||
if isinstance(inputs.prompt, str):
|
||||
if len(inputs.prompt) > 0:
|
||||
raise ValueError(
|
||||
"CLIP accepts text-only or image-only inputs, not both! "
|
||||
"You must pass an image with an empty text prompt."
|
||||
)
|
||||
else:
|
||||
special_tokens = self.info.get_tokenizer().all_special_ids
|
||||
if all(tok in special_tokens for tok in prompt):
|
||||
prompt = []
|
||||
if all(tok in special_tokens for tok in inputs.prompt):
|
||||
inputs.prompt = []
|
||||
else:
|
||||
raise ValueError(
|
||||
"CLIP accepts text-only or image-only inputs, not both! "
|
||||
@@ -229,18 +227,12 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
|
||||
|
||||
# For multi-modal data, the prompt after processing should
|
||||
# only contain the dummy image tokens
|
||||
tokenization_kwargs = {
|
||||
**(tokenization_kwargs or {}),
|
||||
inputs.tokenization_kwargs = {
|
||||
**inputs.tokenization_kwargs,
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
|
||||
return super().apply(
|
||||
prompt=prompt,
|
||||
mm_items=mm_items,
|
||||
mm_uuid_items=mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
return super().apply(inputs, timing_ctx)
|
||||
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
|
||||
@@ -30,15 +30,16 @@ from vllm.multimodal.parse import (
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
MultiModalUUIDItems,
|
||||
)
|
||||
from vllm.multimodal.processing import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.processing.processor import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalProcessingInfo,
|
||||
ProcessorInputs,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
TimingContext,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
@@ -310,32 +311,17 @@ class DeepseekVL2MultiModalProcessor(
|
||||
|
||||
def _cached_apply_hf_processor(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
# The processor logic is different for len(images) <= 2 vs > 2
|
||||
# Since the processing cache assumes that the processor output is
|
||||
# invariant of how many images are passed per prompt, we only
|
||||
# perform caching for the most common case
|
||||
if mm_data_items.get_count("image", strict=False) > 2:
|
||||
return self._apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
mm_uuid_items=mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
if inputs.mm_data_items.get_count("image", strict=False) > 2:
|
||||
return self._apply_hf_processor(inputs, timing_ctx)
|
||||
|
||||
return super()._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
mm_uuid_items=mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
return super()._cached_apply_hf_processor(inputs, timing_ctx)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
|
||||
@@ -21,13 +21,14 @@ from vllm.multimodal.parse import (
|
||||
ImageEmbeddingItems,
|
||||
ImageProcessorItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalUUIDItems,
|
||||
)
|
||||
from vllm.multimodal.processing.processor import (
|
||||
MultiModalProcessingInfo,
|
||||
ProcessorInputs,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
TimingContext,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
@@ -490,32 +491,17 @@ class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingIn
|
||||
|
||||
def _cached_apply_hf_processor(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
# The processor logic is different for len(images) <= 1 vs > 1
|
||||
# Since the processing cache assumes that the processor output is
|
||||
# invariant of how many images are passed per prompt, we only
|
||||
# perform caching for the most common case
|
||||
if mm_data_items.get_count("image", strict=False) > 1:
|
||||
return self._apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
mm_uuid_items=mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
if inputs.mm_data_items.get_count("image", strict=False) > 1:
|
||||
return self._apply_hf_processor(inputs, timing_ctx)
|
||||
|
||||
return super()._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
mm_uuid_items=mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
return super()._cached_apply_hf_processor(inputs, timing_ctx)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
|
||||
@@ -37,16 +37,17 @@ from vllm.multimodal.parse import (
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
MultiModalUUIDItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
ProcessorInputs,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
TimingContext,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@@ -770,11 +771,8 @@ class MantisProcessingInfo(LlavaProcessingInfo):
|
||||
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
def apply(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None = None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> MultiModalInputs:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
@@ -785,15 +783,9 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
image_height=-1,
|
||||
)
|
||||
|
||||
result = super().apply(
|
||||
prompt,
|
||||
mm_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
result = super().apply(inputs, timing_ctx)
|
||||
|
||||
mm_item_counts = mm_items.get_all_counts()
|
||||
mm_item_counts = inputs.mm_data_items.get_all_counts()
|
||||
mm_kwargs = result["mm_kwargs"]
|
||||
mm_hashes = result["mm_hashes"]
|
||||
|
||||
@@ -825,8 +817,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
)
|
||||
|
||||
orig_repls = self._get_mm_prompt_updates(
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
inputs.mm_data_items,
|
||||
inputs.hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
|
||||
|
||||
@@ -21,16 +21,17 @@ from vllm.multimodal.parse import (
|
||||
ImageEmbeddingItems,
|
||||
ImageProcessorItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalUUIDItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
ProcessorInputs,
|
||||
PromptIndexTargets,
|
||||
PromptInsertion,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
TimingContext,
|
||||
)
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -228,19 +229,10 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None = None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> MultiModalInputs:
|
||||
mm_inputs = super().apply(
|
||||
prompt,
|
||||
mm_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
mm_inputs = super().apply(inputs, timing_ctx)
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
@@ -50,16 +50,17 @@ from vllm.multimodal.parse import (
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
MultiModalUUIDItems,
|
||||
)
|
||||
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.multimodal.processing import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.processing.processor import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalProcessingInfo,
|
||||
ProcessorInputs,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
TimingContext,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -277,7 +278,6 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
dummy_images = dummy_mm_data.get("image", [])
|
||||
tokenization_kwargs = {"truncation": False}
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
messages=[
|
||||
@@ -294,11 +294,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
|
||||
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt=dummy_tokens,
|
||||
mm_items=dummy_mm_items,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
|
||||
|
||||
|
||||
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]):
|
||||
@@ -344,19 +340,10 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
|
||||
|
||||
def _cached_apply_hf_processor(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
mm_uuid_items=mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(inputs, timing_ctx)
|
||||
|
||||
# NOTE: The tokens are already inserted by the chat template
|
||||
return prompt_ids, mm_info, True
|
||||
|
||||
@@ -47,15 +47,16 @@ from vllm.multimodal.parse import (
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
MultiModalUUIDItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
ProcessorInputs,
|
||||
PromptIndexTargets,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
TimingContext,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@@ -190,23 +191,20 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None = None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> MultiModalInputs:
|
||||
if mm_items:
|
||||
if isinstance(prompt, str):
|
||||
if len(prompt) > 0:
|
||||
if inputs.mm_data_items:
|
||||
if isinstance(inputs.prompt, str):
|
||||
if len(inputs.prompt) > 0:
|
||||
raise ValueError(
|
||||
"SigLIP accepts text-only or image-only inputs, not both! "
|
||||
"You must pass an image with an empty text prompt."
|
||||
)
|
||||
else:
|
||||
special_tokens = self.info.get_tokenizer().all_special_ids
|
||||
if all(tok in special_tokens for tok in prompt):
|
||||
prompt = []
|
||||
if all(tok in special_tokens for tok in inputs.prompt):
|
||||
inputs.prompt = []
|
||||
else:
|
||||
raise ValueError(
|
||||
"SigLIP accepts text-only or image-only inputs, not both! "
|
||||
@@ -214,19 +212,13 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
|
||||
)
|
||||
|
||||
# For multi-modal data, the prompt after processing should
|
||||
# only contain the image token
|
||||
tokenization_kwargs = {
|
||||
**(tokenization_kwargs or {}),
|
||||
# only contain the dummy image tokens
|
||||
inputs.tokenization_kwargs = {
|
||||
**inputs.tokenization_kwargs,
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
|
||||
return super().apply(
|
||||
prompt=prompt,
|
||||
mm_items=mm_items,
|
||||
mm_uuid_items=mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
return super().apply(inputs, timing_ctx)
|
||||
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
|
||||
@@ -54,13 +54,14 @@ from vllm.multimodal.parse import (
|
||||
ModalityDataItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
MultiModalUUIDItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
ProcessorInputs,
|
||||
PromptUpdate,
|
||||
TimingContext,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@@ -193,29 +194,21 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None = None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> MultiModalInputs:
|
||||
if hf_processor_mm_kwargs is None:
|
||||
hf_processor_mm_kwargs = {}
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
mm_items = inputs.mm_data_items
|
||||
hf_processor_mm_kwargs = inputs.hf_processor_mm_kwargs
|
||||
|
||||
mm_hashes = self._hash_mm_items(
|
||||
mm_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
_, passthrough_data = self._get_hf_mm_data(mm_items)
|
||||
mm_processed_data = BatchFeature(
|
||||
{k: torch.as_tensor(v).unsqueeze(0) for k, v in passthrough_data.items()},
|
||||
tensor_type="pt",
|
||||
)
|
||||
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
|
||||
with timing_ctx.record("apply_hf_processor"):
|
||||
_, passthrough_data = self._get_hf_mm_data(mm_items)
|
||||
mm_processed_data = BatchFeature(
|
||||
{
|
||||
k: torch.as_tensor(v).unsqueeze(0)
|
||||
for k, v in passthrough_data.items()
|
||||
},
|
||||
tensor_type="pt",
|
||||
)
|
||||
|
||||
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
|
||||
mm_processed_data,
|
||||
@@ -226,6 +219,11 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
|
||||
),
|
||||
)
|
||||
|
||||
with timing_ctx.record("get_mm_hashes"):
|
||||
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
|
||||
|
||||
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
|
||||
|
||||
return mm_inputs(
|
||||
prompt_token_ids=[1],
|
||||
mm_kwargs=mm_kwargs,
|
||||
|
||||
@@ -37,12 +37,13 @@ from vllm.multimodal.inputs import (
|
||||
from vllm.multimodal.parse import (
|
||||
ImageProcessorItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalUUIDItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
ProcessorInputs,
|
||||
TimingContext,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -177,11 +178,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None = None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
@@ -189,29 +187,30 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
Apply HF Processor on prompt text and multi-modal data together,
|
||||
outputting token IDs and processed tensors.
|
||||
"""
|
||||
if hf_processor_mm_kwargs is None:
|
||||
hf_processor_mm_kwargs = {}
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
prompt = inputs.prompt
|
||||
mm_items = inputs.mm_data_items
|
||||
hf_processor_mm_kwargs = inputs.hf_processor_mm_kwargs
|
||||
tokenization_kwargs = inputs.tokenization_kwargs
|
||||
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
if not isinstance(prompt, str):
|
||||
# the prompt is the tokenized ids which is not supported
|
||||
# by the hf_processor, which is why we would need to decode the ids
|
||||
# into string
|
||||
prompt = hf_processor.decode(prompt)
|
||||
with timing_ctx.record("apply_hf_processor"):
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
if not isinstance(prompt, str):
|
||||
# the prompt is the tokenized ids which is not supported
|
||||
# by the hf_processor, which is why we would need to decode the ids
|
||||
# into string
|
||||
prompt = hf_processor.decode(prompt)
|
||||
|
||||
# Bypass cached processor and always apply to the full set of mm inputs
|
||||
# NOTE: we can't just set caching=False because base class method
|
||||
# transforms outputs to `MultiModalKwargs` which is not going to
|
||||
# work for Transformers. We have a lot of logic tied to
|
||||
# `mm_tokens_per_modality` below
|
||||
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
|
||||
prompt_text=prompt,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
# Bypass cached processor and always apply to the full set of mm inputs
|
||||
# NOTE: we can't just set caching=False because base class method
|
||||
# transforms outputs to `MultiModalKwargs` which is not going to
|
||||
# work for Transformers. We have a lot of logic tied to
|
||||
# `mm_tokens_per_modality` below
|
||||
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
|
||||
prompt_text=prompt,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
# For gemma3 we check `token_type_ids` as the key
|
||||
token_type_key = (
|
||||
@@ -225,15 +224,14 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
# it for each input `mm_data`.
|
||||
mm_positions = torch.where(mm_token_type_ids == 1)[1]
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
multimodal_config = self.info.ctx.model_config.multimodal_config
|
||||
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
|
||||
image_sizes = []
|
||||
for item_idx in range(len(images)):
|
||||
image_size = images.get_image_size(item_idx)
|
||||
image_sizes.append((image_size.height, image_size.width))
|
||||
|
||||
mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
|
||||
image_sizes=image_sizes, **mm_processor_kwargs
|
||||
image_sizes=image_sizes,
|
||||
**self.info.ctx.get_merged_mm_kwargs({}),
|
||||
)
|
||||
|
||||
mm_placeholders = {}
|
||||
@@ -261,11 +259,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
)
|
||||
|
||||
# Use overrides if provided; fallback to data-dependent hashing.
|
||||
mm_hashes = self._hash_mm_items(
|
||||
mm_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
with timing_ctx.record("get_mm_hashes"):
|
||||
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
|
||||
|
||||
return mm_inputs(
|
||||
prompt_token_ids=prompt_ids,
|
||||
|
||||
@@ -47,16 +47,17 @@ from vllm.multimodal.parse import (
|
||||
AudioProcessorItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
MultiModalUUIDItems,
|
||||
)
|
||||
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.multimodal.processing import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.processing.processor import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalProcessingInfo,
|
||||
PlaceholderFeaturesInfo,
|
||||
ProcessorInputs,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
TimingContext,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
@@ -265,13 +266,13 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
dummy_tokens = res.tokens
|
||||
|
||||
dummy_mm_inputs = self.info.parse_mm_data(
|
||||
dummy_mm_items = self.info.parse_mm_data(
|
||||
# whixtral tokenizer adds padding to the audio
|
||||
# so we need to update the audio arrays
|
||||
{**dummy_mm_data, "audio": [a.audio_array for a in res.audios]},
|
||||
)
|
||||
|
||||
return ProcessorInputs(prompt=dummy_tokens, mm_items=dummy_mm_inputs)
|
||||
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
|
||||
|
||||
|
||||
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]):
|
||||
@@ -361,19 +362,10 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
|
||||
|
||||
def _cached_apply_hf_processor(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
mm_uuid_items=mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(inputs, timing_ctx)
|
||||
|
||||
# NOTE: The tokens are already inserted by the chat template
|
||||
return prompt_ids, mm_info, True
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .context import BaseProcessingInfo, InputProcessingContext
|
||||
from .dummy_inputs import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from .context import BaseProcessingInfo, InputProcessingContext, TimingContext
|
||||
from .dummy_inputs import BaseDummyInputsBuilder
|
||||
from .inputs import ProcessorInputs
|
||||
from .processor import (
|
||||
BaseMultiModalProcessor,
|
||||
EncDecMultiModalProcessor,
|
||||
@@ -15,6 +16,7 @@ from .processor import (
|
||||
__all__ = [
|
||||
"BaseProcessingInfo",
|
||||
"InputProcessingContext",
|
||||
"TimingContext",
|
||||
"BaseDummyInputsBuilder",
|
||||
"ProcessorInputs",
|
||||
"BaseMultiModalProcessor",
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextvars
|
||||
import threading
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
@@ -33,104 +31,53 @@ if TYPE_CHECKING:
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
from vllm.config import ModelConfig, ObservabilityConfig
|
||||
from vllm.config import ModelConfig
|
||||
else:
|
||||
PretrainedConfig = object
|
||||
BatchFeature = object
|
||||
ProcessorMixin = object
|
||||
|
||||
ModelConfig = object
|
||||
ObservabilityConfig = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
_request_id_context: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
"_request_id_context", default=None
|
||||
)
|
||||
|
||||
|
||||
def get_current_request_id() -> str | None:
|
||||
"""Get the current request_id from the context, if available."""
|
||||
return _request_id_context.get()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_request_id(request_id: str) -> Generator[None, None, None]:
|
||||
"""Context manager to set the request_id for the current context."""
|
||||
token = _request_id_context.set(request_id)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_request_id_context.reset(token)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalProcessorTimingStats:
|
||||
"""Per-request timing statistics for multimodal processor stages."""
|
||||
class TimingContext:
|
||||
"""Helper class to record execution times during multi-modal processing."""
|
||||
|
||||
hf_processor_time: float = 0.0
|
||||
"""Time spent in HuggingFace processor calls (seconds)."""
|
||||
enabled: bool = True
|
||||
"""If disabled, `TimingContext.record` becomes a no-op."""
|
||||
|
||||
hashing_time: float = 0.0
|
||||
"""Time spent computing multimodal item hashes (seconds)."""
|
||||
stage_secs: dict[str, float] = field(default_factory=dict)
|
||||
"""The execution time (in seconds) for each processing stage."""
|
||||
|
||||
cache_lookup_time: float = 0.0
|
||||
"""Time spent in cache lookups and merges (seconds)."""
|
||||
@property
|
||||
def total_secs(self) -> float:
|
||||
return sum(self.stage_secs.values())
|
||||
|
||||
prompt_update_time: float = 0.0
|
||||
"""Time spent applying prompt updates and finding placeholders (seconds)."""
|
||||
@contextmanager
|
||||
def record(self, stage: str):
|
||||
"""Record the execution time for a processing stage."""
|
||||
if not self.enabled:
|
||||
yield
|
||||
return
|
||||
|
||||
preprocessor_total_time: float = 0.0
|
||||
"""Total preprocessing time (seconds)."""
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
elapsed = time.perf_counter() - start_time
|
||||
self.stage_secs.setdefault(stage, 0.0)
|
||||
self.stage_secs[stage] += elapsed
|
||||
|
||||
def to_dict(self) -> dict[str, float]:
|
||||
"""Convert stats to a dictionary for JSON serialization."""
|
||||
return {
|
||||
"hf_processor_time": self.hf_processor_time,
|
||||
"hashing_time": self.hashing_time,
|
||||
"cache_lookup_time": self.cache_lookup_time,
|
||||
"prompt_update_time": self.prompt_update_time,
|
||||
"preprocessor_total_time": self.preprocessor_total_time,
|
||||
def get_stats_dict(self):
|
||||
stats_dict = {
|
||||
f"{stage}_secs": time_s for stage, time_s in self.stage_secs.items()
|
||||
}
|
||||
stats_dict["preprocessor_total_secs"] = self.total_secs
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timed_preprocessor_operation(ctx: "InputProcessingContext", stage_name: str):
|
||||
"""
|
||||
Context manager to time an operation using the context's timing stats.
|
||||
|
||||
The request_id is automatically retrieved from the context variable,
|
||||
so it doesn't need to be passed as a parameter.
|
||||
|
||||
Args:
|
||||
ctx: The InputProcessingContext containing the timing stats registry.
|
||||
stage_name: Name of the stage being timed.
|
||||
"""
|
||||
request_id = get_current_request_id()
|
||||
if ctx is None or request_id is None:
|
||||
yield
|
||||
return
|
||||
|
||||
stats = ctx.get_timing_stats(request_id)
|
||||
if stats is None:
|
||||
yield
|
||||
return
|
||||
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
elapsed = time.perf_counter() - start_time
|
||||
if stage_name == "hf_processor":
|
||||
stats.hf_processor_time += elapsed
|
||||
elif stage_name == "hashing":
|
||||
stats.hashing_time += elapsed
|
||||
elif stage_name == "cache_lookup":
|
||||
stats.cache_lookup_time += elapsed
|
||||
elif stage_name == "prompt_update":
|
||||
stats.prompt_update_time += elapsed
|
||||
stats.preprocessor_total_time += elapsed
|
||||
return stats_dict
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
@@ -151,21 +98,6 @@ class InputProcessingContext:
|
||||
tokenizer: TokenizerLike | None
|
||||
"""The tokenizer used to tokenize the inputs."""
|
||||
|
||||
observability_config: "ObservabilityConfig | None" = field(
|
||||
default=None, compare=False, repr=False
|
||||
)
|
||||
"""Configuration for observability features."""
|
||||
|
||||
timing_stats_registry: dict[str, MultiModalProcessorTimingStats] = field(
|
||||
default_factory=dict, compare=False, repr=False
|
||||
)
|
||||
"""Registry for storing timing stats keyed by request_id."""
|
||||
|
||||
_timing_stats_registry_lock: threading.Lock = field(
|
||||
default_factory=threading.Lock, compare=False, repr=False
|
||||
)
|
||||
"""Lock for thread-safe access to timing_stats_registry."""
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
@@ -379,71 +311,6 @@ class InputProcessingContext:
|
||||
|
||||
return self._postprocess_output(output)
|
||||
|
||||
def get_timing_stats(
|
||||
self, request_id: str
|
||||
) -> MultiModalProcessorTimingStats | None:
|
||||
"""
|
||||
Get timing stats for a request.
|
||||
"""
|
||||
if (
|
||||
self.observability_config is None
|
||||
or not self.observability_config.enable_mm_processor_stats
|
||||
):
|
||||
return None
|
||||
with self._timing_stats_registry_lock:
|
||||
return self.timing_stats_registry.get(request_id)
|
||||
|
||||
def create_timing_stats(self, request_id: str) -> MultiModalProcessorTimingStats:
|
||||
"""
|
||||
Create and store timing stats in the registry for a request.
|
||||
|
||||
This should be called at the start of processing for a request.
|
||||
The stats object is created immediately and stored in the registry.
|
||||
"""
|
||||
if (
|
||||
self.observability_config is None
|
||||
or not self.observability_config.enable_mm_processor_stats
|
||||
):
|
||||
return MultiModalProcessorTimingStats()
|
||||
|
||||
with self._timing_stats_registry_lock:
|
||||
if request_id in self.timing_stats_registry:
|
||||
raise ValueError(
|
||||
f"Timing stats already exist for request_id: {request_id}"
|
||||
)
|
||||
stats = MultiModalProcessorTimingStats()
|
||||
self.timing_stats_registry[request_id] = stats
|
||||
return stats
|
||||
|
||||
def clear_timing_stats_registry(self) -> int:
|
||||
"""
|
||||
Clear all stats from the registry. Returns the number of stats cleared.
|
||||
"""
|
||||
if (
|
||||
self.observability_config is None
|
||||
or not self.observability_config.enable_mm_processor_stats
|
||||
):
|
||||
return 0
|
||||
with self._timing_stats_registry_lock:
|
||||
count = len(self.timing_stats_registry)
|
||||
self.timing_stats_registry.clear()
|
||||
return count
|
||||
|
||||
def get_all_timing_stats(self) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
Get all timing stats as a dictionary for API endpoints.
|
||||
"""
|
||||
if (
|
||||
self.observability_config is None
|
||||
or not self.observability_config.enable_mm_processor_stats
|
||||
):
|
||||
return {}
|
||||
with self._timing_stats_registry_lock:
|
||||
return {
|
||||
rid: stats.to_dict()
|
||||
for rid, stats in self.timing_stats_registry.items()
|
||||
}
|
||||
|
||||
|
||||
class BaseProcessingInfo:
|
||||
"""Base class to provide the information necessary for data processing."""
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import numpy as np
|
||||
@@ -18,27 +17,14 @@ from vllm.config.multimodal import (
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..inputs import MultiModalDataDict
|
||||
from ..parse import MultiModalDataItems
|
||||
from .context import BaseProcessingInfo
|
||||
from .inputs import ProcessorInputs
|
||||
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorInputs:
|
||||
"""
|
||||
Represents the keyword arguments to
|
||||
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
|
||||
"""
|
||||
|
||||
prompt: str | list[int]
|
||||
mm_items: MultiModalDataItems
|
||||
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
"""
|
||||
Abstract base class that constructs the dummy data to profile
|
||||
@@ -101,7 +87,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt=dummy_text,
|
||||
mm_items=dummy_mm_items,
|
||||
mm_data_items=dummy_mm_items,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
|
||||
70
vllm/multimodal/processing/inputs.py
Normal file
70
vllm/multimodal/processing/inputs.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..hasher import MultiModalHasher
|
||||
from ..inputs import MultiModalHashes
|
||||
from ..parse import MultiModalDataItems, MultiModalUUIDItems
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorInputs:
|
||||
"""
|
||||
Represents the keyword arguments to
|
||||
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
|
||||
"""
|
||||
|
||||
prompt: str | list[int]
|
||||
mm_data_items: MultiModalDataItems
|
||||
mm_uuid_items: MultiModalUUIDItems | None = None
|
||||
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
|
||||
def get_mm_hashes(self, model_id: str) -> MultiModalHashes:
|
||||
mm_data_items = self.mm_data_items
|
||||
mm_uuid_items = self.mm_uuid_items or {}
|
||||
hf_processor_mm_kwargs = self.hf_processor_mm_kwargs
|
||||
|
||||
mm_hashes: MultiModalHashes = {}
|
||||
hasher = MultiModalHasher
|
||||
|
||||
for modality, data_items in mm_data_items.items():
|
||||
if modality in mm_uuid_items:
|
||||
uuid_items = mm_uuid_items[modality]
|
||||
|
||||
# For None entries, compute a hash; otherwise, use provided ID.
|
||||
hashes: list[str] = []
|
||||
for i, item in enumerate(data_items.get_all_items_for_hash()):
|
||||
uuid_item = uuid_items[i]
|
||||
|
||||
# NOTE: Even if a uuid_item is provided, we still compute a hash
|
||||
# if `hf_processor_mm_kwargs` is provided.
|
||||
# This is because the processed multimodal inputs can be different
|
||||
# depending on the processor kwargs.
|
||||
if uuid_item is None or hf_processor_mm_kwargs:
|
||||
# NOTE: use provided hash string to hash with kwargs
|
||||
# if available for better performance.
|
||||
item = uuid_item if uuid_item is not None else item
|
||||
hashes.append(
|
||||
hasher.hash_kwargs(
|
||||
model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs,
|
||||
)
|
||||
)
|
||||
else:
|
||||
hashes.append(uuid_item)
|
||||
|
||||
mm_hashes[modality] = hashes
|
||||
else:
|
||||
mm_hashes[modality] = [
|
||||
hasher.hash_kwargs(
|
||||
model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs,
|
||||
)
|
||||
for item in data_items
|
||||
]
|
||||
|
||||
return mm_hashes
|
||||
@@ -23,7 +23,6 @@ from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
|
||||
|
||||
from ..hasher import MultiModalHasher
|
||||
from ..inputs import (
|
||||
MultiModalEncDecInputs,
|
||||
MultiModalFieldConfig,
|
||||
@@ -42,12 +41,9 @@ from ..parse import (
|
||||
MultiModalDataItems,
|
||||
MultiModalUUIDItems,
|
||||
)
|
||||
from .context import (
|
||||
BaseProcessingInfo,
|
||||
get_current_request_id,
|
||||
timed_preprocessor_operation,
|
||||
)
|
||||
from .context import BaseProcessingInfo, TimingContext
|
||||
from .dummy_inputs import BaseDummyInputsBuilder
|
||||
from .inputs import ProcessorInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
@@ -1017,13 +1013,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_uuid_items: MultiModalUUIDItems | None = None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
|
||||
) -> MultiModalInputs:
|
||||
return self.apply(
|
||||
processor_inputs = ProcessorInputs(
|
||||
prompt,
|
||||
mm_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs or {},
|
||||
)
|
||||
|
||||
return self.apply(processor_inputs, TimingContext(enabled=False))
|
||||
|
||||
@abstractmethod
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
@@ -1139,12 +1137,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
Call the HF processor on the prompt text and
|
||||
associated multi-modal data.
|
||||
"""
|
||||
with timed_preprocessor_operation(self.info.ctx, "hf_processor"):
|
||||
return self.info.ctx.call_hf_processor(
|
||||
self.info.get_hf_processor(**mm_kwargs),
|
||||
dict(text=prompt, **mm_data),
|
||||
dict(**mm_kwargs, **tok_kwargs),
|
||||
)
|
||||
return self.info.ctx.call_hf_processor(
|
||||
self.info.get_hf_processor(**mm_kwargs),
|
||||
dict(text=prompt, **mm_data),
|
||||
dict(**mm_kwargs, **tok_kwargs),
|
||||
)
|
||||
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
@@ -1306,60 +1303,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
return prompt_ids, mm_processed_data, False
|
||||
|
||||
def _hash_mm_items(
|
||||
self,
|
||||
mm_data_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalHashes:
|
||||
model_id = self.info.model_id
|
||||
|
||||
if mm_uuid_items is None:
|
||||
mm_uuid_items = {}
|
||||
|
||||
mm_hashes: MultiModalHashes = {}
|
||||
hasher = MultiModalHasher
|
||||
|
||||
for modality, data_items in mm_data_items.items():
|
||||
if modality in mm_uuid_items:
|
||||
uuid_items = mm_uuid_items[modality]
|
||||
|
||||
# For None entries, compute a hash; otherwise, use provided ID.
|
||||
hashes: list[str] = []
|
||||
for i, item in enumerate(data_items.get_all_items_for_hash()):
|
||||
uuid_item = uuid_items[i]
|
||||
|
||||
# NOTE: Even if a uuid_item is provided, we still compute a hash
|
||||
# if `hf_processor_mm_kwargs` is provided.
|
||||
# This is because the processed multimodal inputs can be different
|
||||
# depending on the processor kwargs.
|
||||
if uuid_item is None or hf_processor_mm_kwargs:
|
||||
# NOTE: use provided hash string to hash with kwargs
|
||||
# if available for better performance.
|
||||
item = uuid_item if uuid_item is not None else item
|
||||
hashes.append(
|
||||
hasher.hash_kwargs(
|
||||
model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs,
|
||||
)
|
||||
)
|
||||
else:
|
||||
hashes.append(uuid_item)
|
||||
|
||||
mm_hashes[modality] = hashes
|
||||
else:
|
||||
mm_hashes[modality] = [
|
||||
hasher.hash_kwargs(
|
||||
model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs,
|
||||
)
|
||||
for item in data_items
|
||||
]
|
||||
|
||||
return mm_hashes
|
||||
|
||||
def _get_cache_missing_items(
|
||||
self,
|
||||
cache: BaseMultiModalProcessorCache,
|
||||
@@ -1461,40 +1404,36 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
(
|
||||
prompt_ids,
|
||||
mm_processed_data,
|
||||
is_update_applied,
|
||||
) = self._apply_hf_processor_main(
|
||||
prompt=prompt,
|
||||
mm_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
enable_hf_prompt_update=True,
|
||||
)
|
||||
with timing_ctx.record("apply_hf_processor"):
|
||||
(
|
||||
prompt_ids,
|
||||
mm_processed_data,
|
||||
is_update_applied,
|
||||
) = self._apply_hf_processor_main(
|
||||
prompt=inputs.prompt,
|
||||
mm_items=inputs.mm_data_items,
|
||||
hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=inputs.tokenization_kwargs,
|
||||
enable_hf_prompt_update=True,
|
||||
)
|
||||
|
||||
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
|
||||
mm_processed_data,
|
||||
self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
|
||||
self._get_mm_fields_config(
|
||||
mm_processed_data, inputs.hf_processor_mm_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
# Use overrides if provided; fallback to data-dependent hashing.
|
||||
with timed_preprocessor_operation(self.info.ctx, "hashing"):
|
||||
mm_hashes = self._hash_mm_items(
|
||||
mm_data_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
with timing_ctx.record("get_mm_hashes"):
|
||||
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
|
||||
|
||||
mm_prompt_updates = self._get_mm_prompt_updates(
|
||||
mm_data_items,
|
||||
hf_processor_mm_kwargs,
|
||||
inputs.mm_data_items,
|
||||
inputs.hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
|
||||
@@ -1508,11 +1447,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
def _cached_apply_hf_processor(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
"""
|
||||
Apply the HF processor on the full prompt text,
|
||||
@@ -1520,59 +1456,50 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
"""
|
||||
cache = self.cache
|
||||
|
||||
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
|
||||
_, passthrough_data = self._get_hf_mm_data(inputs.mm_data_items)
|
||||
if cache is None or passthrough_data:
|
||||
return self._apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
mm_uuid_items=mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
return self._apply_hf_processor(inputs, timing_ctx)
|
||||
|
||||
with timed_preprocessor_operation(self.info.ctx, "hashing"):
|
||||
mm_hashes = self._hash_mm_items(
|
||||
mm_data_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
with timing_ctx.record("get_mm_hashes"):
|
||||
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
|
||||
|
||||
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
|
||||
with timing_ctx.record("get_cache_missing_items"):
|
||||
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
|
||||
cache=cache,
|
||||
mm_data_items=mm_data_items,
|
||||
mm_data_items=inputs.mm_data_items,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
|
||||
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
|
||||
# so we can't apply prompt updates until the new multimodal
|
||||
# items are combined with the cached multimodal items
|
||||
(
|
||||
prompt_ids,
|
||||
mm_missing_processed_data,
|
||||
is_update_applied,
|
||||
) = self._apply_hf_processor_main(
|
||||
prompt=prompt,
|
||||
mm_items=mm_missing_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
enable_hf_prompt_update=False,
|
||||
)
|
||||
with timing_ctx.record("apply_hf_processor"):
|
||||
(
|
||||
prompt_ids,
|
||||
mm_missing_processed_data,
|
||||
is_update_applied,
|
||||
) = self._apply_hf_processor_main(
|
||||
prompt=inputs.prompt,
|
||||
mm_items=mm_missing_data_items,
|
||||
hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=inputs.tokenization_kwargs,
|
||||
enable_hf_prompt_update=False,
|
||||
)
|
||||
|
||||
mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs(
|
||||
mm_missing_processed_data,
|
||||
self._get_mm_fields_config(
|
||||
mm_missing_processed_data, hf_processor_mm_kwargs
|
||||
mm_missing_processed_data, inputs.hf_processor_mm_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
mm_missing_prompt_updates = self._get_mm_prompt_updates(
|
||||
mm_missing_data_items,
|
||||
hf_processor_mm_kwargs,
|
||||
inputs.hf_processor_mm_kwargs,
|
||||
mm_missing_kwargs,
|
||||
)
|
||||
|
||||
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
|
||||
with timing_ctx.record("merge_mm_kwargs"):
|
||||
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
|
||||
cache,
|
||||
mm_hashes=mm_hashes,
|
||||
@@ -1742,11 +1669,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None = None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
@@ -1761,31 +1685,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
3. Extract information about the placeholder tokens from the
|
||||
processed token IDs.
|
||||
"""
|
||||
request_id = get_current_request_id()
|
||||
if request_id is not None:
|
||||
self.info.ctx.create_timing_stats(request_id)
|
||||
|
||||
if hf_processor_mm_kwargs is None:
|
||||
hf_processor_mm_kwargs = {}
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
|
||||
(
|
||||
prompt_ids,
|
||||
mm_info,
|
||||
is_update_applied,
|
||||
) = self._cached_apply_hf_processor(
|
||||
prompt,
|
||||
mm_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
) = self._cached_apply_hf_processor(inputs, timing_ctx)
|
||||
|
||||
# NOTE: tokenization_kwargs are not required to init processor
|
||||
with timed_preprocessor_operation(self.info.ctx, "prompt_update"):
|
||||
with timing_ctx.record("apply_prompt_updates"):
|
||||
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
|
||||
mm_items=mm_items,
|
||||
mm_items=inputs.mm_data_items,
|
||||
prompt_ids=prompt_ids,
|
||||
mm_kwargs=mm_info.kwargs,
|
||||
mm_prompt_updates=mm_info.prompt_updates,
|
||||
@@ -1851,11 +1760,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_uuid_items: MultiModalUUIDItems | None = None,
|
||||
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
inputs: ProcessorInputs,
|
||||
timing_ctx: TimingContext,
|
||||
) -> MultiModalEncDecInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
@@ -1864,17 +1770,22 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
2. Apply the HF processor on encoder prompt.
|
||||
3. Copy the input prompt text as decoder prompt inputs.
|
||||
"""
|
||||
encoder_prompt = self.create_encoder_prompt(prompt, mm_items)
|
||||
encoder_inputs = super().apply(
|
||||
encoder_prompt = self.create_encoder_prompt(
|
||||
inputs.prompt,
|
||||
inputs.mm_data_items,
|
||||
)
|
||||
encoder_processor_inputs = ProcessorInputs(
|
||||
encoder_prompt,
|
||||
mm_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
inputs.mm_data_items,
|
||||
inputs.mm_uuid_items,
|
||||
hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=inputs.tokenization_kwargs,
|
||||
)
|
||||
|
||||
encoder_inputs = super().apply(encoder_processor_inputs, timing_ctx)
|
||||
|
||||
return self._get_enc_dec_inputs(
|
||||
prompt=prompt,
|
||||
mm_items=mm_items,
|
||||
prompt=inputs.prompt,
|
||||
mm_items=inputs.mm_data_items,
|
||||
encoder_inputs=encoder_inputs,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
|
||||
|
||||
from vllm.config.observability import ObservabilityConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
|
||||
@@ -24,6 +25,7 @@ from .processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
TimingContext,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -174,32 +176,26 @@ class MultiModalRegistry:
|
||||
def _create_processing_ctx(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
observability_config: "ObservabilityConfig | None" = None,
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
) -> InputProcessingContext:
|
||||
if tokenizer is None:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
|
||||
return InputProcessingContext(
|
||||
model_config, tokenizer, observability_config=observability_config
|
||||
)
|
||||
return InputProcessingContext(model_config, tokenizer)
|
||||
|
||||
def _create_processing_info(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
observability_config: "ObservabilityConfig | None" = None,
|
||||
*,
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
) -> BaseProcessingInfo:
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = model_cls._processor_factory
|
||||
ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
return factories.info(ctx)
|
||||
|
||||
def create_processor(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
observability_config: "ObservabilityConfig | None" = None,
|
||||
*,
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
@@ -213,7 +209,7 @@ class MultiModalRegistry:
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = model_cls._processor_factory
|
||||
|
||||
ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
|
||||
return factories.build_processor(ctx, cache=cache)
|
||||
|
||||
@@ -242,10 +238,8 @@ class MultiModalRegistry:
|
||||
mm_options=mm_config.limit_per_prompt,
|
||||
)
|
||||
mm_inputs = processor.apply(
|
||||
prompt=processor_inputs.prompt,
|
||||
mm_items=processor_inputs.mm_items,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
||||
processor_inputs,
|
||||
timing_ctx=TimingContext(enabled=False),
|
||||
)
|
||||
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
@@ -335,3 +329,34 @@ class MultiModalRegistry:
|
||||
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache type: {cache_type!r}")
|
||||
|
||||
|
||||
class MultiModalTimingRegistry:
|
||||
def __init__(self, observability_config: "ObservabilityConfig | None") -> None:
|
||||
super().__init__()
|
||||
|
||||
if observability_config and observability_config.enable_mm_processor_stats:
|
||||
self._lock = threading.Lock()
|
||||
self._ctx_by_request_id = defaultdict[str, TimingContext](TimingContext)
|
||||
self._enabled = True
|
||||
else:
|
||||
self._enabled = False
|
||||
|
||||
def get(self, request_id: str) -> TimingContext:
|
||||
if not self._enabled:
|
||||
return TimingContext(enabled=False)
|
||||
|
||||
with self._lock:
|
||||
return self._ctx_by_request_id[request_id]
|
||||
|
||||
def stat(self) -> dict[str, dict[str, float]]:
|
||||
if not self._enabled:
|
||||
return {}
|
||||
|
||||
with self._lock:
|
||||
stats = {
|
||||
req_id: ctx.get_stats_dict()
|
||||
for req_id, ctx in self._ctx_by_request_id.items()
|
||||
}
|
||||
self._ctx_by_request_id.clear()
|
||||
return stats
|
||||
|
||||
@@ -85,13 +85,13 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
self._mm_cache_stats: MultiModalCacheStats | None = None
|
||||
if config.model_config.is_multimodal_model:
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
|
||||
from vllm.multimodal.registry import MultiModalTimingRegistry
|
||||
|
||||
mm_processor_cache = mm_registry.processor_cache_from_config(config)
|
||||
|
||||
with set_default_torch_num_threads():
|
||||
self.mm_processor = mm_registry.create_processor(
|
||||
config.model_config,
|
||||
config.observability_config,
|
||||
tokenizer=tokenizer,
|
||||
cache=mm_processor_cache,
|
||||
)
|
||||
@@ -102,6 +102,9 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
# This is used to generate internal request ID for MM processing
|
||||
# It has no relation to the request ID for engine core
|
||||
self._mm_req_counter = AtomicCounter()
|
||||
self._mm_timing_registry = MultiModalTimingRegistry(
|
||||
config.observability_config
|
||||
)
|
||||
|
||||
def get_tokenizer(self) -> _T:
|
||||
tokenizer = self.tokenizer
|
||||
@@ -534,7 +537,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
tokenization_kwargs: dict[str, Any] | None,
|
||||
) -> "MultiModalInputs":
|
||||
from vllm.multimodal.parse import parse_mm_uuids
|
||||
from vllm.multimodal.processing.context import set_request_id
|
||||
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
|
||||
|
||||
mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"
|
||||
|
||||
@@ -543,18 +546,21 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
mm_data_items = mm_processor.info.parse_mm_data(mm_data)
|
||||
mm_uuid_items = parse_mm_uuids(mm_uuids)
|
||||
|
||||
mm_uuids = self._process_mm_uuids(
|
||||
mm_uuid_items = self._process_mm_uuids(
|
||||
mm_data, mm_data_items, mm_uuid_items, mm_req_id
|
||||
)
|
||||
|
||||
with set_request_id(mm_req_id), set_default_torch_num_threads():
|
||||
mm_inputs = mm_processor.apply(
|
||||
prompt,
|
||||
mm_data_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
mm_processor_inputs = MMProcessorInputs(
|
||||
prompt,
|
||||
mm_data_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs or {},
|
||||
tokenization_kwargs=tokenization_kwargs or {},
|
||||
)
|
||||
mm_timing_ctx = self._mm_timing_registry.get(mm_req_id)
|
||||
|
||||
with set_default_torch_num_threads():
|
||||
mm_inputs = mm_processor.apply(mm_processor_inputs, mm_timing_ctx)
|
||||
|
||||
self.update_mm_cache_stats()
|
||||
|
||||
|
||||
@@ -6272,7 +6272,7 @@ class GPUModelRunner(
|
||||
self.encoder_timing_registry[req_id] = EncoderTimingStats()
|
||||
|
||||
stats = self.encoder_timing_registry[req_id]
|
||||
stats.encoder_forward_time += per_request_time
|
||||
stats.encoder_forward_secs += per_request_time
|
||||
stats.num_encoder_calls += 1
|
||||
|
||||
|
||||
@@ -6280,7 +6280,7 @@ class GPUModelRunner(
|
||||
class EncoderTimingStats:
|
||||
"""Per-request timing statistics for encoder forward pass."""
|
||||
|
||||
encoder_forward_time: float = 0.0
|
||||
encoder_forward_secs: float = 0.0
|
||||
"""Time spent in vision encoder forward pass (seconds)."""
|
||||
|
||||
num_encoder_calls: int = 0
|
||||
@@ -6288,6 +6288,6 @@ class EncoderTimingStats:
|
||||
|
||||
def to_dict(self) -> dict[str, float | int]:
|
||||
return {
|
||||
"encoder_forward_time": self.encoder_forward_time,
|
||||
"encoder_forward_secs": self.encoder_forward_secs,
|
||||
"num_encoder_calls": self.num_encoder_calls,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user