[Refactor] Define MM data parser in processing info instead of processor itself (#33260)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-29 13:55:17 +08:00
committed by GitHub
parent 07ea184f00
commit 51550179fc
34 changed files with 399 additions and 347 deletions

View File

@@ -127,6 +127,30 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
return feat_lengths, output_lengths
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
return dict(
audio_embeds=MultiModalFieldConfig.batched("audio"),
input_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
)
class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_qwen2audio_field_config,
)
return super()._parse_audio_data(data)
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2AudioConfig)
@@ -140,6 +164,15 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return Qwen2AudioMultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
target_channels=self.get_target_channels(),
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_target_channels(self) -> int:
"""Return target audio channels for Qwen2 Audio models (mono)."""
return 1
@@ -178,38 +211,7 @@ class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingIn
}
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
return dict(
audio_embeds=MultiModalFieldConfig.batched("audio"),
input_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
)
class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_qwen2audio_field_config,
)
return super()._parse_audio_data(data)
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return Qwen2AudioMultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
target_channels=self.info.get_target_channels(),
)
def _call_hf_processor(
self,
prompt: str,