[MM] Allow skipping memory profiling for multimodal models. (#22950)
Signed-off-by: Roger Wang <hey@rogerw.me> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -388,6 +388,10 @@ class ModelConfig:
|
|||||||
interleave_mm_strings: bool = False
|
interleave_mm_strings: bool = False
|
||||||
"""Enable fully interleaved support for multimodal prompts, while using
|
"""Enable fully interleaved support for multimodal prompts, while using
|
||||||
--chat-template-content-format=string. Defaults to False."""
|
--chat-template-content-format=string. Defaults to False."""
|
||||||
|
skip_mm_profiling: bool = False
|
||||||
|
"""When enabled, skips multimodal memory profiling and only profiles with
|
||||||
|
language backbone model during engine initialization.
|
||||||
|
"""
|
||||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||||
"""Additional args passed to process media inputs, keyed by modalities.
|
"""Additional args passed to process media inputs, keyed by modalities.
|
||||||
For example, to set num_frames for video, set
|
For example, to set num_frames for video, set
|
||||||
@@ -837,7 +841,8 @@ class ModelConfig:
|
|||||||
media_io_kwargs=self.media_io_kwargs,
|
media_io_kwargs=self.media_io_kwargs,
|
||||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||||
mm_processor_cache_gb=self.mm_processor_cache_gb,
|
mm_processor_cache_gb=self.mm_processor_cache_gb,
|
||||||
interleave_mm_strings=self.interleave_mm_strings)
|
interleave_mm_strings=self.interleave_mm_strings,
|
||||||
|
skip_mm_profiling=self.skip_mm_profiling)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -2511,6 +2516,16 @@ class MultiModalConfig:
|
|||||||
Enable fully interleaved support for multimodal prompts.
|
Enable fully interleaved support for multimodal prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
skip_mm_profiling: bool = False
|
||||||
|
"""
|
||||||
|
When enabled, skips multimodal memory profiling and only profiles with
|
||||||
|
language backbone model during engine initialization.
|
||||||
|
|
||||||
|
This reduces engine startup time but shifts the responsibility to users for
|
||||||
|
estimating the peak memory usage of the activation of multimodal encoder and
|
||||||
|
embedding cache.
|
||||||
|
"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
|||||||
@@ -350,6 +350,7 @@ class EngineArgs:
|
|||||||
MultiModalConfig.mm_processor_kwargs
|
MultiModalConfig.mm_processor_kwargs
|
||||||
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
||||||
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
|
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
|
||||||
|
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
||||||
# LoRA fields
|
# LoRA fields
|
||||||
enable_lora: bool = False
|
enable_lora: bool = False
|
||||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||||
@@ -716,6 +717,8 @@ class EngineArgs:
|
|||||||
multimodal_group.add_argument(
|
multimodal_group.add_argument(
|
||||||
"--interleave-mm-strings",
|
"--interleave-mm-strings",
|
||||||
**multimodal_kwargs["interleave_mm_strings"])
|
**multimodal_kwargs["interleave_mm_strings"])
|
||||||
|
multimodal_group.add_argument("--skip-mm-profiling",
|
||||||
|
**multimodal_kwargs["skip_mm_profiling"])
|
||||||
|
|
||||||
# LoRA related configs
|
# LoRA related configs
|
||||||
lora_kwargs = get_kwargs(LoRAConfig)
|
lora_kwargs = get_kwargs(LoRAConfig)
|
||||||
@@ -918,6 +921,7 @@ class EngineArgs:
|
|||||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||||
interleave_mm_strings=self.interleave_mm_strings,
|
interleave_mm_strings=self.interleave_mm_strings,
|
||||||
media_io_kwargs=self.media_io_kwargs,
|
media_io_kwargs=self.media_io_kwargs,
|
||||||
|
skip_mm_profiling=self.skip_mm_profiling,
|
||||||
use_async_output_proc=not self.disable_async_output_proc,
|
use_async_output_proc=not self.disable_async_output_proc,
|
||||||
config_format=self.config_format,
|
config_format=self.config_format,
|
||||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||||
|
|||||||
@@ -2479,50 +2479,56 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# Profile with multimodal encoder & encoder cache.
|
# Profile with multimodal encoder & encoder cache.
|
||||||
if self.supports_mm_inputs:
|
if self.supports_mm_inputs:
|
||||||
mm_budget = self.mm_budget
|
if self.model_config.multimodal_config.skip_mm_profiling:
|
||||||
assert mm_budget is not None
|
|
||||||
|
|
||||||
# TODO: handle encoder-decoder models once we support them.
|
|
||||||
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
|
|
||||||
# NOTE: Currently model is profiled with a single non-text
|
|
||||||
# modality with the max possible input tokens even when
|
|
||||||
# it supports multiple.
|
|
||||||
(
|
|
||||||
dummy_modality,
|
|
||||||
max_tokens,
|
|
||||||
) = mm_budget.get_modality_with_max_tokens()
|
|
||||||
(
|
|
||||||
max_mm_items_per_prompt,
|
|
||||||
max_mm_items_per_batch,
|
|
||||||
) = mm_budget.get_max_items(dummy_modality, max_tokens)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Encoder cache will be initialized with a budget of "
|
"Skipping memory profiling for multimodal encoder and "
|
||||||
"%s tokens, and profiled with %s %s items of the maximum "
|
"encoder cache.")
|
||||||
"feature size.",
|
else:
|
||||||
encoder_budget,
|
mm_budget = self.mm_budget
|
||||||
max_mm_items_per_batch,
|
assert mm_budget is not None
|
||||||
dummy_modality,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create dummy batch of multimodal inputs.
|
# TODO: handle encoder-decoder models once we support them.
|
||||||
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
|
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
|
||||||
dummy_modality,
|
# NOTE: Currently model is profiled with a single non-text
|
||||||
max_mm_items_per_batch,
|
# modality with the max possible input tokens even when
|
||||||
)
|
# it supports multiple.
|
||||||
|
(
|
||||||
|
dummy_modality,
|
||||||
|
max_tokens,
|
||||||
|
) = mm_budget.get_modality_with_max_tokens()
|
||||||
|
(
|
||||||
|
max_mm_items_per_prompt,
|
||||||
|
max_mm_items_per_batch,
|
||||||
|
) = mm_budget.get_max_items(dummy_modality, max_tokens)
|
||||||
|
|
||||||
# Run multimodal encoder.
|
logger.info(
|
||||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
"Encoder cache will be initialized with a budget of "
|
||||||
**batched_dummy_mm_inputs)
|
"%s tokens, and profiled with %s %s items of the "
|
||||||
|
"maximum feature size.",
|
||||||
|
encoder_budget,
|
||||||
|
max_mm_items_per_batch,
|
||||||
|
dummy_modality,
|
||||||
|
)
|
||||||
|
|
||||||
sanity_check_mm_encoder_outputs(
|
# Create dummy batch of multimodal inputs.
|
||||||
dummy_encoder_outputs,
|
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
|
||||||
expected_num_items=max_mm_items_per_batch,
|
dummy_modality,
|
||||||
)
|
max_mm_items_per_batch,
|
||||||
|
)
|
||||||
|
|
||||||
# Cache the dummy encoder outputs.
|
# Run multimodal encoder.
|
||||||
self.encoder_cache["tmp"] = dict(
|
dummy_encoder_outputs = \
|
||||||
enumerate(dummy_encoder_outputs))
|
self.model.get_multimodal_embeddings(
|
||||||
|
**batched_dummy_mm_inputs)
|
||||||
|
|
||||||
|
sanity_check_mm_encoder_outputs(
|
||||||
|
dummy_encoder_outputs,
|
||||||
|
expected_num_items=max_mm_items_per_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the dummy encoder outputs.
|
||||||
|
self.encoder_cache["tmp"] = dict(
|
||||||
|
enumerate(dummy_encoder_outputs))
|
||||||
|
|
||||||
# Add `is_profile` here to pre-allocate communication buffers
|
# Add `is_profile` here to pre-allocate communication buffers
|
||||||
hidden_states, last_hidden_states \
|
hidden_states, last_hidden_states \
|
||||||
|
|||||||
@@ -1529,60 +1529,66 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
) -> None:
|
) -> None:
|
||||||
# Profile with multimodal encoder & encoder cache.
|
# Profile with multimodal encoder & encoder cache.
|
||||||
if self.supports_mm_inputs:
|
if self.supports_mm_inputs:
|
||||||
mm_budget = self.mm_budget
|
if self.model_config.multimodal_config.skip_mm_profiling:
|
||||||
assert mm_budget is not None
|
|
||||||
|
|
||||||
# TODO: handle encoder-decoder models once we support them.
|
|
||||||
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
|
|
||||||
# NOTE: Currently model is profiled with a single non-text
|
|
||||||
# modality with the max possible input tokens even when
|
|
||||||
# it supports multiple.
|
|
||||||
(
|
|
||||||
dummy_modality,
|
|
||||||
max_tokens,
|
|
||||||
) = mm_budget.get_modality_with_max_tokens()
|
|
||||||
(
|
|
||||||
max_mm_items_per_prompt,
|
|
||||||
max_mm_items_per_batch,
|
|
||||||
) = mm_budget.get_max_items(dummy_modality, max_tokens)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Encoder cache will be initialized with a budget of "
|
"Skipping memory profiling for multimodal encoder and "
|
||||||
"%s tokens, and profiled with %s %s items of the maximum "
|
"encoder cache.")
|
||||||
"feature size.",
|
else:
|
||||||
encoder_budget,
|
mm_budget = self.mm_budget
|
||||||
max_mm_items_per_batch,
|
assert mm_budget is not None
|
||||||
dummy_modality,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create dummy batch of multimodal inputs.
|
# TODO: handle encoder-decoder models once we support them.
|
||||||
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
|
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
|
||||||
dummy_modality,
|
# NOTE: Currently model is profiled with a single non-text
|
||||||
max_mm_items_per_batch,
|
# modality with the max possible input tokens even when
|
||||||
)
|
# it supports multiple.
|
||||||
|
(
|
||||||
|
dummy_modality,
|
||||||
|
max_tokens,
|
||||||
|
) = mm_budget.get_modality_with_max_tokens()
|
||||||
|
(
|
||||||
|
max_mm_items_per_prompt,
|
||||||
|
max_mm_items_per_batch,
|
||||||
|
) = mm_budget.get_max_items(dummy_modality, max_tokens)
|
||||||
|
|
||||||
# Run multimodal encoder.
|
logger.info(
|
||||||
# Isolate encoder graph from post-processing to minimize
|
"Encoder cache will be initialized with a budget of "
|
||||||
# impact of recompilation until it's fixed.
|
"%s tokens, and profiled with %s %s items of the "
|
||||||
start = time.perf_counter()
|
"maximum feature size.",
|
||||||
xm.mark_step()
|
encoder_budget,
|
||||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
max_mm_items_per_batch,
|
||||||
**batched_dummy_mm_inputs)
|
dummy_modality,
|
||||||
xm.mark_step()
|
)
|
||||||
xm.wait_device_ops()
|
|
||||||
end = time.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
"Multimodal Encoder profiling finished in in %.2f [secs].",
|
|
||||||
end - start)
|
|
||||||
|
|
||||||
sanity_check_mm_encoder_outputs(
|
# Create dummy batch of multimodal inputs.
|
||||||
dummy_encoder_outputs,
|
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
|
||||||
expected_num_items=max_mm_items_per_batch,
|
dummy_modality,
|
||||||
)
|
max_mm_items_per_batch,
|
||||||
|
)
|
||||||
|
|
||||||
# Cache the dummy encoder outputs.
|
# Run multimodal encoder.
|
||||||
self.encoder_cache["tmp"] = dict(
|
# Isolate encoder graph from post-processing to minimize
|
||||||
enumerate(dummy_encoder_outputs))
|
# impact of recompilation until it's fixed.
|
||||||
|
start = time.perf_counter()
|
||||||
|
xm.mark_step()
|
||||||
|
dummy_encoder_outputs = \
|
||||||
|
self.model.get_multimodal_embeddings(
|
||||||
|
**batched_dummy_mm_inputs)
|
||||||
|
xm.mark_step()
|
||||||
|
xm.wait_device_ops()
|
||||||
|
end = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
"Multimodal Encoder profiling finished in %.2f [secs].",
|
||||||
|
end - start)
|
||||||
|
|
||||||
|
sanity_check_mm_encoder_outputs(
|
||||||
|
dummy_encoder_outputs,
|
||||||
|
expected_num_items=max_mm_items_per_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the dummy encoder outputs.
|
||||||
|
self.encoder_cache["tmp"] = dict(
|
||||||
|
enumerate(dummy_encoder_outputs))
|
||||||
|
|
||||||
# Trigger compilation for general shape.
|
# Trigger compilation for general shape.
|
||||||
self._dummy_run(num_tokens, self.num_reqs_max_model_len,
|
self._dummy_run(num_tokens, self.num_reqs_max_model_len,
|
||||||
|
|||||||
Reference in New Issue
Block a user