[Model] Use merge_by_field_config for MM models (D-F) (#26076)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-02 23:17:35 +08:00
committed by GitHub
parent 7d6fb905d9
commit cc253b73d3
4 changed files with 102 additions and 180 deletions

View File

@@ -59,17 +59,14 @@ class FuyuImagePatchInputs(TensorSchema):
type: Literal["image_patches"] = "image_patches"
flat_data: Annotated[
torch.Tensor,
TensorShape("bnp", "fn"),
]
image_patches_flat: Annotated[torch.Tensor, TensorShape("bnp", "fn")]
patches_per_image: Annotated[list[int], TensorShape("bn")]
"""
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
flattened just like `image_patches_flat`.
"""
@@ -174,28 +171,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
tok_kwargs=tok_kwargs,
)
image_patches = processed_outputs.get("image_patches")
if image_patches is not None:
images = mm_data["images"]
assert isinstance(images, list)
# Original output: (1, num_images, Pn, Px * Py * C)
# New output: (num_images, Pn, Px * Py * C)
# image_patches is a list with shape:
# (1, num_images, Pn, Px * Py * C)
# before Transformers 4.53
if isinstance(image_patches, list):
assert len(image_patches) == 1
assert (isinstance(image_patches[0], torch.Tensor)
and len(image_patches[0]) == len(images))
processed_outputs["image_patches"] = image_patches[0]
# image_patches is a tensor with shape:
# (num_images, Pn, Px * Py * C)
# after Transformers 4.53
elif isinstance(image_patches, torch.Tensor):
assert len(image_patches) == len(images)
else:
raise AssertionError("This line should be unreachable.")
image_patches = processed_outputs["image_patches"]
processed_outputs["image_patches"] = flatten_bn(image_patches)
processed_outputs["patches_per_image"] = torch.tensor(
[len(p) for p in image_patches])
return processed_outputs
@@ -218,7 +197,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image"))
patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
return dict(
image_patches=MultiModalFieldConfig.flat_from_sizes(
"image", patches_per_image),
patches_per_image=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
@@ -263,6 +248,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
info=FuyuProcessingInfo,
dummy_inputs=FuyuDummyInputsBuilder)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@@ -306,29 +292,28 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
image_patches = kwargs.pop("image_patches", None)
if image_patches is not None:
image_patches_flat = flatten_bn(image_patches)
flat_data = flatten_bn(image_patches_flat, concat=True)
patches_per_image = kwargs.pop("patches_per_image", None)
return FuyuImagePatchInputs(
type="image_patches",
flat_data=flat_data,
patches_per_image=[x.size(0) for x in image_patches_flat],
resolve_bindings={"fn": self.image_feature_size},
)
if image_patches is None:
return None
return None
return FuyuImagePatchInputs(
type="image_patches",
image_patches_flat=image_patches,
patches_per_image=patches_per_image,
resolve_bindings={"fn": self.image_feature_size},
)
def _process_image_input(
self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
image_patches_flat = image_input["image_patches_flat"]
patches_per_image = image_input["patches_per_image"]
assert self.vision_embed_tokens is not None
vision_embeddings_flat, _ = self.vision_embed_tokens(
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
def get_language_model(self) -> torch.nn.Module:
return self.language_model