feat(benchmark): add encoder forward pass benchmarking to mm-processor (#31655)

Signed-off-by: Reagan <reaganjlee@gmail.com>
Signed-off-by: Reagan Lee <96998476+reaganjlee@users.noreply.github.com>
Co-authored-by: Hiroken. <105287758+HirokenOvo@users.noreply.github.com>
This commit is contained in:
Reagan Lee
2026-01-24 00:24:44 -08:00
committed by GitHub
parent 81c2a889ce
commit 06b557ecd9
5 changed files with 303 additions and 59 deletions

View File

@@ -22,6 +22,10 @@ from typing import Any
import numpy as np import numpy as np
from vllm.benchmarks.datasets import (
MultiModalConversationDataset,
VisionArenaDataset,
)
from vllm.benchmarks.throughput import get_requests from vllm.benchmarks.throughput import get_requests
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.multimodal.processing.context import ( from vllm.multimodal.processing.context import (
@@ -45,26 +49,21 @@ def collect_mm_processor_stats(
""" """
all_stats = get_timing_stats_from_engine_client(llm_engine) all_stats = get_timing_stats_from_engine_client(llm_engine)
stats_by_stage = { stat_keys = [
"hf_processor_time": [], "hf_processor_time",
"hashing_time": [], "hashing_time",
"cache_lookup_time": [], "cache_lookup_time",
"prompt_update_time": [], "prompt_update_time",
"total_time": [], "preprocessor_total_time",
} "encoder_forward_time",
"num_encoder_calls",
]
stats_by_stage = {key: [] for key in stat_keys}
for stats_dict in all_stats.values(): for stats_dict in all_stats.values():
stats_by_stage["hf_processor_time"].append( for key in stat_keys:
stats_dict.get("hf_processor_time", 0.0) if key in stats_dict:
) stats_by_stage[key].append(stats_dict[key])
stats_by_stage["hashing_time"].append(stats_dict.get("hashing_time", 0.0))
stats_by_stage["cache_lookup_time"].append(
stats_dict.get("cache_lookup_time", 0.0)
)
stats_by_stage["prompt_update_time"].append(
stats_dict.get("prompt_update_time", 0.0)
)
stats_by_stage["total_time"].append(stats_dict.get("total_time", 0.0))
return stats_by_stage return stats_by_stage
@@ -88,14 +87,14 @@ def calculate_mm_processor_metrics(
} }
continue continue
times_ms = [t * 1000 for t in times] is_count_metric = stage_name == "num_encoder_calls"
values = times if is_count_metric else [t * 1000 for t in times]
metrics[stage_name] = { metrics[stage_name] = {
"mean": float(np.mean(times_ms)), "mean": float(np.mean(values)),
"median": float(np.median(times_ms)), "median": float(np.median(values)),
"std": float(np.std(times_ms)), "std": float(np.std(values)),
**{ **{f"p{p}": float(np.percentile(values, p)) for p in selected_percentiles},
f"p{p}": float(np.percentile(times_ms, p)) for p in selected_percentiles
},
} }
return metrics return metrics
@@ -114,6 +113,23 @@ def validate_args(args):
if not hasattr(args, "max_loras"): if not hasattr(args, "max_loras"):
args.max_loras = None args.max_loras = None
if args.dataset_name == "hf" and not args.dataset_path:
raise ValueError(
"--dataset-path is required when using --dataset-name hf. "
"For multimodal benchmarking, specify a dataset like "
"'lmarena-ai/VisionArena-Chat'."
)
if args.dataset_name == "hf":
supported_mm_datasets = (
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
)
if args.dataset_path not in supported_mm_datasets:
raise ValueError(
f"{args.dataset_path} is not a supported multimodal dataset. "
f"Supported multimodal datasets are: {sorted(supported_mm_datasets)}"
)
def benchmark_multimodal_processor( def benchmark_multimodal_processor(
args: argparse.Namespace, args: argparse.Namespace,
@@ -223,6 +239,17 @@ def benchmark_multimodal_processor(
std_e2el_ms = 0.0 std_e2el_ms = 0.0
percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles] percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles]
encoder_summary = {}
if (
"num_encoder_calls" in mm_stats_by_stage
and mm_stats_by_stage["num_encoder_calls"]
):
encoder_calls = mm_stats_by_stage["num_encoder_calls"]
encoder_summary = {
"total_encoder_calls": int(sum(encoder_calls)),
"num_requests_with_encoder_calls": len(encoder_calls),
}
benchmark_result = { benchmark_result = {
"completed": completed, "completed": completed,
"failed": failed, "failed": failed,
@@ -231,6 +258,7 @@ def benchmark_multimodal_processor(
"std_e2el_ms": std_e2el_ms, "std_e2el_ms": std_e2el_ms,
"percentiles_e2el_ms": percentiles_e2el_ms, "percentiles_e2el_ms": percentiles_e2el_ms,
"mm_processor_stats": mm_processor_metrics, "mm_processor_stats": mm_processor_metrics,
"encoder_summary": encoder_summary,
} }
return benchmark_result return benchmark_result
@@ -248,7 +276,7 @@ def add_cli_args(parser: argparse.ArgumentParser) -> None:
"--dataset-name", "--dataset-name",
type=str, type=str,
default="random-mm", default="random-mm",
choices=["random-mm", "random-rerank"], choices=["random-mm", "hf"],
help="Name of the dataset to benchmark on. Defaults to 'random-mm'.", help="Name of the dataset to benchmark on. Defaults to 'random-mm'.",
) )
parser.add_argument( parser.add_argument(
@@ -266,6 +294,34 @@ def add_cli_args(parser: argparse.ArgumentParser) -> None:
add_random_dataset_base_args(parser) add_random_dataset_base_args(parser)
add_random_multimodal_dataset_args(parser) add_random_multimodal_dataset_args(parser)
# HuggingFace dataset arguments
parser.add_argument(
"--dataset-path",
type=str,
default=None,
help="Path to the dataset file or HuggingFace dataset name "
"(e.g., 'yale-nlp/MMVU', 'lmarena-ai/VisionArena-Chat').",
)
parser.add_argument(
"--hf-subset",
type=str,
default=None,
help="Subset of the HuggingFace dataset (optional).",
)
parser.add_argument(
"--hf-split",
type=str,
default=None,
help="Split of the HuggingFace dataset (e.g., 'train', 'test', 'validation').",
)
parser.add_argument(
"--output-len",
type=int,
default=None,
help="Output length for each request. "
"Overrides the default output lengths from the dataset.",
)
parser.add_argument( parser.add_argument(
"--output-json", "--output-json",
type=str, type=str,
@@ -296,14 +352,17 @@ def main(args: argparse.Namespace) -> None:
print("=" * 80) print("=" * 80)
if "mm_processor_stats" in result: if "mm_processor_stats" in result:
print("\nMM Processor Timing (ms):") print("\nMM Processor Metrics:")
selected_percentiles = [ selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",") float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
] ]
mm_data = [] mm_data = []
for stage, metrics in result["mm_processor_stats"].items(): for stage, metrics in result["mm_processor_stats"].items():
is_count = stage == "num_encoder_calls"
unit = "" if is_count else " (ms)"
row = { row = {
"Stage": stage, "Stage": stage + unit,
"Mean": f"{metrics['mean']:.2f}", "Mean": f"{metrics['mean']:.2f}",
"Median": f"{metrics['median']:.2f}", "Median": f"{metrics['median']:.2f}",
"Std": f"{metrics['std']:.2f}", "Std": f"{metrics['std']:.2f}",
@@ -315,6 +374,14 @@ def main(args: argparse.Namespace) -> None:
mm_df = pd.DataFrame(mm_data) mm_df = pd.DataFrame(mm_data)
print(mm_df.to_string(index=False)) print(mm_df.to_string(index=False))
if "encoder_summary" in result and result["encoder_summary"]:
total_calls = result["encoder_summary"]["total_encoder_calls"]
num_requests = result["encoder_summary"]["num_requests_with_encoder_calls"]
print(
f"\nSummary: {total_calls} total encoder calls "
f"across {num_requests} requests."
)
if "mean_e2el_ms" in result: if "mean_e2el_ms" in result:
print("\nEnd-to-End Latency (ms):") print("\nEnd-to-End Latency (ms):")
selected_percentiles = [ selected_percentiles = [

View File

@@ -75,8 +75,8 @@ class MultiModalProcessorTimingStats:
prompt_update_time: float = 0.0 prompt_update_time: float = 0.0
"""Time spent applying prompt updates and finding placeholders (seconds).""" """Time spent applying prompt updates and finding placeholders (seconds)."""
total_time: float = 0.0 preprocessor_total_time: float = 0.0
"""Total processing time (seconds).""" """Total preprocessing time (seconds)."""
def to_dict(self) -> dict[str, float]: def to_dict(self) -> dict[str, float]:
"""Convert stats to a dictionary for JSON serialization.""" """Convert stats to a dictionary for JSON serialization."""
@@ -85,7 +85,7 @@ class MultiModalProcessorTimingStats:
"hashing_time": self.hashing_time, "hashing_time": self.hashing_time,
"cache_lookup_time": self.cache_lookup_time, "cache_lookup_time": self.cache_lookup_time,
"prompt_update_time": self.prompt_update_time, "prompt_update_time": self.prompt_update_time,
"total_time": self.total_time, "preprocessor_total_time": self.preprocessor_total_time,
} }
@@ -93,13 +93,30 @@ def get_timing_stats_from_engine_client(
engine_client: Any, engine_client: Any,
) -> dict[str, dict[str, float]]: ) -> dict[str, dict[str, float]]:
""" """
Get all timing stats from the context associated with the engine client. Get all multimodal timing stats from the engine client.
Collects both preprocessing stats (HF processor, hashing, cache lookup,
prompt update) and encoder forward pass timing, merged by request_id.
Args: Args:
engine_client: The engine client that has input_processor. engine_client: The engine client (has input_processor and workers).
Returns: Returns:
A dictionary mapping request_id to stats dict. Dictionary mapping request_id to merged stats dict containing
both preprocessing and encoder timing metrics.
Example:
{
'request-123': {
'hf_processor_time': 0.45,
'hashing_time': 0.02,
'cache_lookup_time': 0.01,
'prompt_update_time': 0.03,
'preprocessor_total_time': 0.51,
'encoder_forward_time': 0.23,
'num_encoder_calls': 1
}
}
""" """
try: try:
if not engine_client.vllm_config.observability_config.enable_mm_processor_stats: if not engine_client.vllm_config.observability_config.enable_mm_processor_stats:
@@ -107,6 +124,7 @@ def get_timing_stats_from_engine_client(
except (AttributeError, RuntimeError): except (AttributeError, RuntimeError):
return {} return {}
preprocessing_stats = {}
try: try:
input_processor = engine_client.input_processor input_processor = engine_client.input_processor
input_preprocessor = input_processor.input_preprocessor input_preprocessor = input_processor.input_preprocessor
@@ -115,15 +133,68 @@ def get_timing_stats_from_engine_client(
mm_processor = input_preprocessor._get_mm_processor() mm_processor = input_preprocessor._get_mm_processor()
if mm_processor is not None and hasattr(mm_processor, "info"): if mm_processor is not None and hasattr(mm_processor, "info"):
ctx = mm_processor.info.ctx ctx = mm_processor.info.ctx
return ctx.get_all_timing_stats() preprocessing_stats = ctx.get_all_timing_stats()
except (AttributeError, RuntimeError): except (AttributeError, RuntimeError):
pass pass
return {} encoder_stats = {}
try:
if hasattr(engine_client, "collective_rpc"):
encoder_stats_results = engine_client.collective_rpc(
"get_encoder_timing_stats"
)
if encoder_stats_results and len(encoder_stats_results) > 0:
for worker_stats in encoder_stats_results:
if not worker_stats:
continue
for request_id, stats_dict in worker_stats.items():
if request_id not in encoder_stats:
encoder_stats[request_id] = dict(stats_dict)
else:
# Aggregate timing metrics across workers
current_time = encoder_stats[request_id].get(
"encoder_forward_time", 0.0
)
new_time = stats_dict.get("encoder_forward_time", 0.0)
encoder_stats[request_id]["encoder_forward_time"] = max(
current_time, new_time
)
current_calls = encoder_stats[request_id].get(
"num_encoder_calls", 0
)
new_calls = stats_dict.get("num_encoder_calls", 0)
encoder_stats[request_id]["num_encoder_calls"] = max(
current_calls, new_calls
)
except (AttributeError, RuntimeError):
pass
merged_stats = {}
for request_id, prep_dict in preprocessing_stats.items():
merged_stats[request_id] = dict(prep_dict)
for request_id, enc_dict in encoder_stats.items():
if request_id in merged_stats:
merged_stats[request_id].update(enc_dict)
continue
# In V1 engine, the request_id in encoder_stats has a suffix
# appended to the original request_id (which is used in
# preprocessing_stats).
# We try to strip the suffix to find the matching request.
possible_original_id = request_id.rpartition("-")[0]
if possible_original_id and possible_original_id in merged_stats:
merged_stats[possible_original_id].update(enc_dict)
else:
merged_stats[request_id] = dict(enc_dict)
return merged_stats
@contextmanager @contextmanager
def timed_operation(ctx: "InputProcessingContext", stage_name: str): def timed_preprocessor_operation(ctx: "InputProcessingContext", stage_name: str):
""" """
Context manager to time an operation using the context's timing stats. Context manager to time an operation using the context's timing stats.
@@ -157,7 +228,7 @@ def timed_operation(ctx: "InputProcessingContext", stage_name: str):
stats.cache_lookup_time += elapsed stats.cache_lookup_time += elapsed
elif stage_name == "prompt_update": elif stage_name == "prompt_update":
stats.prompt_update_time += elapsed stats.prompt_update_time += elapsed
stats.total_time += elapsed stats.preprocessor_total_time += elapsed
_T = TypeVar("_T") _T = TypeVar("_T")

View File

@@ -42,7 +42,11 @@ from ..parse import (
MultiModalDataItems, MultiModalDataItems,
MultiModalDataParser, MultiModalDataParser,
) )
from .context import BaseProcessingInfo, get_current_request_id, timed_operation from .context import (
BaseProcessingInfo,
get_current_request_id,
timed_preprocessor_operation,
)
from .dummy_inputs import BaseDummyInputsBuilder from .dummy_inputs import BaseDummyInputsBuilder
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -1192,7 +1196,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Call the HF processor on the prompt text and Call the HF processor on the prompt text and
associated multi-modal data. associated multi-modal data.
""" """
with timed_operation(self.info.ctx, "hf_processor"): with timed_preprocessor_operation(self.info.ctx, "hf_processor"):
return self.info.ctx.call_hf_processor( return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs), self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data), dict(text=prompt, **mm_data),
@@ -1545,7 +1549,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) )
# Use overrides if provided; fallback to data-dependent hashing. # Use overrides if provided; fallback to data-dependent hashing.
with timed_operation(self.info.ctx, "hashing"): with timed_preprocessor_operation(self.info.ctx, "hashing"):
mm_hashes = self._hash_mm_items( mm_hashes = self._hash_mm_items(
mm_data_items, mm_data_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
@@ -1592,7 +1596,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
with timed_operation(self.info.ctx, "hashing"): with timed_preprocessor_operation(self.info.ctx, "hashing"):
mm_hashes = self._hash_mm_items( mm_hashes = self._hash_mm_items(
mm_data_items, mm_data_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
@@ -1600,7 +1604,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
with timed_operation(self.info.ctx, "cache_lookup"): with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items( mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
cache=cache, cache=cache,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
@@ -1635,7 +1639,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_kwargs, mm_missing_kwargs,
) )
with timed_operation(self.info.ctx, "cache_lookup"): with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
cache, cache,
mm_hashes=mm_hashes, mm_hashes=mm_hashes,
@@ -1846,7 +1850,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) )
# NOTE: tokenization_kwargs are not required to init processor # NOTE: tokenization_kwargs are not required to init processor
with timed_operation(self.info.ctx, "prompt_update"): with timed_preprocessor_operation(self.info.ctx, "prompt_update"):
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates( prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items, mm_items=mm_items,
prompt_ids=prompt_ids, prompt_ids=prompt_ids,

View File

@@ -4,11 +4,13 @@
import functools import functools
import gc import gc
import itertools import itertools
import threading
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator, Sequence from collections.abc import Iterator, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass
from functools import reduce from functools import reduce
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
@@ -546,6 +548,10 @@ class GPUModelRunner(
# Cache the device properties. # Cache the device properties.
self._init_device_properties() self._init_device_properties()
# Encoder timing registry for observability
self.encoder_timing_registry: dict[str, EncoderTimingStats] = {}
self._encoder_timing_lock = threading.Lock()
# Persistent buffers for CUDA graphs. # Persistent buffers for CUDA graphs.
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
@@ -2212,6 +2218,12 @@ class GPUModelRunner(
if not mm_kwargs: if not mm_kwargs:
return [] return []
should_time = bool(
self.observability_config
and self.observability_config.enable_mm_processor_stats
and scheduler_output.scheduled_encoder_inputs
)
# Batch mm inputs as much as we can: if a request in the batch has # Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one, # multiple modalities or a different modality than the previous one,
# we process it separately to preserve item order. # we process it separately to preserve item order.
@@ -2278,6 +2290,8 @@ class GPUModelRunner(
) )
encoder_outputs: list[torch.Tensor] = [] encoder_outputs: list[torch.Tensor] = []
# Track the current index in mm_kwargs/mm_lora_refs to map groups to request IDs
current_item_idx = 0
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, mm_kwargs,
device=self.device, device=self.device,
@@ -2300,22 +2314,24 @@ class GPUModelRunner(
and num_items > 1 and num_items > 1
): ):
curr_group_outputs_lst = list[torch.Tensor]() curr_group_outputs_lst = list[torch.Tensor]()
for video_mm_kwargs_item in filter( for video_idx in range(num_items):
lambda item: item.modality == "video", mm_kwargs video_mm_kwargs_item = mm_kwargs[current_item_idx + video_idx]
): with self.timed_encoder_operation(
_, _, micro_batch_mm_inputs = next( should_time, mm_lora_refs, current_item_idx + video_idx, 1
group_mm_kwargs_by_modality( ):
[video_mm_kwargs_item], _, _, micro_batch_mm_inputs = next(
device=self.device, group_mm_kwargs_by_modality(
pin_memory=self.pin_memory, [video_mm_kwargs_item],
device=self.device,
pin_memory=self.pin_memory,
)
) )
)
micro_batch_outputs = model.embed_multimodal( micro_batch_outputs = model.embed_multimodal(
**micro_batch_mm_inputs **micro_batch_mm_inputs
) )
curr_group_outputs_lst.extend(micro_batch_outputs) curr_group_outputs_lst.extend(micro_batch_outputs)
curr_group_outputs = curr_group_outputs_lst curr_group_outputs = curr_group_outputs_lst
else: else:
@@ -2326,7 +2342,11 @@ class GPUModelRunner(
# 2. A list or tuple (length: num_items) of tensors, # 2. A list or tuple (length: num_items) of tensors,
# each of shape (feature_size, hidden_size) in case the feature # each of shape (feature_size, hidden_size) in case the feature
# size is dynamic depending on the input multimodal items. # size is dynamic depending on the input multimodal items.
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
with self.timed_encoder_operation(
should_time, mm_lora_refs, current_item_idx, num_items
):
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
curr_group_outputs, curr_group_outputs,
@@ -2334,6 +2354,8 @@ class GPUModelRunner(
) )
encoder_outputs.extend(curr_group_outputs) encoder_outputs.extend(curr_group_outputs)
current_item_idx += num_items
# Cache the encoder outputs by mm_hash # Cache the encoder outputs by mm_hash
for mm_hash, output in zip(mm_hashes, encoder_outputs): for mm_hash, output in zip(mm_hashes, encoder_outputs):
self.encoder_cache[mm_hash] = output self.encoder_cache[mm_hash] = output
@@ -5919,3 +5941,79 @@ class GPUModelRunner(
self.transfer_event.record() self.transfer_event.record()
self.transfer_event.synchronize() self.transfer_event.synchronize()
return pinned.tolist() return pinned.tolist()
def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
"""
Get encoder timing stats for all requests and clear the registry.
Returns:
Dictionary mapping request_id to stats dict.
"""
with self._encoder_timing_lock:
stats = {
req_id: stats_obj.to_dict()
for req_id, stats_obj in self.encoder_timing_registry.items()
}
self.encoder_timing_registry.clear()
return stats
@contextmanager
def timed_encoder_operation(
self,
should_time: bool,
group_lora_refs: list[tuple[str, Any]],
current_item_idx: int,
num_items: int,
):
"""
Context manager to time encoder forward operations.
Args:
should_time: Whether timing is enabled
group_lora_refs: Full list of (request_id, pos_info) tuples
current_item_idx: Starting index for this group
num_items: Number of items in this group
"""
if not should_time:
yield
return
group_refs = group_lora_refs[current_item_idx : current_item_idx + num_items]
group_request_ids = {req_id for req_id, _ in group_refs}
torch.cuda.synchronize()
start_time = time.perf_counter()
try:
yield
finally:
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
per_request_time = elapsed / max(len(group_request_ids), 1)
with self._encoder_timing_lock:
for req_id in group_request_ids:
if req_id not in self.encoder_timing_registry:
self.encoder_timing_registry[req_id] = EncoderTimingStats()
stats = self.encoder_timing_registry[req_id]
stats.encoder_forward_time += per_request_time
stats.num_encoder_calls += 1
@dataclass
class EncoderTimingStats:
"""Per-request timing statistics for encoder forward pass."""
encoder_forward_time: float = 0.0
"""Time spent in vision encoder forward pass (seconds)."""
num_encoder_calls: int = 0
"""Number of times encoder was called for this request."""
def to_dict(self) -> dict[str, float | int]:
return {
"encoder_forward_time": self.encoder_forward_time,
"num_encoder_calls": self.num_encoder_calls,
}

View File

@@ -542,6 +542,10 @@ class Worker(WorkerBase):
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks() return self.model_runner.get_supported_tasks()
def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
"""Get encoder timing stats from model runner."""
return self.model_runner.get_encoder_timing_stats()
def annotate_profile(self, scheduler_output): def annotate_profile(self, scheduler_output):
# add trace annotation so that we can easily distinguish # add trace annotation so that we can easily distinguish
# context/generation request numbers in each iteration. # context/generation request numbers in each iteration.