feat(benchmark): add encoder forward pass benchmarking to mm-processor (#31655)
Signed-off-by: Reagan <reaganjlee@gmail.com> Signed-off-by: Reagan Lee <96998476+reaganjlee@users.noreply.github.com> Co-authored-by: Hiroken. <105287758+HirokenOvo@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,10 @@ from typing import Any
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from vllm.benchmarks.datasets import (
|
||||||
|
MultiModalConversationDataset,
|
||||||
|
VisionArenaDataset,
|
||||||
|
)
|
||||||
from vllm.benchmarks.throughput import get_requests
|
from vllm.benchmarks.throughput import get_requests
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.multimodal.processing.context import (
|
from vllm.multimodal.processing.context import (
|
||||||
@@ -45,26 +49,21 @@ def collect_mm_processor_stats(
|
|||||||
"""
|
"""
|
||||||
all_stats = get_timing_stats_from_engine_client(llm_engine)
|
all_stats = get_timing_stats_from_engine_client(llm_engine)
|
||||||
|
|
||||||
stats_by_stage = {
|
stat_keys = [
|
||||||
"hf_processor_time": [],
|
"hf_processor_time",
|
||||||
"hashing_time": [],
|
"hashing_time",
|
||||||
"cache_lookup_time": [],
|
"cache_lookup_time",
|
||||||
"prompt_update_time": [],
|
"prompt_update_time",
|
||||||
"total_time": [],
|
"preprocessor_total_time",
|
||||||
}
|
"encoder_forward_time",
|
||||||
|
"num_encoder_calls",
|
||||||
|
]
|
||||||
|
stats_by_stage = {key: [] for key in stat_keys}
|
||||||
|
|
||||||
for stats_dict in all_stats.values():
|
for stats_dict in all_stats.values():
|
||||||
stats_by_stage["hf_processor_time"].append(
|
for key in stat_keys:
|
||||||
stats_dict.get("hf_processor_time", 0.0)
|
if key in stats_dict:
|
||||||
)
|
stats_by_stage[key].append(stats_dict[key])
|
||||||
stats_by_stage["hashing_time"].append(stats_dict.get("hashing_time", 0.0))
|
|
||||||
stats_by_stage["cache_lookup_time"].append(
|
|
||||||
stats_dict.get("cache_lookup_time", 0.0)
|
|
||||||
)
|
|
||||||
stats_by_stage["prompt_update_time"].append(
|
|
||||||
stats_dict.get("prompt_update_time", 0.0)
|
|
||||||
)
|
|
||||||
stats_by_stage["total_time"].append(stats_dict.get("total_time", 0.0))
|
|
||||||
|
|
||||||
return stats_by_stage
|
return stats_by_stage
|
||||||
|
|
||||||
@@ -88,14 +87,14 @@ def calculate_mm_processor_metrics(
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
|
||||||
times_ms = [t * 1000 for t in times]
|
is_count_metric = stage_name == "num_encoder_calls"
|
||||||
|
values = times if is_count_metric else [t * 1000 for t in times]
|
||||||
|
|
||||||
metrics[stage_name] = {
|
metrics[stage_name] = {
|
||||||
"mean": float(np.mean(times_ms)),
|
"mean": float(np.mean(values)),
|
||||||
"median": float(np.median(times_ms)),
|
"median": float(np.median(values)),
|
||||||
"std": float(np.std(times_ms)),
|
"std": float(np.std(values)),
|
||||||
**{
|
**{f"p{p}": float(np.percentile(values, p)) for p in selected_percentiles},
|
||||||
f"p{p}": float(np.percentile(times_ms, p)) for p in selected_percentiles
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
@@ -114,6 +113,23 @@ def validate_args(args):
|
|||||||
if not hasattr(args, "max_loras"):
|
if not hasattr(args, "max_loras"):
|
||||||
args.max_loras = None
|
args.max_loras = None
|
||||||
|
|
||||||
|
if args.dataset_name == "hf" and not args.dataset_path:
|
||||||
|
raise ValueError(
|
||||||
|
"--dataset-path is required when using --dataset-name hf. "
|
||||||
|
"For multimodal benchmarking, specify a dataset like "
|
||||||
|
"'lmarena-ai/VisionArena-Chat'."
|
||||||
|
)
|
||||||
|
if args.dataset_name == "hf":
|
||||||
|
supported_mm_datasets = (
|
||||||
|
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||||
|
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
|
||||||
|
)
|
||||||
|
if args.dataset_path not in supported_mm_datasets:
|
||||||
|
raise ValueError(
|
||||||
|
f"{args.dataset_path} is not a supported multimodal dataset. "
|
||||||
|
f"Supported multimodal datasets are: {sorted(supported_mm_datasets)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def benchmark_multimodal_processor(
|
def benchmark_multimodal_processor(
|
||||||
args: argparse.Namespace,
|
args: argparse.Namespace,
|
||||||
@@ -223,6 +239,17 @@ def benchmark_multimodal_processor(
|
|||||||
std_e2el_ms = 0.0
|
std_e2el_ms = 0.0
|
||||||
percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles]
|
percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles]
|
||||||
|
|
||||||
|
encoder_summary = {}
|
||||||
|
if (
|
||||||
|
"num_encoder_calls" in mm_stats_by_stage
|
||||||
|
and mm_stats_by_stage["num_encoder_calls"]
|
||||||
|
):
|
||||||
|
encoder_calls = mm_stats_by_stage["num_encoder_calls"]
|
||||||
|
encoder_summary = {
|
||||||
|
"total_encoder_calls": int(sum(encoder_calls)),
|
||||||
|
"num_requests_with_encoder_calls": len(encoder_calls),
|
||||||
|
}
|
||||||
|
|
||||||
benchmark_result = {
|
benchmark_result = {
|
||||||
"completed": completed,
|
"completed": completed,
|
||||||
"failed": failed,
|
"failed": failed,
|
||||||
@@ -231,6 +258,7 @@ def benchmark_multimodal_processor(
|
|||||||
"std_e2el_ms": std_e2el_ms,
|
"std_e2el_ms": std_e2el_ms,
|
||||||
"percentiles_e2el_ms": percentiles_e2el_ms,
|
"percentiles_e2el_ms": percentiles_e2el_ms,
|
||||||
"mm_processor_stats": mm_processor_metrics,
|
"mm_processor_stats": mm_processor_metrics,
|
||||||
|
"encoder_summary": encoder_summary,
|
||||||
}
|
}
|
||||||
|
|
||||||
return benchmark_result
|
return benchmark_result
|
||||||
@@ -248,7 +276,7 @@ def add_cli_args(parser: argparse.ArgumentParser) -> None:
|
|||||||
"--dataset-name",
|
"--dataset-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="random-mm",
|
default="random-mm",
|
||||||
choices=["random-mm", "random-rerank"],
|
choices=["random-mm", "hf"],
|
||||||
help="Name of the dataset to benchmark on. Defaults to 'random-mm'.",
|
help="Name of the dataset to benchmark on. Defaults to 'random-mm'.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -266,6 +294,34 @@ def add_cli_args(parser: argparse.ArgumentParser) -> None:
|
|||||||
add_random_dataset_base_args(parser)
|
add_random_dataset_base_args(parser)
|
||||||
add_random_multimodal_dataset_args(parser)
|
add_random_multimodal_dataset_args(parser)
|
||||||
|
|
||||||
|
# HuggingFace dataset arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the dataset file or HuggingFace dataset name "
|
||||||
|
"(e.g., 'yale-nlp/MMVU', 'lmarena-ai/VisionArena-Chat').",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-subset",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Subset of the HuggingFace dataset (optional).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-split",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Split of the HuggingFace dataset (e.g., 'train', 'test', 'validation').",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Output length for each request. "
|
||||||
|
"Overrides the default output lengths from the dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-json",
|
"--output-json",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -296,14 +352,17 @@ def main(args: argparse.Namespace) -> None:
|
|||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
if "mm_processor_stats" in result:
|
if "mm_processor_stats" in result:
|
||||||
print("\nMM Processor Timing (ms):")
|
print("\nMM Processor Metrics:")
|
||||||
selected_percentiles = [
|
selected_percentiles = [
|
||||||
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
|
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
|
||||||
]
|
]
|
||||||
mm_data = []
|
mm_data = []
|
||||||
for stage, metrics in result["mm_processor_stats"].items():
|
for stage, metrics in result["mm_processor_stats"].items():
|
||||||
|
is_count = stage == "num_encoder_calls"
|
||||||
|
unit = "" if is_count else " (ms)"
|
||||||
|
|
||||||
row = {
|
row = {
|
||||||
"Stage": stage,
|
"Stage": stage + unit,
|
||||||
"Mean": f"{metrics['mean']:.2f}",
|
"Mean": f"{metrics['mean']:.2f}",
|
||||||
"Median": f"{metrics['median']:.2f}",
|
"Median": f"{metrics['median']:.2f}",
|
||||||
"Std": f"{metrics['std']:.2f}",
|
"Std": f"{metrics['std']:.2f}",
|
||||||
@@ -315,6 +374,14 @@ def main(args: argparse.Namespace) -> None:
|
|||||||
mm_df = pd.DataFrame(mm_data)
|
mm_df = pd.DataFrame(mm_data)
|
||||||
print(mm_df.to_string(index=False))
|
print(mm_df.to_string(index=False))
|
||||||
|
|
||||||
|
if "encoder_summary" in result and result["encoder_summary"]:
|
||||||
|
total_calls = result["encoder_summary"]["total_encoder_calls"]
|
||||||
|
num_requests = result["encoder_summary"]["num_requests_with_encoder_calls"]
|
||||||
|
print(
|
||||||
|
f"\nSummary: {total_calls} total encoder calls "
|
||||||
|
f"across {num_requests} requests."
|
||||||
|
)
|
||||||
|
|
||||||
if "mean_e2el_ms" in result:
|
if "mean_e2el_ms" in result:
|
||||||
print("\nEnd-to-End Latency (ms):")
|
print("\nEnd-to-End Latency (ms):")
|
||||||
selected_percentiles = [
|
selected_percentiles = [
|
||||||
|
|||||||
@@ -75,8 +75,8 @@ class MultiModalProcessorTimingStats:
|
|||||||
prompt_update_time: float = 0.0
|
prompt_update_time: float = 0.0
|
||||||
"""Time spent applying prompt updates and finding placeholders (seconds)."""
|
"""Time spent applying prompt updates and finding placeholders (seconds)."""
|
||||||
|
|
||||||
total_time: float = 0.0
|
preprocessor_total_time: float = 0.0
|
||||||
"""Total processing time (seconds)."""
|
"""Total preprocessing time (seconds)."""
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, float]:
|
def to_dict(self) -> dict[str, float]:
|
||||||
"""Convert stats to a dictionary for JSON serialization."""
|
"""Convert stats to a dictionary for JSON serialization."""
|
||||||
@@ -85,7 +85,7 @@ class MultiModalProcessorTimingStats:
|
|||||||
"hashing_time": self.hashing_time,
|
"hashing_time": self.hashing_time,
|
||||||
"cache_lookup_time": self.cache_lookup_time,
|
"cache_lookup_time": self.cache_lookup_time,
|
||||||
"prompt_update_time": self.prompt_update_time,
|
"prompt_update_time": self.prompt_update_time,
|
||||||
"total_time": self.total_time,
|
"preprocessor_total_time": self.preprocessor_total_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -93,13 +93,30 @@ def get_timing_stats_from_engine_client(
|
|||||||
engine_client: Any,
|
engine_client: Any,
|
||||||
) -> dict[str, dict[str, float]]:
|
) -> dict[str, dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
Get all timing stats from the context associated with the engine client.
|
Get all multimodal timing stats from the engine client.
|
||||||
|
|
||||||
|
Collects both preprocessing stats (HF processor, hashing, cache lookup,
|
||||||
|
prompt update) and encoder forward pass timing, merged by request_id.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
engine_client: The engine client that has input_processor.
|
engine_client: The engine client (has input_processor and workers).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping request_id to stats dict.
|
Dictionary mapping request_id to merged stats dict containing
|
||||||
|
both preprocessing and encoder timing metrics.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
'request-123': {
|
||||||
|
'hf_processor_time': 0.45,
|
||||||
|
'hashing_time': 0.02,
|
||||||
|
'cache_lookup_time': 0.01,
|
||||||
|
'prompt_update_time': 0.03,
|
||||||
|
'preprocessor_total_time': 0.51,
|
||||||
|
'encoder_forward_time': 0.23,
|
||||||
|
'num_encoder_calls': 1
|
||||||
|
}
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if not engine_client.vllm_config.observability_config.enable_mm_processor_stats:
|
if not engine_client.vllm_config.observability_config.enable_mm_processor_stats:
|
||||||
@@ -107,6 +124,7 @@ def get_timing_stats_from_engine_client(
|
|||||||
except (AttributeError, RuntimeError):
|
except (AttributeError, RuntimeError):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
preprocessing_stats = {}
|
||||||
try:
|
try:
|
||||||
input_processor = engine_client.input_processor
|
input_processor = engine_client.input_processor
|
||||||
input_preprocessor = input_processor.input_preprocessor
|
input_preprocessor = input_processor.input_preprocessor
|
||||||
@@ -115,15 +133,68 @@ def get_timing_stats_from_engine_client(
|
|||||||
mm_processor = input_preprocessor._get_mm_processor()
|
mm_processor = input_preprocessor._get_mm_processor()
|
||||||
if mm_processor is not None and hasattr(mm_processor, "info"):
|
if mm_processor is not None and hasattr(mm_processor, "info"):
|
||||||
ctx = mm_processor.info.ctx
|
ctx = mm_processor.info.ctx
|
||||||
return ctx.get_all_timing_stats()
|
preprocessing_stats = ctx.get_all_timing_stats()
|
||||||
except (AttributeError, RuntimeError):
|
except (AttributeError, RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return {}
|
encoder_stats = {}
|
||||||
|
try:
|
||||||
|
if hasattr(engine_client, "collective_rpc"):
|
||||||
|
encoder_stats_results = engine_client.collective_rpc(
|
||||||
|
"get_encoder_timing_stats"
|
||||||
|
)
|
||||||
|
if encoder_stats_results and len(encoder_stats_results) > 0:
|
||||||
|
for worker_stats in encoder_stats_results:
|
||||||
|
if not worker_stats:
|
||||||
|
continue
|
||||||
|
for request_id, stats_dict in worker_stats.items():
|
||||||
|
if request_id not in encoder_stats:
|
||||||
|
encoder_stats[request_id] = dict(stats_dict)
|
||||||
|
else:
|
||||||
|
# Aggregate timing metrics across workers
|
||||||
|
current_time = encoder_stats[request_id].get(
|
||||||
|
"encoder_forward_time", 0.0
|
||||||
|
)
|
||||||
|
new_time = stats_dict.get("encoder_forward_time", 0.0)
|
||||||
|
encoder_stats[request_id]["encoder_forward_time"] = max(
|
||||||
|
current_time, new_time
|
||||||
|
)
|
||||||
|
|
||||||
|
current_calls = encoder_stats[request_id].get(
|
||||||
|
"num_encoder_calls", 0
|
||||||
|
)
|
||||||
|
new_calls = stats_dict.get("num_encoder_calls", 0)
|
||||||
|
encoder_stats[request_id]["num_encoder_calls"] = max(
|
||||||
|
current_calls, new_calls
|
||||||
|
)
|
||||||
|
except (AttributeError, RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
merged_stats = {}
|
||||||
|
|
||||||
|
for request_id, prep_dict in preprocessing_stats.items():
|
||||||
|
merged_stats[request_id] = dict(prep_dict)
|
||||||
|
|
||||||
|
for request_id, enc_dict in encoder_stats.items():
|
||||||
|
if request_id in merged_stats:
|
||||||
|
merged_stats[request_id].update(enc_dict)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# In V1 engine, the request_id in encoder_stats has a suffix
|
||||||
|
# appended to the original request_id (which is used in
|
||||||
|
# preprocessing_stats).
|
||||||
|
# We try to strip the suffix to find the matching request.
|
||||||
|
possible_original_id = request_id.rpartition("-")[0]
|
||||||
|
if possible_original_id and possible_original_id in merged_stats:
|
||||||
|
merged_stats[possible_original_id].update(enc_dict)
|
||||||
|
else:
|
||||||
|
merged_stats[request_id] = dict(enc_dict)
|
||||||
|
|
||||||
|
return merged_stats
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def timed_operation(ctx: "InputProcessingContext", stage_name: str):
|
def timed_preprocessor_operation(ctx: "InputProcessingContext", stage_name: str):
|
||||||
"""
|
"""
|
||||||
Context manager to time an operation using the context's timing stats.
|
Context manager to time an operation using the context's timing stats.
|
||||||
|
|
||||||
@@ -157,7 +228,7 @@ def timed_operation(ctx: "InputProcessingContext", stage_name: str):
|
|||||||
stats.cache_lookup_time += elapsed
|
stats.cache_lookup_time += elapsed
|
||||||
elif stage_name == "prompt_update":
|
elif stage_name == "prompt_update":
|
||||||
stats.prompt_update_time += elapsed
|
stats.prompt_update_time += elapsed
|
||||||
stats.total_time += elapsed
|
stats.preprocessor_total_time += elapsed
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|||||||
@@ -42,7 +42,11 @@ from ..parse import (
|
|||||||
MultiModalDataItems,
|
MultiModalDataItems,
|
||||||
MultiModalDataParser,
|
MultiModalDataParser,
|
||||||
)
|
)
|
||||||
from .context import BaseProcessingInfo, get_current_request_id, timed_operation
|
from .context import (
|
||||||
|
BaseProcessingInfo,
|
||||||
|
get_current_request_id,
|
||||||
|
timed_preprocessor_operation,
|
||||||
|
)
|
||||||
from .dummy_inputs import BaseDummyInputsBuilder
|
from .dummy_inputs import BaseDummyInputsBuilder
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -1192,7 +1196,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
Call the HF processor on the prompt text and
|
Call the HF processor on the prompt text and
|
||||||
associated multi-modal data.
|
associated multi-modal data.
|
||||||
"""
|
"""
|
||||||
with timed_operation(self.info.ctx, "hf_processor"):
|
with timed_preprocessor_operation(self.info.ctx, "hf_processor"):
|
||||||
return self.info.ctx.call_hf_processor(
|
return self.info.ctx.call_hf_processor(
|
||||||
self.info.get_hf_processor(**mm_kwargs),
|
self.info.get_hf_processor(**mm_kwargs),
|
||||||
dict(text=prompt, **mm_data),
|
dict(text=prompt, **mm_data),
|
||||||
@@ -1545,7 +1549,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Use overrides if provided; fallback to data-dependent hashing.
|
# Use overrides if provided; fallback to data-dependent hashing.
|
||||||
with timed_operation(self.info.ctx, "hashing"):
|
with timed_preprocessor_operation(self.info.ctx, "hashing"):
|
||||||
mm_hashes = self._hash_mm_items(
|
mm_hashes = self._hash_mm_items(
|
||||||
mm_data_items,
|
mm_data_items,
|
||||||
hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs,
|
||||||
@@ -1592,7 +1596,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_uuids=mm_uuids,
|
mm_uuids=mm_uuids,
|
||||||
)
|
)
|
||||||
|
|
||||||
with timed_operation(self.info.ctx, "hashing"):
|
with timed_preprocessor_operation(self.info.ctx, "hashing"):
|
||||||
mm_hashes = self._hash_mm_items(
|
mm_hashes = self._hash_mm_items(
|
||||||
mm_data_items,
|
mm_data_items,
|
||||||
hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs,
|
||||||
@@ -1600,7 +1604,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_uuids=mm_uuids,
|
mm_uuids=mm_uuids,
|
||||||
)
|
)
|
||||||
|
|
||||||
with timed_operation(self.info.ctx, "cache_lookup"):
|
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
|
||||||
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
|
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
|
||||||
cache=cache,
|
cache=cache,
|
||||||
mm_data_items=mm_data_items,
|
mm_data_items=mm_data_items,
|
||||||
@@ -1635,7 +1639,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_missing_kwargs,
|
mm_missing_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
with timed_operation(self.info.ctx, "cache_lookup"):
|
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
|
||||||
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
|
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
|
||||||
cache,
|
cache,
|
||||||
mm_hashes=mm_hashes,
|
mm_hashes=mm_hashes,
|
||||||
@@ -1846,7 +1850,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: tokenization_kwargs are not required to init processor
|
# NOTE: tokenization_kwargs are not required to init processor
|
||||||
with timed_operation(self.info.ctx, "prompt_update"):
|
with timed_preprocessor_operation(self.info.ctx, "prompt_update"):
|
||||||
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
|
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
|
||||||
mm_items=mm_items,
|
mm_items=mm_items,
|
||||||
prompt_ids=prompt_ids,
|
prompt_ids=prompt_ids,
|
||||||
|
|||||||
@@ -4,11 +4,13 @@
|
|||||||
import functools
|
import functools
|
||||||
import gc
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterator, Sequence
|
from collections.abc import Iterator, Sequence
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
|
from dataclasses import dataclass
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
|
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
|
||||||
|
|
||||||
@@ -546,6 +548,10 @@ class GPUModelRunner(
|
|||||||
# Cache the device properties.
|
# Cache the device properties.
|
||||||
self._init_device_properties()
|
self._init_device_properties()
|
||||||
|
|
||||||
|
# Encoder timing registry for observability
|
||||||
|
self.encoder_timing_registry: dict[str, EncoderTimingStats] = {}
|
||||||
|
self._encoder_timing_lock = threading.Lock()
|
||||||
|
|
||||||
# Persistent buffers for CUDA graphs.
|
# Persistent buffers for CUDA graphs.
|
||||||
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
|
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
|
||||||
self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
|
self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
|
||||||
@@ -2212,6 +2218,12 @@ class GPUModelRunner(
|
|||||||
if not mm_kwargs:
|
if not mm_kwargs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
should_time = bool(
|
||||||
|
self.observability_config
|
||||||
|
and self.observability_config.enable_mm_processor_stats
|
||||||
|
and scheduler_output.scheduled_encoder_inputs
|
||||||
|
)
|
||||||
|
|
||||||
# Batch mm inputs as much as we can: if a request in the batch has
|
# Batch mm inputs as much as we can: if a request in the batch has
|
||||||
# multiple modalities or a different modality than the previous one,
|
# multiple modalities or a different modality than the previous one,
|
||||||
# we process it separately to preserve item order.
|
# we process it separately to preserve item order.
|
||||||
@@ -2278,6 +2290,8 @@ class GPUModelRunner(
|
|||||||
)
|
)
|
||||||
|
|
||||||
encoder_outputs: list[torch.Tensor] = []
|
encoder_outputs: list[torch.Tensor] = []
|
||||||
|
# Track the current index in mm_kwargs/mm_lora_refs to map groups to request IDs
|
||||||
|
current_item_idx = 0
|
||||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||||
mm_kwargs,
|
mm_kwargs,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
@@ -2300,22 +2314,24 @@ class GPUModelRunner(
|
|||||||
and num_items > 1
|
and num_items > 1
|
||||||
):
|
):
|
||||||
curr_group_outputs_lst = list[torch.Tensor]()
|
curr_group_outputs_lst = list[torch.Tensor]()
|
||||||
for video_mm_kwargs_item in filter(
|
for video_idx in range(num_items):
|
||||||
lambda item: item.modality == "video", mm_kwargs
|
video_mm_kwargs_item = mm_kwargs[current_item_idx + video_idx]
|
||||||
):
|
with self.timed_encoder_operation(
|
||||||
_, _, micro_batch_mm_inputs = next(
|
should_time, mm_lora_refs, current_item_idx + video_idx, 1
|
||||||
group_mm_kwargs_by_modality(
|
):
|
||||||
[video_mm_kwargs_item],
|
_, _, micro_batch_mm_inputs = next(
|
||||||
device=self.device,
|
group_mm_kwargs_by_modality(
|
||||||
pin_memory=self.pin_memory,
|
[video_mm_kwargs_item],
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
micro_batch_outputs = model.embed_multimodal(
|
micro_batch_outputs = model.embed_multimodal(
|
||||||
**micro_batch_mm_inputs
|
**micro_batch_mm_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
curr_group_outputs_lst.extend(micro_batch_outputs)
|
curr_group_outputs_lst.extend(micro_batch_outputs)
|
||||||
|
|
||||||
curr_group_outputs = curr_group_outputs_lst
|
curr_group_outputs = curr_group_outputs_lst
|
||||||
else:
|
else:
|
||||||
@@ -2326,7 +2342,11 @@ class GPUModelRunner(
|
|||||||
# 2. A list or tuple (length: num_items) of tensors,
|
# 2. A list or tuple (length: num_items) of tensors,
|
||||||
# each of shape (feature_size, hidden_size) in case the feature
|
# each of shape (feature_size, hidden_size) in case the feature
|
||||||
# size is dynamic depending on the input multimodal items.
|
# size is dynamic depending on the input multimodal items.
|
||||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
|
||||||
|
with self.timed_encoder_operation(
|
||||||
|
should_time, mm_lora_refs, current_item_idx, num_items
|
||||||
|
):
|
||||||
|
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||||
|
|
||||||
sanity_check_mm_encoder_outputs(
|
sanity_check_mm_encoder_outputs(
|
||||||
curr_group_outputs,
|
curr_group_outputs,
|
||||||
@@ -2334,6 +2354,8 @@ class GPUModelRunner(
|
|||||||
)
|
)
|
||||||
encoder_outputs.extend(curr_group_outputs)
|
encoder_outputs.extend(curr_group_outputs)
|
||||||
|
|
||||||
|
current_item_idx += num_items
|
||||||
|
|
||||||
# Cache the encoder outputs by mm_hash
|
# Cache the encoder outputs by mm_hash
|
||||||
for mm_hash, output in zip(mm_hashes, encoder_outputs):
|
for mm_hash, output in zip(mm_hashes, encoder_outputs):
|
||||||
self.encoder_cache[mm_hash] = output
|
self.encoder_cache[mm_hash] = output
|
||||||
@@ -5919,3 +5941,79 @@ class GPUModelRunner(
|
|||||||
self.transfer_event.record()
|
self.transfer_event.record()
|
||||||
self.transfer_event.synchronize()
|
self.transfer_event.synchronize()
|
||||||
return pinned.tolist()
|
return pinned.tolist()
|
||||||
|
|
||||||
|
def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
|
||||||
|
"""
|
||||||
|
Get encoder timing stats for all requests and clear the registry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping request_id to stats dict.
|
||||||
|
"""
|
||||||
|
with self._encoder_timing_lock:
|
||||||
|
stats = {
|
||||||
|
req_id: stats_obj.to_dict()
|
||||||
|
for req_id, stats_obj in self.encoder_timing_registry.items()
|
||||||
|
}
|
||||||
|
self.encoder_timing_registry.clear()
|
||||||
|
return stats
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def timed_encoder_operation(
|
||||||
|
self,
|
||||||
|
should_time: bool,
|
||||||
|
group_lora_refs: list[tuple[str, Any]],
|
||||||
|
current_item_idx: int,
|
||||||
|
num_items: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Context manager to time encoder forward operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
should_time: Whether timing is enabled
|
||||||
|
group_lora_refs: Full list of (request_id, pos_info) tuples
|
||||||
|
current_item_idx: Starting index for this group
|
||||||
|
num_items: Number of items in this group
|
||||||
|
"""
|
||||||
|
if not should_time:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
group_refs = group_lora_refs[current_item_idx : current_item_idx + num_items]
|
||||||
|
group_request_ids = {req_id for req_id, _ in group_refs}
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elapsed = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
per_request_time = elapsed / max(len(group_request_ids), 1)
|
||||||
|
|
||||||
|
with self._encoder_timing_lock:
|
||||||
|
for req_id in group_request_ids:
|
||||||
|
if req_id not in self.encoder_timing_registry:
|
||||||
|
self.encoder_timing_registry[req_id] = EncoderTimingStats()
|
||||||
|
|
||||||
|
stats = self.encoder_timing_registry[req_id]
|
||||||
|
stats.encoder_forward_time += per_request_time
|
||||||
|
stats.num_encoder_calls += 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EncoderTimingStats:
|
||||||
|
"""Per-request timing statistics for encoder forward pass."""
|
||||||
|
|
||||||
|
encoder_forward_time: float = 0.0
|
||||||
|
"""Time spent in vision encoder forward pass (seconds)."""
|
||||||
|
|
||||||
|
num_encoder_calls: int = 0
|
||||||
|
"""Number of times encoder was called for this request."""
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, float | int]:
|
||||||
|
return {
|
||||||
|
"encoder_forward_time": self.encoder_forward_time,
|
||||||
|
"num_encoder_calls": self.num_encoder_calls,
|
||||||
|
}
|
||||||
|
|||||||
@@ -542,6 +542,10 @@ class Worker(WorkerBase):
|
|||||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
return self.model_runner.get_supported_tasks()
|
return self.model_runner.get_supported_tasks()
|
||||||
|
|
||||||
|
def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
|
||||||
|
"""Get encoder timing stats from model runner."""
|
||||||
|
return self.model_runner.get_encoder_timing_stats()
|
||||||
|
|
||||||
def annotate_profile(self, scheduler_output):
|
def annotate_profile(self, scheduler_output):
|
||||||
# add trace annotation so that we can easily distinguish
|
# add trace annotation so that we can easily distinguish
|
||||||
# context/generation request numbers in each iteration.
|
# context/generation request numbers in each iteration.
|
||||||
|
|||||||
Reference in New Issue
Block a user