[Model] Use merge_by_field_config for MM models (Qwen series) (#27546)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -313,6 +313,8 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing
|
||||
dummy_inputs=Qwen2AudioDummyInputsBuilder,
|
||||
)
|
||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("audio"):
|
||||
@@ -346,16 +348,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(
|
||||
self, mm_input: object, name: str
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object
|
||||
) -> Qwen2AudioInputs | None:
|
||||
@@ -367,24 +359,11 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
|
||||
return None
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of audio embeds. Got type: {type(audio_embeds)}"
|
||||
)
|
||||
audio_embeds = self._validate_and_reshape_mm_tensor(
|
||||
audio_embeds, "audio_embeds"
|
||||
)
|
||||
return Qwen2AudioEmbeddingInputs(
|
||||
type="audio_embeds", audio_embeds=audio_embeds
|
||||
)
|
||||
|
||||
if input_features is not None:
|
||||
input_features = self._validate_and_reshape_mm_tensor(
|
||||
input_features, "input_features"
|
||||
)
|
||||
feature_attention_mask = self._validate_and_reshape_mm_tensor(
|
||||
feature_attention_mask, "feature_attention_mask"
|
||||
)
|
||||
return Qwen2AudioFeatureInputs(
|
||||
type="audio_features",
|
||||
input_features=input_features,
|
||||
|
||||
Reference in New Issue
Block a user