[Model] Use merge_by_field_config for MM models (M-N) (#26710)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -71,7 +71,7 @@ from .interfaces import (
|
||||
SupportsPP,
|
||||
)
|
||||
from .llama4 import Llama4ForCausalLM
|
||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ class Llama4ImagePatchInputs(TensorSchema):
|
||||
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
|
||||
flat_data: Annotated[
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"),
|
||||
]
|
||||
@@ -96,7 +96,7 @@ class Llama4ImagePatchInputs(TensorSchema):
|
||||
The number of total patches for each image in the batch.
|
||||
|
||||
This is used to split the embeddings which has the first two dimensions
|
||||
flattened just like `flat_data`.
|
||||
flattened just like `pixel_values`.
|
||||
"""
|
||||
|
||||
aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
|
||||
@@ -725,6 +725,8 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
|
||||
class Llama4ForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
@@ -798,17 +800,12 @@ class Llama4ForConditionalGeneration(
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
# num_images x num_chunks, channel, image_size, image_size
|
||||
# TODO: confirm handling for variable lengths
|
||||
flat_pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
|
||||
patches_per_image = kwargs.pop("patches_per_image")
|
||||
aspect_ratios = kwargs.pop("aspect_ratios")
|
||||
if aspect_ratios.ndim == 3:
|
||||
aspect_ratios = aspect_ratios.squeeze(1)
|
||||
|
||||
return Llama4ImagePatchInputs(
|
||||
type="pixel_values",
|
||||
flat_data=flat_pixel_values,
|
||||
pixel_values=pixel_values,
|
||||
patches_per_image=patches_per_image,
|
||||
aspect_ratios=aspect_ratios,
|
||||
)
|
||||
@@ -817,16 +814,16 @@ class Llama4ForConditionalGeneration(
|
||||
self, image_input: Llama4ImagePatchInputs
|
||||
) -> MultiModalEmbeddings:
|
||||
assert self.vision_model and self.multi_modal_projector
|
||||
flat_data = image_input["flat_data"]
|
||||
pixel_values = image_input["pixel_values"]
|
||||
patches_per_image = image_input["patches_per_image"].tolist()
|
||||
|
||||
# shard image input
|
||||
if self.use_data_parallel:
|
||||
vision_embeddings_flat = run_dp_sharded_vision_model(
|
||||
flat_data, self.vision_model
|
||||
pixel_values, self.vision_model
|
||||
)
|
||||
else:
|
||||
vision_embeddings_flat = self.vision_model(flat_data)
|
||||
vision_embeddings_flat = self.vision_model(pixel_values)
|
||||
|
||||
vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user