[Model] Use merge_by_field_config for MM models (M-N) (#26710)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-14 01:27:01 +08:00
committed by GitHub
parent e3b90c1ba2
commit afc47e4de7
11 changed files with 127 additions and 331 deletions

View File

@@ -71,7 +71,7 @@ from .interfaces import (
SupportsPP,
)
from .llama4 import Llama4ForCausalLM
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix
from .vision import run_dp_sharded_vision_model
@@ -86,7 +86,7 @@ class Llama4ImagePatchInputs(TensorSchema):
type: Literal["pixel_values"] = "pixel_values"
flat_data: Annotated[
pixel_values: Annotated[
torch.Tensor,
TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"),
]
@@ -96,7 +96,7 @@ class Llama4ImagePatchInputs(TensorSchema):
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
flattened just like `pixel_values`.
"""
aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
@@ -725,6 +725,8 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
class Llama4ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -798,17 +800,12 @@ class Llama4ForConditionalGeneration(
if pixel_values is None:
return None
# num_images x num_chunks, channel, image_size, image_size
# TODO: confirm handling for variable lengths
flat_pixel_values = flatten_bn(pixel_values, concat=True)
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
patches_per_image = kwargs.pop("patches_per_image")
aspect_ratios = kwargs.pop("aspect_ratios")
if aspect_ratios.ndim == 3:
aspect_ratios = aspect_ratios.squeeze(1)
return Llama4ImagePatchInputs(
type="pixel_values",
flat_data=flat_pixel_values,
pixel_values=pixel_values,
patches_per_image=patches_per_image,
aspect_ratios=aspect_ratios,
)
@@ -817,16 +814,16 @@ class Llama4ForConditionalGeneration(
self, image_input: Llama4ImagePatchInputs
) -> MultiModalEmbeddings:
assert self.vision_model and self.multi_modal_projector
flat_data = image_input["flat_data"]
pixel_values = image_input["pixel_values"]
patches_per_image = image_input["patches_per_image"].tolist()
# shard image input
if self.use_data_parallel:
vision_embeddings_flat = run_dp_sharded_vision_model(
flat_data, self.vision_model
pixel_values, self.vision_model
)
else:
vision_embeddings_flat = self.vision_model(flat_data)
vision_embeddings_flat = self.vision_model(pixel_values)
vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat)