[Refactor] Move top-level dummy data generation to registry (#32310)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -24,32 +24,20 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
limit_mm_per_prompt=mm_counts,
|
||||
)
|
||||
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
decoder_dummy_data = processor.dummy_inputs.get_decoder_dummy_data(
|
||||
processor,
|
||||
max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
|
||||
max_model_len,
|
||||
mm_inputs = MULTIMODAL_REGISTRY.get_dummy_mm_inputs(
|
||||
ctx.model_config,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
|
||||
hf_config = ctx.get_hf_config(Llama4Config)
|
||||
|
||||
mm_inputs = processor.apply(
|
||||
prompt=dummy_mm_data.prompt,
|
||||
mm_data=dummy_mm_data.mm_data,
|
||||
hf_processor_mm_kwargs=dict(),
|
||||
)
|
||||
mm_data = mm_inputs["mm_kwargs"].get_data()
|
||||
|
||||
image_size = hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
downsample_ratio = int(
|
||||
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))
|
||||
)
|
||||
tokens_per_patch = ((image_size // patch_size) ** 2) // downsample_ratio
|
||||
|
||||
mm_data = mm_inputs["mm_kwargs"].get_data()
|
||||
chunks_per_image = prod(mm_data["patches_per_image"])
|
||||
total_num_patches = chunks_per_image * tokens_per_patch
|
||||
num_tiles = (
|
||||
@@ -63,6 +51,5 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
item.get_num_embeds for item in mm_inputs["mm_placeholders"]["image"]
|
||||
)
|
||||
assert total_tokens == sum(
|
||||
placeholder.length
|
||||
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
|
||||
placeholder.length for placeholder in mm_inputs["mm_placeholders"]["image"]
|
||||
)
|
||||
|
||||
@@ -926,10 +926,10 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
|
||||
|
||||
with exc_ctx:
|
||||
processor.dummy_inputs.get_decoder_dummy_data(
|
||||
processor,
|
||||
model_config.max_model_len,
|
||||
MULTIMODAL_REGISTRY.get_dummy_mm_inputs(
|
||||
model_config,
|
||||
mm_counts=limit_mm_per_prompt,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ from .parse import (
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from .profiling import BaseDummyInputsBuilder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
@@ -59,7 +60,6 @@ if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, ObservabilityConfig
|
||||
|
||||
from .cache import BaseMultiModalProcessorCache
|
||||
from .profiling import BaseDummyInputsBuilder
|
||||
else:
|
||||
PretrainedConfig = object
|
||||
BatchFeature = object
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, NamedTuple, TypeVar
|
||||
from typing import TYPE_CHECKING, Generic
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -17,16 +17,14 @@ from vllm.config.multimodal import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalInputs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalPlaceholderDict,
|
||||
)
|
||||
from .processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
)
|
||||
from .inputs import MultiModalDataDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .processing import _I
|
||||
else:
|
||||
from typing import TypeVar
|
||||
|
||||
_I = TypeVar("_I")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -44,23 +42,6 @@ class ProcessorInputs:
|
||||
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DummyEncoderData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
|
||||
|
||||
class DummyDecoderData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
multi_modal_data: MultiModalKwargsItems
|
||||
multi_modal_placeholders: MultiModalPlaceholderDict
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
|
||||
|
||||
class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
"""
|
||||
Abstract base class that constructs the dummy data to profile
|
||||
@@ -222,52 +203,3 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
height = min(height, overrides.height)
|
||||
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
||||
return [video] * num_videos
|
||||
|
||||
def get_dummy_mm_inputs(
|
||||
self,
|
||||
processor: BaseMultiModalProcessor[_I],
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalInputs:
|
||||
if mm_counts is None:
|
||||
mm_counts = processor.allowed_mm_limits
|
||||
|
||||
processor_inputs = self.get_dummy_processor_inputs(
|
||||
seq_len,
|
||||
mm_counts=mm_counts,
|
||||
mm_options=mm_options,
|
||||
)
|
||||
|
||||
return processor.apply(
|
||||
prompt=processor_inputs.prompt,
|
||||
mm_data=processor_inputs.mm_data,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
||||
)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
processor: BaseMultiModalProcessor[_I],
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> DummyDecoderData:
|
||||
mm_inputs = self.get_dummy_mm_inputs(
|
||||
processor,
|
||||
seq_len,
|
||||
mm_counts=mm_counts,
|
||||
mm_options=mm_options,
|
||||
)
|
||||
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
total_len = len(prompt_token_ids)
|
||||
|
||||
if total_len < seq_len:
|
||||
prompt_token_ids.extend([0] * (seq_len - total_len))
|
||||
|
||||
return DummyDecoderData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
@@ -10,15 +10,13 @@ from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
|
||||
from .cache import BaseMultiModalProcessorCache
|
||||
from .inputs import MultiModalInputs
|
||||
from .processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
)
|
||||
from .profiling import (
|
||||
BaseDummyInputsBuilder,
|
||||
DummyDecoderData,
|
||||
)
|
||||
from .profiling import BaseDummyInputsBuilder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, ObservabilityConfig
|
||||
@@ -160,7 +158,6 @@ class MultiModalRegistry:
|
||||
model_config, observability_config, cache=cache
|
||||
)
|
||||
|
||||
seq_len = model_config.max_model_len
|
||||
if profiler_limits is None:
|
||||
profiler_limits = processor.allowed_mm_limits
|
||||
|
||||
@@ -169,7 +166,7 @@ class MultiModalRegistry:
|
||||
}
|
||||
|
||||
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
|
||||
seq_len=seq_len,
|
||||
seq_len=model_config.max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
if max_tokens_per_item is not None:
|
||||
@@ -179,11 +176,10 @@ class MultiModalRegistry:
|
||||
if mm_counts.get(modality, 0) > 0
|
||||
}
|
||||
|
||||
mm_inputs = processor.dummy_inputs.get_dummy_mm_inputs(
|
||||
processor,
|
||||
seq_len,
|
||||
mm_inputs = self.get_dummy_mm_inputs(
|
||||
model_config,
|
||||
mm_counts=mm_counts,
|
||||
mm_options=self._extract_mm_options(model_config),
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -298,39 +294,47 @@ class MultiModalRegistry:
|
||||
|
||||
return factories.build_processor(ctx, cache=cache)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
def get_dummy_mm_inputs(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
observability_config: ObservabilityConfig | None = None,
|
||||
) -> DummyDecoderData:
|
||||
processor: BaseMultiModalProcessor | None = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by `model_config`.
|
||||
"""
|
||||
processor = self.create_processor(
|
||||
model_config, observability_config, cache=cache
|
||||
)
|
||||
dummy_data = processor.dummy_inputs.get_decoder_dummy_data(
|
||||
processor,
|
||||
seq_len,
|
||||
seq_len = model_config.max_model_len
|
||||
|
||||
if processor is None:
|
||||
processor = self.create_processor(
|
||||
model_config, observability_config, cache=cache
|
||||
)
|
||||
if mm_counts is None:
|
||||
mm_counts = processor.allowed_mm_limits
|
||||
|
||||
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
|
||||
seq_len=seq_len,
|
||||
mm_counts=mm_counts,
|
||||
mm_options=self._extract_mm_options(model_config),
|
||||
)
|
||||
mm_inputs = processor.apply(
|
||||
prompt=processor_inputs.prompt,
|
||||
mm_data=processor_inputs.mm_data,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
||||
)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
if len(token_ids) < seq_len:
|
||||
raise AssertionError(
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but found {len(token_ids)} tokens instead."
|
||||
)
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
total_len = len(prompt_token_ids)
|
||||
if total_len < seq_len:
|
||||
prompt_token_ids.extend([0] * (seq_len - total_len))
|
||||
|
||||
return dummy_data
|
||||
return mm_inputs
|
||||
|
||||
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
|
||||
"""
|
||||
|
||||
@@ -4192,16 +4192,18 @@ class GPUModelRunner(
|
||||
"""Dummy data for profiling and precompiling multimodal models."""
|
||||
assert self.mm_budget is not None
|
||||
|
||||
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
|
||||
model_config=self.model_config,
|
||||
seq_len=self.max_model_len,
|
||||
# Don't use `max_items_per_batch` here to avoid redundant computation
|
||||
dummy_mm_inputs = self.mm_registry.get_dummy_mm_inputs(
|
||||
self.model_config,
|
||||
mm_counts={modality: 1},
|
||||
cache=self.mm_budget.cache,
|
||||
)
|
||||
dummy_mm_data = dummy_decoder_data.multi_modal_data
|
||||
dummy_mm_item = dummy_mm_inputs["mm_kwargs"][modality][0]
|
||||
|
||||
# We use the cache so that the item is saved to the cache,
|
||||
# but not read from the cache
|
||||
assert dummy_mm_item is not None, "Item should not already be cached"
|
||||
|
||||
# Result in the maximum GPU consumption of the model
|
||||
dummy_mm_item = dummy_mm_data[modality][0]
|
||||
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
|
||||
|
||||
return next(
|
||||
|
||||
Reference in New Issue
Block a user