[Refactor] Define MM data parser in processing info instead of processor itself (#33260)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -127,6 +127,30 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||
return feat_lengths, output_lengths
|
||||
|
||||
|
||||
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
return dict(
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
|
||||
class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="audio",
|
||||
required_fields={"audio_embeds"},
|
||||
fields_factory=_qwen2audio_field_config,
|
||||
)
|
||||
|
||||
return super()._parse_audio_data(data)
|
||||
|
||||
|
||||
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
@@ -140,6 +164,15 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
|
||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||
return feature_extractor
|
||||
|
||||
def get_data_parser(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
return Qwen2AudioMultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
target_channels=self.get_target_channels(),
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
|
||||
def get_target_channels(self) -> int:
|
||||
"""Return target audio channels for Qwen2 Audio models (mono)."""
|
||||
return 1
|
||||
@@ -178,38 +211,7 @@ class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingIn
|
||||
}
|
||||
|
||||
|
||||
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
return dict(
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
|
||||
class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="audio",
|
||||
required_fields={"audio_embeds"},
|
||||
fields_factory=_qwen2audio_field_config,
|
||||
)
|
||||
|
||||
return super()._parse_audio_data(data)
|
||||
|
||||
|
||||
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return Qwen2AudioMultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
target_channels=self.info.get_target_channels(),
|
||||
)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
Reference in New Issue
Block a user