feat(benchmark): add encoder forward pass benchmarking to mm-processor (#31655)
Signed-off-by: Reagan <reaganjlee@gmail.com> Signed-off-by: Reagan Lee <96998476+reaganjlee@users.noreply.github.com> Co-authored-by: Hiroken. <105287758+HirokenOvo@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,10 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
MultiModalConversationDataset,
|
||||
VisionArenaDataset,
|
||||
)
|
||||
from vllm.benchmarks.throughput import get_requests
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.multimodal.processing.context import (
|
||||
@@ -45,26 +49,21 @@ def collect_mm_processor_stats(
|
||||
"""
|
||||
all_stats = get_timing_stats_from_engine_client(llm_engine)
|
||||
|
||||
stats_by_stage = {
|
||||
"hf_processor_time": [],
|
||||
"hashing_time": [],
|
||||
"cache_lookup_time": [],
|
||||
"prompt_update_time": [],
|
||||
"total_time": [],
|
||||
}
|
||||
stat_keys = [
|
||||
"hf_processor_time",
|
||||
"hashing_time",
|
||||
"cache_lookup_time",
|
||||
"prompt_update_time",
|
||||
"preprocessor_total_time",
|
||||
"encoder_forward_time",
|
||||
"num_encoder_calls",
|
||||
]
|
||||
stats_by_stage = {key: [] for key in stat_keys}
|
||||
|
||||
for stats_dict in all_stats.values():
|
||||
stats_by_stage["hf_processor_time"].append(
|
||||
stats_dict.get("hf_processor_time", 0.0)
|
||||
)
|
||||
stats_by_stage["hashing_time"].append(stats_dict.get("hashing_time", 0.0))
|
||||
stats_by_stage["cache_lookup_time"].append(
|
||||
stats_dict.get("cache_lookup_time", 0.0)
|
||||
)
|
||||
stats_by_stage["prompt_update_time"].append(
|
||||
stats_dict.get("prompt_update_time", 0.0)
|
||||
)
|
||||
stats_by_stage["total_time"].append(stats_dict.get("total_time", 0.0))
|
||||
for key in stat_keys:
|
||||
if key in stats_dict:
|
||||
stats_by_stage[key].append(stats_dict[key])
|
||||
|
||||
return stats_by_stage
|
||||
|
||||
@@ -88,14 +87,14 @@ def calculate_mm_processor_metrics(
|
||||
}
|
||||
continue
|
||||
|
||||
times_ms = [t * 1000 for t in times]
|
||||
is_count_metric = stage_name == "num_encoder_calls"
|
||||
values = times if is_count_metric else [t * 1000 for t in times]
|
||||
|
||||
metrics[stage_name] = {
|
||||
"mean": float(np.mean(times_ms)),
|
||||
"median": float(np.median(times_ms)),
|
||||
"std": float(np.std(times_ms)),
|
||||
**{
|
||||
f"p{p}": float(np.percentile(times_ms, p)) for p in selected_percentiles
|
||||
},
|
||||
"mean": float(np.mean(values)),
|
||||
"median": float(np.median(values)),
|
||||
"std": float(np.std(values)),
|
||||
**{f"p{p}": float(np.percentile(values, p)) for p in selected_percentiles},
|
||||
}
|
||||
|
||||
return metrics
|
||||
@@ -114,6 +113,23 @@ def validate_args(args):
|
||||
if not hasattr(args, "max_loras"):
|
||||
args.max_loras = None
|
||||
|
||||
if args.dataset_name == "hf" and not args.dataset_path:
|
||||
raise ValueError(
|
||||
"--dataset-path is required when using --dataset-name hf. "
|
||||
"For multimodal benchmarking, specify a dataset like "
|
||||
"'lmarena-ai/VisionArena-Chat'."
|
||||
)
|
||||
if args.dataset_name == "hf":
|
||||
supported_mm_datasets = (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
)
|
||||
if args.dataset_path not in supported_mm_datasets:
|
||||
raise ValueError(
|
||||
f"{args.dataset_path} is not a supported multimodal dataset. "
|
||||
f"Supported multimodal datasets are: {sorted(supported_mm_datasets)}"
|
||||
)
|
||||
|
||||
|
||||
def benchmark_multimodal_processor(
|
||||
args: argparse.Namespace,
|
||||
@@ -223,6 +239,17 @@ def benchmark_multimodal_processor(
|
||||
std_e2el_ms = 0.0
|
||||
percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles]
|
||||
|
||||
encoder_summary = {}
|
||||
if (
|
||||
"num_encoder_calls" in mm_stats_by_stage
|
||||
and mm_stats_by_stage["num_encoder_calls"]
|
||||
):
|
||||
encoder_calls = mm_stats_by_stage["num_encoder_calls"]
|
||||
encoder_summary = {
|
||||
"total_encoder_calls": int(sum(encoder_calls)),
|
||||
"num_requests_with_encoder_calls": len(encoder_calls),
|
||||
}
|
||||
|
||||
benchmark_result = {
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
@@ -231,6 +258,7 @@ def benchmark_multimodal_processor(
|
||||
"std_e2el_ms": std_e2el_ms,
|
||||
"percentiles_e2el_ms": percentiles_e2el_ms,
|
||||
"mm_processor_stats": mm_processor_metrics,
|
||||
"encoder_summary": encoder_summary,
|
||||
}
|
||||
|
||||
return benchmark_result
|
||||
@@ -248,7 +276,7 @@ def add_cli_args(parser: argparse.ArgumentParser) -> None:
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="random-mm",
|
||||
choices=["random-mm", "random-rerank"],
|
||||
choices=["random-mm", "hf"],
|
||||
help="Name of the dataset to benchmark on. Defaults to 'random-mm'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -266,6 +294,34 @@ def add_cli_args(parser: argparse.ArgumentParser) -> None:
|
||||
add_random_dataset_base_args(parser)
|
||||
add_random_multimodal_dataset_args(parser)
|
||||
|
||||
# HuggingFace dataset arguments
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset file or HuggingFace dataset name "
|
||||
"(e.g., 'yale-nlp/MMVU', 'lmarena-ai/VisionArena-Chat').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HuggingFace dataset (optional).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HuggingFace dataset (e.g., 'train', 'test', 'validation').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. "
|
||||
"Overrides the default output lengths from the dataset.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
@@ -296,14 +352,17 @@ def main(args: argparse.Namespace) -> None:
|
||||
print("=" * 80)
|
||||
|
||||
if "mm_processor_stats" in result:
|
||||
print("\nMM Processor Timing (ms):")
|
||||
print("\nMM Processor Metrics:")
|
||||
selected_percentiles = [
|
||||
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
|
||||
]
|
||||
mm_data = []
|
||||
for stage, metrics in result["mm_processor_stats"].items():
|
||||
is_count = stage == "num_encoder_calls"
|
||||
unit = "" if is_count else " (ms)"
|
||||
|
||||
row = {
|
||||
"Stage": stage,
|
||||
"Stage": stage + unit,
|
||||
"Mean": f"{metrics['mean']:.2f}",
|
||||
"Median": f"{metrics['median']:.2f}",
|
||||
"Std": f"{metrics['std']:.2f}",
|
||||
@@ -315,6 +374,14 @@ def main(args: argparse.Namespace) -> None:
|
||||
mm_df = pd.DataFrame(mm_data)
|
||||
print(mm_df.to_string(index=False))
|
||||
|
||||
if "encoder_summary" in result and result["encoder_summary"]:
|
||||
total_calls = result["encoder_summary"]["total_encoder_calls"]
|
||||
num_requests = result["encoder_summary"]["num_requests_with_encoder_calls"]
|
||||
print(
|
||||
f"\nSummary: {total_calls} total encoder calls "
|
||||
f"across {num_requests} requests."
|
||||
)
|
||||
|
||||
if "mean_e2el_ms" in result:
|
||||
print("\nEnd-to-End Latency (ms):")
|
||||
selected_percentiles = [
|
||||
|
||||
@@ -75,8 +75,8 @@ class MultiModalProcessorTimingStats:
|
||||
prompt_update_time: float = 0.0
|
||||
"""Time spent applying prompt updates and finding placeholders (seconds)."""
|
||||
|
||||
total_time: float = 0.0
|
||||
"""Total processing time (seconds)."""
|
||||
preprocessor_total_time: float = 0.0
|
||||
"""Total preprocessing time (seconds)."""
|
||||
|
||||
def to_dict(self) -> dict[str, float]:
|
||||
"""Convert stats to a dictionary for JSON serialization."""
|
||||
@@ -85,7 +85,7 @@ class MultiModalProcessorTimingStats:
|
||||
"hashing_time": self.hashing_time,
|
||||
"cache_lookup_time": self.cache_lookup_time,
|
||||
"prompt_update_time": self.prompt_update_time,
|
||||
"total_time": self.total_time,
|
||||
"preprocessor_total_time": self.preprocessor_total_time,
|
||||
}
|
||||
|
||||
|
||||
@@ -93,13 +93,30 @@ def get_timing_stats_from_engine_client(
|
||||
engine_client: Any,
|
||||
) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
Get all timing stats from the context associated with the engine client.
|
||||
Get all multimodal timing stats from the engine client.
|
||||
|
||||
Collects both preprocessing stats (HF processor, hashing, cache lookup,
|
||||
prompt update) and encoder forward pass timing, merged by request_id.
|
||||
|
||||
Args:
|
||||
engine_client: The engine client that has input_processor.
|
||||
engine_client: The engine client (has input_processor and workers).
|
||||
|
||||
Returns:
|
||||
A dictionary mapping request_id to stats dict.
|
||||
Dictionary mapping request_id to merged stats dict containing
|
||||
both preprocessing and encoder timing metrics.
|
||||
|
||||
Example:
|
||||
{
|
||||
'request-123': {
|
||||
'hf_processor_time': 0.45,
|
||||
'hashing_time': 0.02,
|
||||
'cache_lookup_time': 0.01,
|
||||
'prompt_update_time': 0.03,
|
||||
'preprocessor_total_time': 0.51,
|
||||
'encoder_forward_time': 0.23,
|
||||
'num_encoder_calls': 1
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
if not engine_client.vllm_config.observability_config.enable_mm_processor_stats:
|
||||
@@ -107,6 +124,7 @@ def get_timing_stats_from_engine_client(
|
||||
except (AttributeError, RuntimeError):
|
||||
return {}
|
||||
|
||||
preprocessing_stats = {}
|
||||
try:
|
||||
input_processor = engine_client.input_processor
|
||||
input_preprocessor = input_processor.input_preprocessor
|
||||
@@ -115,15 +133,68 @@ def get_timing_stats_from_engine_client(
|
||||
mm_processor = input_preprocessor._get_mm_processor()
|
||||
if mm_processor is not None and hasattr(mm_processor, "info"):
|
||||
ctx = mm_processor.info.ctx
|
||||
return ctx.get_all_timing_stats()
|
||||
preprocessing_stats = ctx.get_all_timing_stats()
|
||||
except (AttributeError, RuntimeError):
|
||||
pass
|
||||
|
||||
return {}
|
||||
encoder_stats = {}
|
||||
try:
|
||||
if hasattr(engine_client, "collective_rpc"):
|
||||
encoder_stats_results = engine_client.collective_rpc(
|
||||
"get_encoder_timing_stats"
|
||||
)
|
||||
if encoder_stats_results and len(encoder_stats_results) > 0:
|
||||
for worker_stats in encoder_stats_results:
|
||||
if not worker_stats:
|
||||
continue
|
||||
for request_id, stats_dict in worker_stats.items():
|
||||
if request_id not in encoder_stats:
|
||||
encoder_stats[request_id] = dict(stats_dict)
|
||||
else:
|
||||
# Aggregate timing metrics across workers
|
||||
current_time = encoder_stats[request_id].get(
|
||||
"encoder_forward_time", 0.0
|
||||
)
|
||||
new_time = stats_dict.get("encoder_forward_time", 0.0)
|
||||
encoder_stats[request_id]["encoder_forward_time"] = max(
|
||||
current_time, new_time
|
||||
)
|
||||
|
||||
current_calls = encoder_stats[request_id].get(
|
||||
"num_encoder_calls", 0
|
||||
)
|
||||
new_calls = stats_dict.get("num_encoder_calls", 0)
|
||||
encoder_stats[request_id]["num_encoder_calls"] = max(
|
||||
current_calls, new_calls
|
||||
)
|
||||
except (AttributeError, RuntimeError):
|
||||
pass
|
||||
|
||||
merged_stats = {}
|
||||
|
||||
for request_id, prep_dict in preprocessing_stats.items():
|
||||
merged_stats[request_id] = dict(prep_dict)
|
||||
|
||||
for request_id, enc_dict in encoder_stats.items():
|
||||
if request_id in merged_stats:
|
||||
merged_stats[request_id].update(enc_dict)
|
||||
continue
|
||||
|
||||
# In V1 engine, the request_id in encoder_stats has a suffix
|
||||
# appended to the original request_id (which is used in
|
||||
# preprocessing_stats).
|
||||
# We try to strip the suffix to find the matching request.
|
||||
possible_original_id = request_id.rpartition("-")[0]
|
||||
if possible_original_id and possible_original_id in merged_stats:
|
||||
merged_stats[possible_original_id].update(enc_dict)
|
||||
else:
|
||||
merged_stats[request_id] = dict(enc_dict)
|
||||
|
||||
return merged_stats
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timed_operation(ctx: "InputProcessingContext", stage_name: str):
|
||||
def timed_preprocessor_operation(ctx: "InputProcessingContext", stage_name: str):
|
||||
"""
|
||||
Context manager to time an operation using the context's timing stats.
|
||||
|
||||
@@ -157,7 +228,7 @@ def timed_operation(ctx: "InputProcessingContext", stage_name: str):
|
||||
stats.cache_lookup_time += elapsed
|
||||
elif stage_name == "prompt_update":
|
||||
stats.prompt_update_time += elapsed
|
||||
stats.total_time += elapsed
|
||||
stats.preprocessor_total_time += elapsed
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
@@ -42,7 +42,11 @@ from ..parse import (
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from .context import BaseProcessingInfo, get_current_request_id, timed_operation
|
||||
from .context import (
|
||||
BaseProcessingInfo,
|
||||
get_current_request_id,
|
||||
timed_preprocessor_operation,
|
||||
)
|
||||
from .dummy_inputs import BaseDummyInputsBuilder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -1192,7 +1196,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
Call the HF processor on the prompt text and
|
||||
associated multi-modal data.
|
||||
"""
|
||||
with timed_operation(self.info.ctx, "hf_processor"):
|
||||
with timed_preprocessor_operation(self.info.ctx, "hf_processor"):
|
||||
return self.info.ctx.call_hf_processor(
|
||||
self.info.get_hf_processor(**mm_kwargs),
|
||||
dict(text=prompt, **mm_data),
|
||||
@@ -1545,7 +1549,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
)
|
||||
|
||||
# Use overrides if provided; fallback to data-dependent hashing.
|
||||
with timed_operation(self.info.ctx, "hashing"):
|
||||
with timed_preprocessor_operation(self.info.ctx, "hashing"):
|
||||
mm_hashes = self._hash_mm_items(
|
||||
mm_data_items,
|
||||
hf_processor_mm_kwargs,
|
||||
@@ -1592,7 +1596,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
with timed_operation(self.info.ctx, "hashing"):
|
||||
with timed_preprocessor_operation(self.info.ctx, "hashing"):
|
||||
mm_hashes = self._hash_mm_items(
|
||||
mm_data_items,
|
||||
hf_processor_mm_kwargs,
|
||||
@@ -1600,7 +1604,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
with timed_operation(self.info.ctx, "cache_lookup"):
|
||||
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
|
||||
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
|
||||
cache=cache,
|
||||
mm_data_items=mm_data_items,
|
||||
@@ -1635,7 +1639,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_missing_kwargs,
|
||||
)
|
||||
|
||||
with timed_operation(self.info.ctx, "cache_lookup"):
|
||||
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
|
||||
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
|
||||
cache,
|
||||
mm_hashes=mm_hashes,
|
||||
@@ -1846,7 +1850,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
)
|
||||
|
||||
# NOTE: tokenization_kwargs are not required to init processor
|
||||
with timed_operation(self.info.ctx, "prompt_update"):
|
||||
with timed_preprocessor_operation(self.info.ctx, "prompt_update"):
|
||||
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
|
||||
mm_items=mm_items,
|
||||
prompt_ids=prompt_ids,
|
||||
|
||||
@@ -4,11 +4,13 @@
|
||||
import functools
|
||||
import gc
|
||||
import itertools
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
|
||||
|
||||
@@ -546,6 +548,10 @@ class GPUModelRunner(
|
||||
# Cache the device properties.
|
||||
self._init_device_properties()
|
||||
|
||||
# Encoder timing registry for observability
|
||||
self.encoder_timing_registry: dict[str, EncoderTimingStats] = {}
|
||||
self._encoder_timing_lock = threading.Lock()
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
|
||||
@@ -2212,6 +2218,12 @@ class GPUModelRunner(
|
||||
if not mm_kwargs:
|
||||
return []
|
||||
|
||||
should_time = bool(
|
||||
self.observability_config
|
||||
and self.observability_config.enable_mm_processor_stats
|
||||
and scheduler_output.scheduled_encoder_inputs
|
||||
)
|
||||
|
||||
# Batch mm inputs as much as we can: if a request in the batch has
|
||||
# multiple modalities or a different modality than the previous one,
|
||||
# we process it separately to preserve item order.
|
||||
@@ -2278,6 +2290,8 @@ class GPUModelRunner(
|
||||
)
|
||||
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
# Track the current index in mm_kwargs/mm_lora_refs to map groups to request IDs
|
||||
current_item_idx = 0
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
@@ -2300,22 +2314,24 @@ class GPUModelRunner(
|
||||
and num_items > 1
|
||||
):
|
||||
curr_group_outputs_lst = list[torch.Tensor]()
|
||||
for video_mm_kwargs_item in filter(
|
||||
lambda item: item.modality == "video", mm_kwargs
|
||||
):
|
||||
_, _, micro_batch_mm_inputs = next(
|
||||
group_mm_kwargs_by_modality(
|
||||
[video_mm_kwargs_item],
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
for video_idx in range(num_items):
|
||||
video_mm_kwargs_item = mm_kwargs[current_item_idx + video_idx]
|
||||
with self.timed_encoder_operation(
|
||||
should_time, mm_lora_refs, current_item_idx + video_idx, 1
|
||||
):
|
||||
_, _, micro_batch_mm_inputs = next(
|
||||
group_mm_kwargs_by_modality(
|
||||
[video_mm_kwargs_item],
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
micro_batch_outputs = model.embed_multimodal(
|
||||
**micro_batch_mm_inputs
|
||||
)
|
||||
micro_batch_outputs = model.embed_multimodal(
|
||||
**micro_batch_mm_inputs
|
||||
)
|
||||
|
||||
curr_group_outputs_lst.extend(micro_batch_outputs)
|
||||
curr_group_outputs_lst.extend(micro_batch_outputs)
|
||||
|
||||
curr_group_outputs = curr_group_outputs_lst
|
||||
else:
|
||||
@@ -2326,7 +2342,11 @@ class GPUModelRunner(
|
||||
# 2. A list or tuple (length: num_items) of tensors,
|
||||
# each of shape (feature_size, hidden_size) in case the feature
|
||||
# size is dynamic depending on the input multimodal items.
|
||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||
|
||||
with self.timed_encoder_operation(
|
||||
should_time, mm_lora_refs, current_item_idx, num_items
|
||||
):
|
||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||
|
||||
sanity_check_mm_encoder_outputs(
|
||||
curr_group_outputs,
|
||||
@@ -2334,6 +2354,8 @@ class GPUModelRunner(
|
||||
)
|
||||
encoder_outputs.extend(curr_group_outputs)
|
||||
|
||||
current_item_idx += num_items
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
for mm_hash, output in zip(mm_hashes, encoder_outputs):
|
||||
self.encoder_cache[mm_hash] = output
|
||||
@@ -5919,3 +5941,79 @@ class GPUModelRunner(
|
||||
self.transfer_event.record()
|
||||
self.transfer_event.synchronize()
|
||||
return pinned.tolist()
|
||||
|
||||
def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
|
||||
"""
|
||||
Get encoder timing stats for all requests and clear the registry.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping request_id to stats dict.
|
||||
"""
|
||||
with self._encoder_timing_lock:
|
||||
stats = {
|
||||
req_id: stats_obj.to_dict()
|
||||
for req_id, stats_obj in self.encoder_timing_registry.items()
|
||||
}
|
||||
self.encoder_timing_registry.clear()
|
||||
return stats
|
||||
|
||||
@contextmanager
|
||||
def timed_encoder_operation(
|
||||
self,
|
||||
should_time: bool,
|
||||
group_lora_refs: list[tuple[str, Any]],
|
||||
current_item_idx: int,
|
||||
num_items: int,
|
||||
):
|
||||
"""
|
||||
Context manager to time encoder forward operations.
|
||||
|
||||
Args:
|
||||
should_time: Whether timing is enabled
|
||||
group_lora_refs: Full list of (request_id, pos_info) tuples
|
||||
current_item_idx: Starting index for this group
|
||||
num_items: Number of items in this group
|
||||
"""
|
||||
if not should_time:
|
||||
yield
|
||||
return
|
||||
|
||||
group_refs = group_lora_refs[current_item_idx : current_item_idx + num_items]
|
||||
group_request_ids = {req_id for req_id, _ in group_refs}
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start_time
|
||||
|
||||
per_request_time = elapsed / max(len(group_request_ids), 1)
|
||||
|
||||
with self._encoder_timing_lock:
|
||||
for req_id in group_request_ids:
|
||||
if req_id not in self.encoder_timing_registry:
|
||||
self.encoder_timing_registry[req_id] = EncoderTimingStats()
|
||||
|
||||
stats = self.encoder_timing_registry[req_id]
|
||||
stats.encoder_forward_time += per_request_time
|
||||
stats.num_encoder_calls += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class EncoderTimingStats:
|
||||
"""Per-request timing statistics for encoder forward pass."""
|
||||
|
||||
encoder_forward_time: float = 0.0
|
||||
"""Time spent in vision encoder forward pass (seconds)."""
|
||||
|
||||
num_encoder_calls: int = 0
|
||||
"""Number of times encoder was called for this request."""
|
||||
|
||||
def to_dict(self) -> dict[str, float | int]:
|
||||
return {
|
||||
"encoder_forward_time": self.encoder_forward_time,
|
||||
"num_encoder_calls": self.num_encoder_calls,
|
||||
}
|
||||
|
||||
@@ -542,6 +542,10 @@ class Worker(WorkerBase):
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.model_runner.get_supported_tasks()
|
||||
|
||||
def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
|
||||
"""Get encoder timing stats from model runner."""
|
||||
return self.model_runner.get_encoder_timing_stats()
|
||||
|
||||
def annotate_profile(self, scheduler_output):
|
||||
# add trace annotation so that we can easily distinguish
|
||||
# context/generation request numbers in each iteration.
|
||||
|
||||
Reference in New Issue
Block a user