[Refactor] Remove MultiModalProfiler (#32254)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-13 23:10:20 +08:00
committed by GitHub
parent 98f60e5acb
commit 252c011012
4 changed files with 57 additions and 101 deletions

View File

@@ -7,7 +7,6 @@ from torch import prod
from transformers import Llama4Config
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.profiling import MultiModalProfiler
from ...utils import build_model_context
@@ -26,9 +25,8 @@ def test_profiling(model_id: str, max_model_len: int):
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
profiler = MultiModalProfiler(processor)
decoder_dummy_data = profiler.get_decoder_dummy_data(
decoder_dummy_data = processor.dummy_inputs.get_decoder_dummy_data(
processor,
max_model_len,
mm_counts=mm_counts,
)
@@ -39,11 +37,12 @@ def test_profiling(model_id: str, max_model_len: int):
hf_config = ctx.get_hf_config(Llama4Config)
mm_data = processor.apply(
mm_inputs = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"].get_data()
)
mm_data = mm_inputs["mm_kwargs"].get_data()
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
@@ -60,12 +59,9 @@ def test_profiling(model_id: str, max_model_len: int):
total_num_patches.item() + num_tiles.item() + 3
) # image start, image, image end
profiled_tokens = profiler.get_mm_max_tokens(
max_model_len,
mm_counts=mm_counts,
assert total_num_patches == sum(
item.get_num_embeds for item in mm_inputs["mm_placeholders"]["image"]
)
assert total_num_patches == profiled_tokens["image"]
assert total_tokens == sum(
placeholder.length
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]

View File

@@ -22,7 +22,6 @@ from vllm.multimodal.processing import (
iter_token_matches,
replace_token_matches,
)
from vllm.multimodal.profiling import MultiModalProfiler
from .utils import random_image
@@ -924,12 +923,11 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
processor = MULTIMODAL_REGISTRY.create_processor(model_config)
processor._supported_mm_limits = {"image": num_supported}
profiler = MultiModalProfiler(processor)
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
with exc_ctx:
profiler.get_decoder_dummy_data(
processor.dummy_inputs.get_decoder_dummy_data(
processor,
model_config.max_model_len,
mm_counts=limit_mm_per_prompt,
)

View File

@@ -223,70 +223,42 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
return [video] * num_videos
class MultiModalProfiler(Generic[_I]):
"""
Contains code for running memory profiling for multi-modal models.
"""
def __init__(
def get_dummy_mm_inputs(
self,
processor: BaseMultiModalProcessor[_I],
) -> None:
super().__init__()
self.processor = processor
@property
def processing_info(self) -> BaseProcessingInfo:
return self.processor.info
@property
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
return self.processor.dummy_inputs
def get_mm_limits(self) -> Mapping[str, int]:
return self.processor.allowed_mm_limits
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalInputs:
if mm_counts is None:
mm_counts = self.get_mm_limits()
mm_counts = processor.allowed_mm_limits
factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts, mm_options
processor_inputs = self.get_dummy_processor_inputs(
seq_len,
mm_counts=mm_counts,
mm_options=mm_options,
)
return self.processor.apply(
return processor.apply(
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)
def _get_mm_num_tokens(
self,
mm_inputs: MultiModalInputs,
) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"]
return {
modality: sum(item.get_num_embeds for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
def get_decoder_dummy_data(
self,
processor: BaseMultiModalProcessor[_I],
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> DummyDecoderData:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
mm_inputs = self.get_dummy_mm_inputs(
processor,
seq_len,
mm_counts=mm_counts,
mm_options=mm_options,
)
prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)
@@ -299,29 +271,3 @@ class MultiModalProfiler(Generic[_I]):
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)
def get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
"""
Returns the maximum number of embeddings per item of each modality, excluding
any break/text tokens in-between multimodal embeddings/encoder outputs.
"""
if mm_counts is None:
mm_counts = self.get_mm_limits()
max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item(
seq_len=seq_len,
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
return {
modality: max_tokens
for modality, max_tokens in max_tokens_per_item.items()
if mm_counts.get(modality, 0) > 0
}
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs)

View File

@@ -18,7 +18,6 @@ from .processing import (
from .profiling import (
BaseDummyInputsBuilder,
DummyDecoderData,
MultiModalProfiler,
)
if TYPE_CHECKING:
@@ -160,17 +159,37 @@ class MultiModalRegistry:
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len
profiler_limits = (
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
if profiler_limits is None:
profiler_limits = processor.allowed_mm_limits
mm_counts = {
modality: 1 for modality, limit in profiler_limits.items() if limit > 0
}
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
seq_len=seq_len,
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
return {
modality: max_tokens
for modality, max_tokens in max_tokens_per_item.items()
if mm_counts.get(modality, 0) > 0
}
mm_inputs = processor.dummy_inputs.get_dummy_mm_inputs(
processor,
seq_len,
mm_counts=mm_counts,
mm_options=self._extract_mm_options(model_config),
)
return profiler.get_mm_max_tokens(
seq_len,
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
)
return {
modality: sum(item.get_num_embeds for item in placeholders)
for modality, placeholders in mm_inputs["mm_placeholders"].items()
}
def get_mm_limits_per_prompt(
self,
@@ -189,8 +208,7 @@ class MultiModalRegistry:
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()
return processor.allowed_mm_limits
def register_processor(
self,
@@ -297,14 +315,12 @@ class MultiModalRegistry:
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config.
# Only include modalities that use advanced option types so legacy
# count-only behavior remains unchanged.
mm_options = self._extract_mm_options(model_config)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options)
dummy_data = processor.dummy_inputs.get_decoder_dummy_data(
processor,
seq_len,
mm_counts=mm_counts,
mm_options=self._extract_mm_options(model_config),
)
# Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids