[Model] Use merge_by_field_config for MM models (H-L) (#26230)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-05 11:54:17 +08:00
committed by GitHub
parent 119f00630b
commit 59a85c366e
6 changed files with 29 additions and 161 deletions

View File

@@ -53,7 +53,7 @@ from .idefics2_vision_model import (
# yapf: enable
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix
class Idefics3ImagePixelInputs(TensorSchema):
@@ -67,7 +67,7 @@ class Idefics3ImagePixelInputs(TensorSchema):
"""
type: Literal["pixel_values"]
pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
pixel_attention_mask: torch.Tensor
pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
@@ -569,6 +569,8 @@ class Idefics3Model(nn.Module):
dummy_inputs=Idefics3DummyInputsBuilder)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -621,37 +623,21 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
data=image_embeds,
)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_attention_mask = kwargs.pop("pixel_attention_mask")
if not isinstance(pixel_attention_mask, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel_attention_mask. "
f"Got type: {type(pixel_attention_mask)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
expected_h = expected_w = self.config.vision_config.image_size
return Idefics3ImagePixelInputs(
type="pixel_values",
pixel_values=flatten_bn(pixel_values, concat=True),
pixel_attention_mask=flatten_bn(pixel_attention_mask,
concat=True),
num_patches=flatten_bn(num_patches, concat=True),
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches,
resolve_bindings={
"h": expected_h,
"w": expected_w