[Model] Use merge_by_field_config for MM models (D-F) (#26076)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -59,17 +59,14 @@ class FuyuImagePatchInputs(TensorSchema):
|
||||
|
||||
type: Literal["image_patches"] = "image_patches"
|
||||
|
||||
flat_data: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bnp", "fn"),
|
||||
]
|
||||
image_patches_flat: Annotated[torch.Tensor, TensorShape("bnp", "fn")]
|
||||
|
||||
patches_per_image: Annotated[list[int], TensorShape("bn")]
|
||||
"""
|
||||
The number of total patches for each image in the batch.
|
||||
|
||||
This is used to split the embeddings which has the first two dimensions
|
||||
flattened just like `flat_data`.
|
||||
flattened just like `image_patches_flat`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -174,28 +171,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
image_patches = processed_outputs.get("image_patches")
|
||||
if image_patches is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
# Original output: (1, num_images, Pn, Px * Py * C)
|
||||
# New output: (num_images, Pn, Px * Py * C)
|
||||
# image_patches is a list with shape:
|
||||
# (1, num_images, Pn, Px * Py * C)
|
||||
# before Transformers 4.53
|
||||
if isinstance(image_patches, list):
|
||||
assert len(image_patches) == 1
|
||||
assert (isinstance(image_patches[0], torch.Tensor)
|
||||
and len(image_patches[0]) == len(images))
|
||||
processed_outputs["image_patches"] = image_patches[0]
|
||||
# image_patches is a tensor with shape:
|
||||
# (num_images, Pn, Px * Py * C)
|
||||
# after Transformers 4.53
|
||||
elif isinstance(image_patches, torch.Tensor):
|
||||
assert len(image_patches) == len(images)
|
||||
else:
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
image_patches = processed_outputs["image_patches"]
|
||||
processed_outputs["image_patches"] = flatten_bn(image_patches)
|
||||
processed_outputs["patches_per_image"] = torch.tensor(
|
||||
[len(p) for p in image_patches])
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@@ -218,7 +197,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(image_patches=MultiModalFieldConfig.batched("image"))
|
||||
patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
image_patches=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", patches_per_image),
|
||||
patches_per_image=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@@ -263,6 +248,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
info=FuyuProcessingInfo,
|
||||
dummy_inputs=FuyuDummyInputsBuilder)
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
@@ -306,29 +292,28 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
||||
image_patches = kwargs.pop("image_patches", None)
|
||||
if image_patches is not None:
|
||||
image_patches_flat = flatten_bn(image_patches)
|
||||
flat_data = flatten_bn(image_patches_flat, concat=True)
|
||||
patches_per_image = kwargs.pop("patches_per_image", None)
|
||||
|
||||
return FuyuImagePatchInputs(
|
||||
type="image_patches",
|
||||
flat_data=flat_data,
|
||||
patches_per_image=[x.size(0) for x in image_patches_flat],
|
||||
resolve_bindings={"fn": self.image_feature_size},
|
||||
)
|
||||
if image_patches is None:
|
||||
return None
|
||||
|
||||
return None
|
||||
return FuyuImagePatchInputs(
|
||||
type="image_patches",
|
||||
image_patches_flat=image_patches,
|
||||
patches_per_image=patches_per_image,
|
||||
resolve_bindings={"fn": self.image_feature_size},
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
|
||||
image_patches_flat = image_input["flat_data"]
|
||||
image_patches_flat = image_input["image_patches_flat"]
|
||||
patches_per_image = image_input["patches_per_image"]
|
||||
|
||||
assert self.vision_embed_tokens is not None
|
||||
vision_embeddings_flat, _ = self.vision_embed_tokens(
|
||||
image_patches_flat)
|
||||
|
||||
return vision_embeddings_flat.split(patches_per_image, dim=0)
|
||||
return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
Reference in New Issue
Block a user