[Refactor] Remove get_encoder_dummy_data (#32241)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -605,6 +605,10 @@ class NemotronParseProcessingInfo(BaseProcessingInfo):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def skip_prompt_length_check(self) -> bool:
|
||||
return True # Because the encoder prompt is padded
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": 1}
|
||||
|
||||
@@ -657,10 +661,6 @@ class NemotronParseMultiModalProcessor(
|
||||
) -> str | list[int]:
|
||||
return [0]
|
||||
|
||||
@property
|
||||
def pad_dummy_encoder_prompt(self) -> bool:
|
||||
return True
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -681,6 +681,10 @@ class WhisperProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self) -> WhisperConfig:
|
||||
return self.ctx.get_hf_config(WhisperConfig)
|
||||
|
||||
@property
|
||||
def skip_prompt_length_check(self) -> bool:
|
||||
return True # Because the encoder prompt is padded
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": 1}
|
||||
|
||||
@@ -733,10 +737,6 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo
|
||||
target_channels=self.info.get_target_channels(),
|
||||
)
|
||||
|
||||
@property
|
||||
def pad_dummy_encoder_prompt(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_encoder_prompt(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
|
||||
@@ -1396,6 +1396,10 @@ class BaseProcessingInfo:
|
||||
"""
|
||||
return self.ctx.get_hf_processor(**kwargs)
|
||||
|
||||
@property
|
||||
def skip_prompt_length_check(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
"""
|
||||
@@ -2403,10 +2407,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def pad_dummy_encoder_prompt(self) -> bool:
|
||||
return False
|
||||
|
||||
def create_decoder_prompt(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, NamedTuple, TypeVar, cast
|
||||
from typing import Generic, NamedTuple, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -19,7 +19,6 @@ from vllm.logger import init_logger
|
||||
|
||||
from .inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalEncDecInputs,
|
||||
MultiModalInputs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalPlaceholderDict,
|
||||
@@ -27,7 +26,6 @@ from .inputs import (
|
||||
from .processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -282,28 +280,6 @@ class MultiModalProfiler(Generic[_I]):
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
|
||||
def get_encoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> DummyEncoderData:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
|
||||
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
|
||||
|
||||
# For encoder-decoder models, use encoder prompt token ids instead of
|
||||
# decoder prompt to construct dummy seq_data for encoder profiling.
|
||||
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
|
||||
|
||||
total_len = len(encoder_prompt_token_ids)
|
||||
|
||||
processor = cast(EncDecMultiModalProcessor, self.processor)
|
||||
if processor.pad_dummy_encoder_prompt:
|
||||
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
||||
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
|
||||
|
||||
return DummyEncoderData(encoder_prompt_token_ids)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
|
||||
@@ -18,7 +18,6 @@ from .processing import (
|
||||
from .profiling import (
|
||||
BaseDummyInputsBuilder,
|
||||
DummyDecoderData,
|
||||
DummyEncoderData,
|
||||
MultiModalProfiler,
|
||||
)
|
||||
|
||||
@@ -317,43 +316,6 @@ class MultiModalRegistry:
|
||||
|
||||
return dummy_data
|
||||
|
||||
def get_encoder_dummy_data(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
observability_config: ObservabilityConfig | None = None,
|
||||
) -> DummyEncoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by `model_config`.
|
||||
"""
|
||||
processor = self.create_processor(
|
||||
model_config, observability_config, cache=cache
|
||||
)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
|
||||
# Extract configurable options from multimodal config.
|
||||
# Only include modalities that use advanced option types so legacy
|
||||
# count-only behavior remains unchanged.
|
||||
mm_options = self._extract_mm_options(model_config)
|
||||
|
||||
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, mm_options)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
if len(token_ids) < seq_len:
|
||||
logger.warning_once(
|
||||
"Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.", # noqa: E501
|
||||
seq_len,
|
||||
len(token_ids),
|
||||
)
|
||||
|
||||
return dummy_data
|
||||
|
||||
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
|
||||
"""
|
||||
Get the maximum length of the encoder input for encoder-decoder models.
|
||||
|
||||
@@ -17,7 +17,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import processor_cache_from_config
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor, set_request_id
|
||||
from vllm.multimodal.processing import set_request_id
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
|
||||
@@ -655,17 +655,18 @@ class InputProcessor:
|
||||
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
if prompt_len > max_prompt_len:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
if model_config.is_multimodal_model:
|
||||
mm_registry = self.input_preprocessor.mm_registry
|
||||
mm_processor = mm_registry.create_processor(
|
||||
model_cls = mm_registry._get_model_cls(model_config)
|
||||
factories = model_cls._processor_factory
|
||||
ctx = mm_registry._create_processing_ctx(
|
||||
model_config,
|
||||
self.vllm_config.observability_config,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
assert isinstance(mm_processor, EncDecMultiModalProcessor)
|
||||
mm_info = factories.info(ctx)
|
||||
|
||||
if mm_processor.pad_dummy_encoder_prompt:
|
||||
return # Skip encoder length check for Whisper
|
||||
if mm_info.skip_prompt_length_check:
|
||||
return
|
||||
|
||||
if model_config.is_multimodal_model:
|
||||
suggestion = (
|
||||
|
||||
Reference in New Issue
Block a user