[Refactor] Remove MultiModalProfiler (#32254)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -7,7 +7,6 @@ from torch import prod
|
||||
from transformers import Llama4Config
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.profiling import MultiModalProfiler
|
||||
|
||||
from ...utils import build_model_context
|
||||
|
||||
@@ -26,9 +25,8 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
)
|
||||
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
|
||||
decoder_dummy_data = profiler.get_decoder_dummy_data(
|
||||
decoder_dummy_data = processor.dummy_inputs.get_decoder_dummy_data(
|
||||
processor,
|
||||
max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
@@ -39,11 +37,12 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
|
||||
hf_config = ctx.get_hf_config(Llama4Config)
|
||||
|
||||
mm_data = processor.apply(
|
||||
mm_inputs = processor.apply(
|
||||
prompt=dummy_mm_data.prompt,
|
||||
mm_data=dummy_mm_data.mm_data,
|
||||
hf_processor_mm_kwargs=dict(),
|
||||
)["mm_kwargs"].get_data()
|
||||
)
|
||||
mm_data = mm_inputs["mm_kwargs"].get_data()
|
||||
|
||||
image_size = hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
@@ -60,12 +59,9 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
total_num_patches.item() + num_tiles.item() + 3
|
||||
) # image start, image, image end
|
||||
|
||||
profiled_tokens = profiler.get_mm_max_tokens(
|
||||
max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
assert total_num_patches == sum(
|
||||
item.get_num_embeds for item in mm_inputs["mm_placeholders"]["image"]
|
||||
)
|
||||
|
||||
assert total_num_patches == profiled_tokens["image"]
|
||||
assert total_tokens == sum(
|
||||
placeholder.length
|
||||
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
|
||||
|
||||
@@ -22,7 +22,6 @@ from vllm.multimodal.processing import (
|
||||
iter_token_matches,
|
||||
replace_token_matches,
|
||||
)
|
||||
from vllm.multimodal.profiling import MultiModalProfiler
|
||||
|
||||
from .utils import random_image
|
||||
|
||||
@@ -924,12 +923,11 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(model_config)
|
||||
processor._supported_mm_limits = {"image": num_supported}
|
||||
|
||||
profiler = MultiModalProfiler(processor)
|
||||
|
||||
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
|
||||
|
||||
with exc_ctx:
|
||||
profiler.get_decoder_dummy_data(
|
||||
processor.dummy_inputs.get_decoder_dummy_data(
|
||||
processor,
|
||||
model_config.max_model_len,
|
||||
mm_counts=limit_mm_per_prompt,
|
||||
)
|
||||
|
||||
@@ -223,70 +223,42 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
||||
return [video] * num_videos
|
||||
|
||||
|
||||
class MultiModalProfiler(Generic[_I]):
|
||||
"""
|
||||
Contains code for running memory profiling for multi-modal models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def get_dummy_mm_inputs(
|
||||
self,
|
||||
processor: BaseMultiModalProcessor[_I],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.processor = processor
|
||||
|
||||
@property
|
||||
def processing_info(self) -> BaseProcessingInfo:
|
||||
return self.processor.info
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
|
||||
return self.processor.dummy_inputs
|
||||
|
||||
def get_mm_limits(self) -> Mapping[str, int]:
|
||||
return self.processor.allowed_mm_limits
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalInputs:
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
mm_counts = processor.allowed_mm_limits
|
||||
|
||||
factory = self.dummy_inputs
|
||||
processor_inputs = factory.get_dummy_processor_inputs(
|
||||
seq_len, mm_counts, mm_options
|
||||
processor_inputs = self.get_dummy_processor_inputs(
|
||||
seq_len,
|
||||
mm_counts=mm_counts,
|
||||
mm_options=mm_options,
|
||||
)
|
||||
|
||||
return self.processor.apply(
|
||||
return processor.apply(
|
||||
prompt=processor_inputs.prompt,
|
||||
mm_data=processor_inputs.mm_data,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
||||
)
|
||||
|
||||
def _get_mm_num_tokens(
|
||||
self,
|
||||
mm_inputs: MultiModalInputs,
|
||||
) -> Mapping[str, int]:
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
return {
|
||||
modality: sum(item.get_num_embeds for item in placeholders)
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
processor: BaseMultiModalProcessor[_I],
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> DummyDecoderData:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
|
||||
mm_inputs = self.get_dummy_mm_inputs(
|
||||
processor,
|
||||
seq_len,
|
||||
mm_counts=mm_counts,
|
||||
mm_options=mm_options,
|
||||
)
|
||||
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
total_len = len(prompt_token_ids)
|
||||
@@ -299,29 +271,3 @@ class MultiModalProfiler(Generic[_I]):
|
||||
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
def get_mm_max_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Returns the maximum number of embeddings per item of each modality, excluding
|
||||
any break/text tokens in-between multimodal embeddings/encoder outputs.
|
||||
"""
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item(
|
||||
seq_len=seq_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
if max_tokens_per_item is not None:
|
||||
return {
|
||||
modality: max_tokens
|
||||
for modality, max_tokens in max_tokens_per_item.items()
|
||||
if mm_counts.get(modality, 0) > 0
|
||||
}
|
||||
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
return self._get_mm_num_tokens(mm_inputs)
|
||||
|
||||
@@ -18,7 +18,6 @@ from .processing import (
|
||||
from .profiling import (
|
||||
BaseDummyInputsBuilder,
|
||||
DummyDecoderData,
|
||||
MultiModalProfiler,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -160,17 +159,37 @@ class MultiModalRegistry:
|
||||
processor = self.create_processor(
|
||||
model_config, observability_config, cache=cache
|
||||
)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
|
||||
seq_len = model_config.max_model_len
|
||||
profiler_limits = (
|
||||
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
|
||||
if profiler_limits is None:
|
||||
profiler_limits = processor.allowed_mm_limits
|
||||
|
||||
mm_counts = {
|
||||
modality: 1 for modality, limit in profiler_limits.items() if limit > 0
|
||||
}
|
||||
|
||||
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
|
||||
seq_len=seq_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
if max_tokens_per_item is not None:
|
||||
return {
|
||||
modality: max_tokens
|
||||
for modality, max_tokens in max_tokens_per_item.items()
|
||||
if mm_counts.get(modality, 0) > 0
|
||||
}
|
||||
|
||||
mm_inputs = processor.dummy_inputs.get_dummy_mm_inputs(
|
||||
processor,
|
||||
seq_len,
|
||||
mm_counts=mm_counts,
|
||||
mm_options=self._extract_mm_options(model_config),
|
||||
)
|
||||
|
||||
return profiler.get_mm_max_tokens(
|
||||
seq_len,
|
||||
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
|
||||
)
|
||||
return {
|
||||
modality: sum(item.get_num_embeds for item in placeholders)
|
||||
for modality, placeholders in mm_inputs["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
def get_mm_limits_per_prompt(
|
||||
self,
|
||||
@@ -189,8 +208,7 @@ class MultiModalRegistry:
|
||||
processor = self.create_processor(
|
||||
model_config, observability_config, cache=cache
|
||||
)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
return profiler.get_mm_limits()
|
||||
return processor.allowed_mm_limits
|
||||
|
||||
def register_processor(
|
||||
self,
|
||||
@@ -297,14 +315,12 @@ class MultiModalRegistry:
|
||||
processor = self.create_processor(
|
||||
model_config, observability_config, cache=cache
|
||||
)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
|
||||
# Extract configurable options from multimodal config.
|
||||
# Only include modalities that use advanced option types so legacy
|
||||
# count-only behavior remains unchanged.
|
||||
mm_options = self._extract_mm_options(model_config)
|
||||
|
||||
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options)
|
||||
dummy_data = processor.dummy_inputs.get_decoder_dummy_data(
|
||||
processor,
|
||||
seq_len,
|
||||
mm_counts=mm_counts,
|
||||
mm_options=self._extract_mm_options(model_config),
|
||||
)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
|
||||
Reference in New Issue
Block a user