[VLM] Reorganize profiling/processing-related code (#11812)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-08 18:59:58 +08:00
committed by GitHub
parent f12141170a
commit 2a0596bc48
23 changed files with 833 additions and 760 deletions

View File

@@ -38,11 +38,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
@@ -80,12 +80,12 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
return feat_lengths, output_lengths
class Qwen2AudioProcessingMixin(ProcessingMixin):
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2AudioConfig)
def _get_hf_processor(
def get_hf_processor(
self,
*,
# Ignored in initialization
@@ -93,36 +93,37 @@ class Qwen2AudioProcessingMixin(ProcessingMixin):
) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
def _get_feature_extractor(
def get_feature_extractor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate)
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
hf_config = self._get_hf_config()
hf_config = self.get_hf_config()
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
return {"audio": max_output_lengths}
class Qwen2AudioDummyInputsBuilder(
BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
@@ -139,14 +140,11 @@ class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
)
class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Qwen2AudioProfilingInfo(self.ctx)
class Qwen2AudioMultiModalProcessor(
BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor()
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
@@ -161,7 +159,7 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
if audios:
mm_data["audios"] = audios
feature_extractor = self._get_feature_extractor(**mm_kwargs)
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
@@ -194,7 +192,7 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
placeholder = hf_config.audio_token_index
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
@@ -234,10 +232,13 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
# has already performed processing for multi-audio input when the input
# audios are short (the corresponding placeholders may take up fewer
# tokens than the number of audio items)
return not hasattr(self._get_hf_processor(), "audio_token")
return not hasattr(self.info.get_hf_processor(), "audio_token")
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(
Qwen2AudioMultiModalProcessor,
info=Qwen2AudioProcessingInfo,
dummy_inputs=Qwen2AudioDummyInputsBuilder)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):