[UX] Add --language-model-only for hybrid models (#34120)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2026-02-09 06:57:33 -08:00
committed by GitHub
parent d0d97e2974
commit 64a9c2528b
3 changed files with 19 additions and 3 deletions

View File

@@ -297,6 +297,7 @@ class ModelConfig:
multimodal_config: MultiModalConfig | None = None multimodal_config: MultiModalConfig | None = None
"""Configuration for multimodal model. If `None`, this will be inferred """Configuration for multimodal model. If `None`, this will be inferred
from the architecture of `self.model`.""" from the architecture of `self.model`."""
language_model_only: InitVar[bool] = False
limit_mm_per_prompt: InitVar[dict[str, int | dict[str, int]] | None] = None limit_mm_per_prompt: InitVar[dict[str, int | dict[str, int]] | None] = None
enable_mm_embeds: InitVar[bool | None] = None enable_mm_embeds: InitVar[bool | None] = None
media_io_kwargs: InitVar[dict[str, dict[str, Any]] | None] = None media_io_kwargs: InitVar[dict[str, dict[str, Any]] | None] = None
@@ -411,6 +412,7 @@ class ModelConfig:
def __post_init__( def __post_init__(
self, self,
# Multimodal config init vars # Multimodal config init vars
language_model_only: bool,
limit_mm_per_prompt: dict[str, int | dict[str, int]] | None, limit_mm_per_prompt: dict[str, int | dict[str, int]] | None,
enable_mm_embeds: bool | None, enable_mm_embeds: bool | None,
media_io_kwargs: dict[str, dict[str, Any]] | None, media_io_kwargs: dict[str, dict[str, Any]] | None,
@@ -576,6 +578,7 @@ class ModelConfig:
mm_encoder_tp_mode = "weights" mm_encoder_tp_mode = "weights"
mm_config_kwargs = dict( mm_config_kwargs = dict(
language_model_only=language_model_only,
limit_per_prompt=limit_mm_per_prompt, limit_per_prompt=limit_mm_per_prompt,
enable_mm_embeds=enable_mm_embeds, enable_mm_embeds=enable_mm_embeds,
media_io_kwargs=media_io_kwargs, media_io_kwargs=media_io_kwargs,

View File

@@ -54,8 +54,12 @@ DummyOptions: TypeAlias = (
class MultiModalConfig: class MultiModalConfig:
"""Controls the behavior of multimodal models.""" """Controls the behavior of multimodal models."""
language_model_only: bool = False
"""If True, disables all multimodal inputs by setting all modality limits
to 0. Equivalent to setting --limit-mm-per-prompt to 0 for every
modality."""
limit_per_prompt: dict[str, DummyOptions] = Field(default_factory=dict) limit_per_prompt: dict[str, DummyOptions] = Field(default_factory=dict)
"""The maximum number of input items and options allowed per """The maximum number of input items and options allowed per
prompt for each modality. prompt for each modality.
Defaults to 999 for each modality. Defaults to 999 for each modality.
@@ -63,11 +67,11 @@ class MultiModalConfig:
{"image": 16, "video": 2} {"image": 16, "video": 2}
Configurable format (with options): Configurable format (with options):
{"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512}, {"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512},
"image": {"count": 5, "width": 512, "height": 512}} "image": {"count": 5, "width": 512, "height": 512}}
Mixed format (combining both): Mixed format (combining both):
{"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512, {"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
"height": 512}} "height": 512}}
""" """
enable_mm_embeds: bool = False enable_mm_embeds: bool = False
@@ -215,6 +219,7 @@ class MultiModalConfig:
the final hidden states. the final hidden states.
""" """
factors: list[Any] = [ factors: list[Any] = [
self.language_model_only,
self.mm_encoder_attn_backend.name self.mm_encoder_attn_backend.name
if self.mm_encoder_attn_backend is not None if self.mm_encoder_attn_backend is not None
else None, else None,
@@ -228,6 +233,9 @@ class MultiModalConfig:
Get the maximum number of input items allowed per prompt Get the maximum number of input items allowed per prompt
for the given modality (backward compatible). for the given modality (backward compatible).
""" """
if self.language_model_only:
return 0
limit_data = self.limit_per_prompt.get(modality) limit_data = self.limit_per_prompt.get(modality)
if limit_data is None: if limit_data is None:

View File

@@ -454,6 +454,7 @@ class EngineArgs:
allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
enforce_eager: bool = ModelConfig.enforce_eager enforce_eager: bool = ModelConfig.enforce_eager
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
language_model_only: bool = MultiModalConfig.language_model_only
limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field( limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
MultiModalConfig, "limit_per_prompt" MultiModalConfig, "limit_per_prompt"
) )
@@ -975,6 +976,9 @@ class EngineArgs:
title="MultiModalConfig", title="MultiModalConfig",
description=MultiModalConfig.__doc__, description=MultiModalConfig.__doc__,
) )
multimodal_group.add_argument(
"--language-model-only", **multimodal_kwargs["language_model_only"]
)
multimodal_group.add_argument( multimodal_group.add_argument(
"--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"] "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
) )
@@ -1291,6 +1295,7 @@ class EngineArgs:
skip_tokenizer_init=self.skip_tokenizer_init, skip_tokenizer_init=self.skip_tokenizer_init,
enable_prompt_embeds=self.enable_prompt_embeds, enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
language_model_only=self.language_model_only,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
enable_mm_embeds=self.enable_mm_embeds, enable_mm_embeds=self.enable_mm_embeds,
interleave_mm_strings=self.interleave_mm_strings, interleave_mm_strings=self.interleave_mm_strings,