From 06b557ecd9881289afcfa43458fb07329c955921 Mon Sep 17 00:00:00 2001 From: Reagan Lee <96998476+reaganjlee@users.noreply.github.com> Date: Sat, 24 Jan 2026 00:24:44 -0800 Subject: [PATCH] feat(benchmark): add encoder forward pass benchmarking to mm-processor (#31655) Signed-off-by: Reagan Signed-off-by: Reagan Lee <96998476+reaganjlee@users.noreply.github.com> Co-authored-by: Hiroken. <105287758+HirokenOvo@users.noreply.github.com> --- vllm/benchmarks/mm_processor.py | 123 +++++++++++++++++------ vllm/multimodal/processing/context.py | 91 +++++++++++++++-- vllm/multimodal/processing/processor.py | 18 ++-- vllm/v1/worker/gpu_model_runner.py | 126 +++++++++++++++++++++--- vllm/v1/worker/gpu_worker.py | 4 + 5 files changed, 303 insertions(+), 59 deletions(-) diff --git a/vllm/benchmarks/mm_processor.py b/vllm/benchmarks/mm_processor.py index a5512f1e5..e6e205396 100644 --- a/vllm/benchmarks/mm_processor.py +++ b/vllm/benchmarks/mm_processor.py @@ -22,6 +22,10 @@ from typing import Any import numpy as np +from vllm.benchmarks.datasets import ( + MultiModalConversationDataset, + VisionArenaDataset, +) from vllm.benchmarks.throughput import get_requests from vllm.engine.arg_utils import EngineArgs from vllm.multimodal.processing.context import ( @@ -45,26 +49,21 @@ def collect_mm_processor_stats( """ all_stats = get_timing_stats_from_engine_client(llm_engine) - stats_by_stage = { - "hf_processor_time": [], - "hashing_time": [], - "cache_lookup_time": [], - "prompt_update_time": [], - "total_time": [], - } + stat_keys = [ + "hf_processor_time", + "hashing_time", + "cache_lookup_time", + "prompt_update_time", + "preprocessor_total_time", + "encoder_forward_time", + "num_encoder_calls", + ] + stats_by_stage = {key: [] for key in stat_keys} for stats_dict in all_stats.values(): - stats_by_stage["hf_processor_time"].append( - stats_dict.get("hf_processor_time", 0.0) - ) - stats_by_stage["hashing_time"].append(stats_dict.get("hashing_time", 0.0)) - stats_by_stage["cache_lookup_time"].append( - stats_dict.get("cache_lookup_time", 0.0) - ) - stats_by_stage["prompt_update_time"].append( - stats_dict.get("prompt_update_time", 0.0) - ) - stats_by_stage["total_time"].append(stats_dict.get("total_time", 0.0)) + for key in stat_keys: + if key in stats_dict: + stats_by_stage[key].append(stats_dict[key]) return stats_by_stage @@ -88,14 +87,14 @@ def calculate_mm_processor_metrics( } continue - times_ms = [t * 1000 for t in times] + is_count_metric = stage_name == "num_encoder_calls" + values = times if is_count_metric else [t * 1000 for t in times] + metrics[stage_name] = { - "mean": float(np.mean(times_ms)), - "median": float(np.median(times_ms)), - "std": float(np.std(times_ms)), - **{ - f"p{p}": float(np.percentile(times_ms, p)) for p in selected_percentiles - }, + "mean": float(np.mean(values)), + "median": float(np.median(values)), + "std": float(np.std(values)), + **{f"p{p}": float(np.percentile(values, p)) for p in selected_percentiles}, } return metrics @@ -114,6 +113,23 @@ def validate_args(args): if not hasattr(args, "max_loras"): args.max_loras = None + if args.dataset_name == "hf" and not args.dataset_path: + raise ValueError( + "--dataset-path is required when using --dataset-name hf. " + "For multimodal benchmarking, specify a dataset like " + "'lmarena-ai/VisionArena-Chat'." + ) + if args.dataset_name == "hf": + supported_mm_datasets = ( + VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() + | MultiModalConversationDataset.SUPPORTED_DATASET_PATHS + ) + if args.dataset_path not in supported_mm_datasets: + raise ValueError( + f"{args.dataset_path} is not a supported multimodal dataset. " + f"Supported multimodal datasets are: {sorted(supported_mm_datasets)}" + ) + def benchmark_multimodal_processor( args: argparse.Namespace, @@ -223,6 +239,17 @@ def benchmark_multimodal_processor( std_e2el_ms = 0.0 percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles] + encoder_summary = {} + if ( + "num_encoder_calls" in mm_stats_by_stage + and mm_stats_by_stage["num_encoder_calls"] + ): + encoder_calls = mm_stats_by_stage["num_encoder_calls"] + encoder_summary = { + "total_encoder_calls": int(sum(encoder_calls)), + "num_requests_with_encoder_calls": len(encoder_calls), + } + benchmark_result = { "completed": completed, "failed": failed, @@ -231,6 +258,7 @@ def benchmark_multimodal_processor( "std_e2el_ms": std_e2el_ms, "percentiles_e2el_ms": percentiles_e2el_ms, "mm_processor_stats": mm_processor_metrics, + "encoder_summary": encoder_summary, } return benchmark_result @@ -248,7 +276,7 @@ def add_cli_args(parser: argparse.ArgumentParser) -> None: "--dataset-name", type=str, default="random-mm", - choices=["random-mm", "random-rerank"], + choices=["random-mm", "hf"], help="Name of the dataset to benchmark on. Defaults to 'random-mm'.", ) parser.add_argument( @@ -266,6 +294,34 @@ def add_cli_args(parser: argparse.ArgumentParser) -> None: add_random_dataset_base_args(parser) add_random_multimodal_dataset_args(parser) + # HuggingFace dataset arguments + parser.add_argument( + "--dataset-path", + type=str, + default=None, + help="Path to the dataset file or HuggingFace dataset name " + "(e.g., 'yale-nlp/MMVU', 'lmarena-ai/VisionArena-Chat').", + ) + parser.add_argument( + "--hf-subset", + type=str, + default=None, + help="Subset of the HuggingFace dataset (optional).", + ) + parser.add_argument( + "--hf-split", + type=str, + default=None, + help="Split of the HuggingFace dataset (e.g., 'train', 'test', 'validation').", + ) + parser.add_argument( + "--output-len", + type=int, + default=None, + help="Output length for each request. " + "Overrides the default output lengths from the dataset.", + ) + parser.add_argument( "--output-json", type=str, @@ -296,14 +352,17 @@ def main(args: argparse.Namespace) -> None: print("=" * 80) if "mm_processor_stats" in result: - print("\nMM Processor Timing (ms):") + print("\nMM Processor Metrics:") selected_percentiles = [ float(p) for p in getattr(args, "metric_percentiles", "99").split(",") ] mm_data = [] for stage, metrics in result["mm_processor_stats"].items(): + is_count = stage == "num_encoder_calls" + unit = "" if is_count else " (ms)" + row = { - "Stage": stage, + "Stage": stage + unit, "Mean": f"{metrics['mean']:.2f}", "Median": f"{metrics['median']:.2f}", "Std": f"{metrics['std']:.2f}", @@ -315,6 +374,14 @@ def main(args: argparse.Namespace) -> None: mm_df = pd.DataFrame(mm_data) print(mm_df.to_string(index=False)) + if "encoder_summary" in result and result["encoder_summary"]: + total_calls = result["encoder_summary"]["total_encoder_calls"] + num_requests = result["encoder_summary"]["num_requests_with_encoder_calls"] + print( + f"\nSummary: {total_calls} total encoder calls " + f"across {num_requests} requests." + ) + if "mean_e2el_ms" in result: print("\nEnd-to-End Latency (ms):") selected_percentiles = [ diff --git a/vllm/multimodal/processing/context.py b/vllm/multimodal/processing/context.py index 8606a659c..d4894a984 100644 --- a/vllm/multimodal/processing/context.py +++ b/vllm/multimodal/processing/context.py @@ -75,8 +75,8 @@ class MultiModalProcessorTimingStats: prompt_update_time: float = 0.0 """Time spent applying prompt updates and finding placeholders (seconds).""" - total_time: float = 0.0 - """Total processing time (seconds).""" + preprocessor_total_time: float = 0.0 + """Total preprocessing time (seconds).""" def to_dict(self) -> dict[str, float]: """Convert stats to a dictionary for JSON serialization.""" @@ -85,7 +85,7 @@ class MultiModalProcessorTimingStats: "hashing_time": self.hashing_time, "cache_lookup_time": self.cache_lookup_time, "prompt_update_time": self.prompt_update_time, - "total_time": self.total_time, + "preprocessor_total_time": self.preprocessor_total_time, } @@ -93,13 +93,30 @@ def get_timing_stats_from_engine_client( engine_client: Any, ) -> dict[str, dict[str, float]]: """ - Get all timing stats from the context associated with the engine client. + Get all multimodal timing stats from the engine client. + + Collects both preprocessing stats (HF processor, hashing, cache lookup, + prompt update) and encoder forward pass timing, merged by request_id. Args: - engine_client: The engine client that has input_processor. + engine_client: The engine client (has input_processor and workers). Returns: - A dictionary mapping request_id to stats dict. + Dictionary mapping request_id to merged stats dict containing + both preprocessing and encoder timing metrics. + + Example: + { + 'request-123': { + 'hf_processor_time': 0.45, + 'hashing_time': 0.02, + 'cache_lookup_time': 0.01, + 'prompt_update_time': 0.03, + 'preprocessor_total_time': 0.51, + 'encoder_forward_time': 0.23, + 'num_encoder_calls': 1 + } + } """ try: if not engine_client.vllm_config.observability_config.enable_mm_processor_stats: @@ -107,6 +124,7 @@ def get_timing_stats_from_engine_client( except (AttributeError, RuntimeError): return {} + preprocessing_stats = {} try: input_processor = engine_client.input_processor input_preprocessor = input_processor.input_preprocessor @@ -115,15 +133,68 @@ def get_timing_stats_from_engine_client( mm_processor = input_preprocessor._get_mm_processor() if mm_processor is not None and hasattr(mm_processor, "info"): ctx = mm_processor.info.ctx - return ctx.get_all_timing_stats() + preprocessing_stats = ctx.get_all_timing_stats() except (AttributeError, RuntimeError): pass - return {} + encoder_stats = {} + try: + if hasattr(engine_client, "collective_rpc"): + encoder_stats_results = engine_client.collective_rpc( + "get_encoder_timing_stats" + ) + if encoder_stats_results and len(encoder_stats_results) > 0: + for worker_stats in encoder_stats_results: + if not worker_stats: + continue + for request_id, stats_dict in worker_stats.items(): + if request_id not in encoder_stats: + encoder_stats[request_id] = dict(stats_dict) + else: + # Aggregate timing metrics across workers + current_time = encoder_stats[request_id].get( + "encoder_forward_time", 0.0 + ) + new_time = stats_dict.get("encoder_forward_time", 0.0) + encoder_stats[request_id]["encoder_forward_time"] = max( + current_time, new_time + ) + + current_calls = encoder_stats[request_id].get( + "num_encoder_calls", 0 + ) + new_calls = stats_dict.get("num_encoder_calls", 0) + encoder_stats[request_id]["num_encoder_calls"] = max( + current_calls, new_calls + ) + except (AttributeError, RuntimeError): + pass + + merged_stats = {} + + for request_id, prep_dict in preprocessing_stats.items(): + merged_stats[request_id] = dict(prep_dict) + + for request_id, enc_dict in encoder_stats.items(): + if request_id in merged_stats: + merged_stats[request_id].update(enc_dict) + continue + + # In V1 engine, the request_id in encoder_stats has a suffix + # appended to the original request_id (which is used in + # preprocessing_stats). + # We try to strip the suffix to find the matching request. + possible_original_id = request_id.rpartition("-")[0] + if possible_original_id and possible_original_id in merged_stats: + merged_stats[possible_original_id].update(enc_dict) + else: + merged_stats[request_id] = dict(enc_dict) + + return merged_stats @contextmanager -def timed_operation(ctx: "InputProcessingContext", stage_name: str): +def timed_preprocessor_operation(ctx: "InputProcessingContext", stage_name: str): """ Context manager to time an operation using the context's timing stats. @@ -157,7 +228,7 @@ def timed_operation(ctx: "InputProcessingContext", stage_name: str): stats.cache_lookup_time += elapsed elif stage_name == "prompt_update": stats.prompt_update_time += elapsed - stats.total_time += elapsed + stats.preprocessor_total_time += elapsed _T = TypeVar("_T") diff --git a/vllm/multimodal/processing/processor.py b/vllm/multimodal/processing/processor.py index 3aee5ed8b..1b8039f76 100644 --- a/vllm/multimodal/processing/processor.py +++ b/vllm/multimodal/processing/processor.py @@ -42,7 +42,11 @@ from ..parse import ( MultiModalDataItems, MultiModalDataParser, ) -from .context import BaseProcessingInfo, get_current_request_id, timed_operation +from .context import ( + BaseProcessingInfo, + get_current_request_id, + timed_preprocessor_operation, +) from .dummy_inputs import BaseDummyInputsBuilder if TYPE_CHECKING: @@ -1192,7 +1196,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): Call the HF processor on the prompt text and associated multi-modal data. """ - with timed_operation(self.info.ctx, "hf_processor"): + with timed_preprocessor_operation(self.info.ctx, "hf_processor"): return self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), dict(text=prompt, **mm_data), @@ -1545,7 +1549,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ) # Use overrides if provided; fallback to data-dependent hashing. - with timed_operation(self.info.ctx, "hashing"): + with timed_preprocessor_operation(self.info.ctx, "hashing"): mm_hashes = self._hash_mm_items( mm_data_items, hf_processor_mm_kwargs, @@ -1592,7 +1596,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_uuids=mm_uuids, ) - with timed_operation(self.info.ctx, "hashing"): + with timed_preprocessor_operation(self.info.ctx, "hashing"): mm_hashes = self._hash_mm_items( mm_data_items, hf_processor_mm_kwargs, @@ -1600,7 +1604,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_uuids=mm_uuids, ) - with timed_operation(self.info.ctx, "cache_lookup"): + with timed_preprocessor_operation(self.info.ctx, "cache_lookup"): mm_is_cached, mm_missing_data_items = self._get_cache_missing_items( cache=cache, mm_data_items=mm_data_items, @@ -1635,7 +1639,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_missing_kwargs, ) - with timed_operation(self.info.ctx, "cache_lookup"): + with timed_preprocessor_operation(self.info.ctx, "cache_lookup"): mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( cache, mm_hashes=mm_hashes, @@ -1846,7 +1850,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ) # NOTE: tokenization_kwargs are not required to init processor - with timed_operation(self.info.ctx, "prompt_update"): + with timed_preprocessor_operation(self.info.ctx, "prompt_update"): prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates( mm_items=mm_items, prompt_ids=prompt_ids, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b46fc175d..60d1c88d2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4,11 +4,13 @@ import functools import gc import itertools +import threading import time from collections import defaultdict from collections.abc import Iterator, Sequence from contextlib import contextmanager from copy import copy, deepcopy +from dataclasses import dataclass from functools import reduce from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast @@ -546,6 +548,10 @@ class GPUModelRunner( # Cache the device properties. self._init_device_properties() + # Encoder timing registry for observability + self.encoder_timing_registry: dict[str, EncoderTimingStats] = {} + self._encoder_timing_lock = threading.Lock() + # Persistent buffers for CUDA graphs. self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) @@ -2212,6 +2218,12 @@ class GPUModelRunner( if not mm_kwargs: return [] + should_time = bool( + self.observability_config + and self.observability_config.enable_mm_processor_stats + and scheduler_output.scheduled_encoder_inputs + ) + # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, # we process it separately to preserve item order. @@ -2278,6 +2290,8 @@ class GPUModelRunner( ) encoder_outputs: list[torch.Tensor] = [] + # Track the current index in mm_kwargs/mm_lora_refs to map groups to request IDs + current_item_idx = 0 for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( mm_kwargs, device=self.device, @@ -2300,22 +2314,24 @@ class GPUModelRunner( and num_items > 1 ): curr_group_outputs_lst = list[torch.Tensor]() - for video_mm_kwargs_item in filter( - lambda item: item.modality == "video", mm_kwargs - ): - _, _, micro_batch_mm_inputs = next( - group_mm_kwargs_by_modality( - [video_mm_kwargs_item], - device=self.device, - pin_memory=self.pin_memory, + for video_idx in range(num_items): + video_mm_kwargs_item = mm_kwargs[current_item_idx + video_idx] + with self.timed_encoder_operation( + should_time, mm_lora_refs, current_item_idx + video_idx, 1 + ): + _, _, micro_batch_mm_inputs = next( + group_mm_kwargs_by_modality( + [video_mm_kwargs_item], + device=self.device, + pin_memory=self.pin_memory, + ) ) - ) - micro_batch_outputs = model.embed_multimodal( - **micro_batch_mm_inputs - ) + micro_batch_outputs = model.embed_multimodal( + **micro_batch_mm_inputs + ) - curr_group_outputs_lst.extend(micro_batch_outputs) + curr_group_outputs_lst.extend(micro_batch_outputs) curr_group_outputs = curr_group_outputs_lst else: @@ -2326,7 +2342,11 @@ class GPUModelRunner( # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) + + with self.timed_encoder_operation( + should_time, mm_lora_refs, current_item_idx, num_items + ): + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -2334,6 +2354,8 @@ class GPUModelRunner( ) encoder_outputs.extend(curr_group_outputs) + current_item_idx += num_items + # Cache the encoder outputs by mm_hash for mm_hash, output in zip(mm_hashes, encoder_outputs): self.encoder_cache[mm_hash] = output @@ -5919,3 +5941,79 @@ class GPUModelRunner( self.transfer_event.record() self.transfer_event.synchronize() return pinned.tolist() + + def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]: + """ + Get encoder timing stats for all requests and clear the registry. + + Returns: + Dictionary mapping request_id to stats dict. + """ + with self._encoder_timing_lock: + stats = { + req_id: stats_obj.to_dict() + for req_id, stats_obj in self.encoder_timing_registry.items() + } + self.encoder_timing_registry.clear() + return stats + + @contextmanager + def timed_encoder_operation( + self, + should_time: bool, + group_lora_refs: list[tuple[str, Any]], + current_item_idx: int, + num_items: int, + ): + """ + Context manager to time encoder forward operations. + + Args: + should_time: Whether timing is enabled + group_lora_refs: Full list of (request_id, pos_info) tuples + current_item_idx: Starting index for this group + num_items: Number of items in this group + """ + if not should_time: + yield + return + + group_refs = group_lora_refs[current_item_idx : current_item_idx + num_items] + group_request_ids = {req_id for req_id, _ in group_refs} + + torch.cuda.synchronize() + start_time = time.perf_counter() + + try: + yield + finally: + torch.cuda.synchronize() + elapsed = time.perf_counter() - start_time + + per_request_time = elapsed / max(len(group_request_ids), 1) + + with self._encoder_timing_lock: + for req_id in group_request_ids: + if req_id not in self.encoder_timing_registry: + self.encoder_timing_registry[req_id] = EncoderTimingStats() + + stats = self.encoder_timing_registry[req_id] + stats.encoder_forward_time += per_request_time + stats.num_encoder_calls += 1 + + +@dataclass +class EncoderTimingStats: + """Per-request timing statistics for encoder forward pass.""" + + encoder_forward_time: float = 0.0 + """Time spent in vision encoder forward pass (seconds).""" + + num_encoder_calls: int = 0 + """Number of times encoder was called for this request.""" + + def to_dict(self) -> dict[str, float | int]: + return { + "encoder_forward_time": self.encoder_forward_time, + "num_encoder_calls": self.num_encoder_calls, + } diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index db4cb45e2..4c02c0598 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -542,6 +542,10 @@ class Worker(WorkerBase): def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.model_runner.get_supported_tasks() + def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]: + """Get encoder timing stats from model runner.""" + return self.model_runner.get_encoder_timing_stats() + def annotate_profile(self, scheduler_output): # add trace annotation so that we can easily distinguish # context/generation request numbers in each iteration.