[VLM] Merged multi-modal processor for Pixtral (#12211)

Signed-off-by: remi <remi@mistral.ai>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Rémi Delacourt
2025-03-15 14:28:27 +01:00
committed by GitHub
parent 74bc397b0a
commit 61c6a5a796
9 changed files with 620 additions and 358 deletions

View File

@@ -68,23 +68,15 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
"""
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
Shape: `(batch_size, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_embeds)`
Shape: `(batch_size, num_images, num_embeds)`
"""
num_crops: Union[torch.Tensor, list[torch.Tensor]]
num_patches: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
@@ -360,16 +352,16 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
num_crops = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
num_patches = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# flatten the list and later use `num_crops` to get per-image masks.
embed_is_patch = torch.tensor(
flatten_2d_lists([([True] * ncols + [False]) * nrows
for ncols, nrows in tile_sizes]))
processed_outputs["num_crops"] = num_crops
# later use `num_patches` to get per-image masks.
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["num_patches"] = num_patches
processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["feat_is_patch"] = embed_is_patch
return processed_outputs
@@ -378,14 +370,10 @@ class PixtralHFMultiModalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
return dict(
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@@ -628,27 +616,21 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral":
feat_is_patch = kwargs.pop("feat_is_patch")
if not isinstance(feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops")
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
num_patches=num_patches,
)
return LlavaImagePixelInputs(
@@ -687,21 +669,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
def select_features(leaf: torch.Tensor):
return self._select_image_features(
leaf,
strategy=self.config.vision_feature_select_strategy,
)
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
def _process_image_pixels(
self,
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
@@ -731,45 +718,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> list[torch.Tensor]:
features: torch.Tensor, # Shape: (num_patch, d)
num_patches: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
) -> tuple[torch.Tensor, ...]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `feat_is_patch`. This work
# Insert columns of nan values according to `embed_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
feat_is_patch = feat_is_patch.view(-1)
embed_is_patch = embed_is_patch.view(-1)
expanded_embedding = torch.full(
(sum(num_crops), *features.shape[1:]),
torch.nan,
dtype=features.dtype).to(features.device)
expanded_embedding[feat_is_patch] = features
num_patches_per_image: list[int] = num_patches.tolist()
num_crops_per_image = num_crops.tolist()
feats_per_image = expanded_embedding.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
embeds_flat = features.new_full(
(sum(num_patches_per_image), *features.shape[1:]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features
embed_dim = expanded_embedding.shape[-1]
num_embeds = embed_is_patch.shape[0]
embeds_in_batch = list[torch.Tensor]()
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
embeds[embed_is_patch] = feats[f_is_patch]
embeds_in_batch.append(embeds)
return embeds_in_batch
return embeds_flat.split(num_patches_per_image)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
@@ -784,12 +756,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings
nested_emb = [
return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"],
image_input["num_crops"], image_input["embed_is_patch"])
]
return flatten_2d_lists(nested_emb)
vision_embeddings,
image_input["num_patches"],
image_input["embed_is_patch"],
))
def get_input_embeddings(
self,
@@ -805,9 +777,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, cast(NestedTensors,
patch_embeddings),
self.config.image_token_index)
input_ids,
inputs_embeds,
cast(NestedTensors, patch_embeddings),
self.config.image_token_index,
)
return inputs_embeds
def forward(