[Model] Use merge_by_field_config for MM models (G) (#26117)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -168,10 +168,8 @@ class GraniteSpeechMultiModalProcessor(
|
||||
# Calculate the number of audio tokens per entry in the batch;
|
||||
# This is used to split the batch back out after padding.
|
||||
audio_token_index = self.info.get_hf_config().audio_token_index
|
||||
processed_outputs["audio_embed_sizes"] = [
|
||||
torch.sum(indices == audio_token_index).item()
|
||||
for indices in processed_outputs["input_ids"]
|
||||
]
|
||||
processed_outputs["audio_embed_sizes"] = (
|
||||
processed_outputs["input_ids"] == audio_token_index).sum(-1)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@@ -527,6 +525,7 @@ class GraniteSpeechForConditionalGeneration(
|
||||
SupportsPP,
|
||||
SupportsLoRA,
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
|
||||
Reference in New Issue
Block a user