[Refactor] Remove get_encoder_dummy_data (#32241)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-13 17:21:23 +08:00
committed by GitHub
parent 542a4059b2
commit eb28e8068d
6 changed files with 21 additions and 82 deletions

View File

@@ -605,6 +605,10 @@ class NemotronParseProcessingInfo(BaseProcessingInfo):
**kwargs,
)
@property
def skip_prompt_length_check(self) -> bool:
return True # Because the encoder prompt is padded
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": 1}
@@ -657,10 +661,6 @@ class NemotronParseMultiModalProcessor(
) -> str | list[int]:
return [0]
@property
def pad_dummy_encoder_prompt(self) -> bool:
return True
def _call_hf_processor(
self,
prompt: str,

View File

@@ -681,6 +681,10 @@ class WhisperProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig)
@property
def skip_prompt_length_check(self) -> bool:
return True # Because the encoder prompt is padded
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 1}
@@ -733,10 +737,6 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo
target_channels=self.info.get_target_channels(),
)
@property
def pad_dummy_encoder_prompt(self) -> bool:
return True
def create_encoder_prompt(
self,
prompt: str | list[int],

View File

@@ -1396,6 +1396,10 @@ class BaseProcessingInfo:
"""
return self.ctx.get_hf_processor(**kwargs)
@property
def skip_prompt_length_check(self) -> bool:
return False
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
"""
@@ -2403,10 +2407,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""
raise NotImplementedError
@property
def pad_dummy_encoder_prompt(self) -> bool:
return False
def create_decoder_prompt(
self,
prompt: str | list[int],

View File

@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, NamedTuple, TypeVar, cast
from typing import Generic, NamedTuple, TypeVar
import numpy as np
import numpy.typing as npt
@@ -19,7 +19,6 @@ from vllm.logger import init_logger
from .inputs import (
MultiModalDataDict,
MultiModalEncDecInputs,
MultiModalInputs,
MultiModalKwargsItems,
MultiModalPlaceholderDict,
@@ -27,7 +26,6 @@ from .inputs import (
from .processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
EncDecMultiModalProcessor,
)
logger = init_logger(__name__)
@@ -282,28 +280,6 @@ class MultiModalProfiler(Generic[_I]):
for modality, placeholders in placeholders_by_modality.items()
}
def get_encoder_dummy_data(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> DummyEncoderData:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
total_len = len(encoder_prompt_token_ids)
processor = cast(EncDecMultiModalProcessor, self.processor)
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyEncoderData(encoder_prompt_token_ids)
def get_decoder_dummy_data(
self,
seq_len: int,

View File

@@ -18,7 +18,6 @@ from .processing import (
from .profiling import (
BaseDummyInputsBuilder,
DummyDecoderData,
DummyEncoderData,
MultiModalProfiler,
)
@@ -317,43 +316,6 @@ class MultiModalRegistry:
return dummy_data
def get_encoder_dummy_data(
self,
model_config: "ModelConfig",
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
*,
cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> DummyEncoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
"""
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config.
# Only include modalities that use advanced option types so legacy
# count-only behavior remains unchanged.
mm_options = self._extract_mm_options(model_config)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, mm_options)
# Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids
if len(token_ids) < seq_len:
logger.warning_once(
"Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.", # noqa: E501
seq_len,
len(token_ids),
)
return dummy_data
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
"""
Get the maximum length of the encoder input for encoder-decoder models.

View File

@@ -17,7 +17,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import EncDecMultiModalProcessor, set_request_id
from vllm.multimodal.processing import set_request_id
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
@@ -655,17 +655,18 @@ class InputProcessor:
max_prompt_len = self.model_config.max_model_len
if prompt_len > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:
if model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_cls = mm_registry._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = mm_registry._create_processing_ctx(
model_config,
self.vllm_config.observability_config,
tokenizer=tokenizer,
)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
mm_info = factories.info(ctx)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper
if mm_info.skip_prompt_length_check:
return
if model_config.is_multimodal_model:
suggestion = (