[VLM] Reorganize profiling/processing-related code (#11812)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -38,11 +38,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser
|
||||
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessingMixin,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
@@ -80,12 +80,12 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||
return feat_lengths, output_lengths
|
||||
|
||||
|
||||
class Qwen2AudioProcessingMixin(ProcessingMixin):
|
||||
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def _get_hf_config(self):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
|
||||
def _get_hf_processor(
|
||||
def get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
# Ignored in initialization
|
||||
@@ -93,36 +93,37 @@ class Qwen2AudioProcessingMixin(ProcessingMixin):
|
||||
) -> Qwen2AudioProcessor:
|
||||
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
|
||||
|
||||
def _get_feature_extractor(
|
||||
def get_feature_extractor(
|
||||
self,
|
||||
*,
|
||||
# Ignored in initialization
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> WhisperFeatureExtractor:
|
||||
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate)
|
||||
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
|
||||
feature_extractor = hf_processor.feature_extractor # type: ignore
|
||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||
return feature_extractor
|
||||
|
||||
|
||||
class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
hf_config = self._get_hf_config()
|
||||
hf_config = self.get_hf_config()
|
||||
max_source_positions = hf_config.audio_config.max_source_positions
|
||||
max_output_lengths = (max_source_positions - 2) // 2 + 1
|
||||
|
||||
return {"audio": max_output_lengths}
|
||||
|
||||
|
||||
class Qwen2AudioDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
@@ -139,14 +140,11 @@ class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
|
||||
)
|
||||
|
||||
|
||||
class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return Qwen2AudioProfilingInfo(self.ctx)
|
||||
class Qwen2AudioMultiModalProcessor(
|
||||
BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||
|
||||
def _call_hf_processor(
|
||||
@@ -161,7 +159,7 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
|
||||
if audios:
|
||||
mm_data["audios"] = audios
|
||||
|
||||
feature_extractor = self._get_feature_extractor(**mm_kwargs)
|
||||
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
@@ -194,7 +192,7 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self._get_hf_config()
|
||||
hf_config = self.info.get_hf_config()
|
||||
placeholder = hf_config.audio_token_index
|
||||
|
||||
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
|
||||
@@ -234,10 +232,13 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
|
||||
# has already performed processing for multi-audio input when the input
|
||||
# audios are short (the corresponding placeholders may take up fewer
|
||||
# tokens than the number of audio items)
|
||||
return not hasattr(self._get_hf_processor(), "audio_token")
|
||||
return not hasattr(self.info.get_hf_processor(), "audio_token")
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2AudioMultiModalProcessor,
|
||||
info=Qwen2AudioProcessingInfo,
|
||||
dummy_inputs=Qwen2AudioDummyInputsBuilder)
|
||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user