[VLM] Merged multi-modal processor for Pixtral (#12211)
Signed-off-by: remi <remi@mistral.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -68,23 +68,15 @@ class PixtralHFImagePixelInputs(TypedDict):
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image features correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_crops, num_patch)`
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_embeds)`
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
||||
num_patches: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
@@ -360,16 +352,16 @@ class PixtralHFMultiModalProcessor(
|
||||
image_height=pixel_value.shape[-2],
|
||||
) for pixel_value in processed_outputs["pixel_values"]
|
||||
]
|
||||
num_crops = torch.tensor([(ncols + 1) * nrows
|
||||
for ncols, nrows in tile_sizes])
|
||||
num_patches = torch.tensor([(ncols + 1) * nrows
|
||||
for ncols, nrows in tile_sizes])
|
||||
# Each image may result to masks of different sizes, so we need to
|
||||
# flatten the list and later use `num_crops` to get per-image masks.
|
||||
embed_is_patch = torch.tensor(
|
||||
flatten_2d_lists([([True] * ncols + [False]) * nrows
|
||||
for ncols, nrows in tile_sizes]))
|
||||
processed_outputs["num_crops"] = num_crops
|
||||
# later use `num_patches` to get per-image masks.
|
||||
embed_is_patch = [
|
||||
torch.tensor(([True] * ncols + [False]) * nrows)
|
||||
for ncols, nrows in tile_sizes
|
||||
]
|
||||
processed_outputs["num_patches"] = num_patches
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
processed_outputs["feat_is_patch"] = embed_is_patch
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@@ -378,14 +370,10 @@ class PixtralHFMultiModalProcessor(
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
|
||||
return dict(
|
||||
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops),
|
||||
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
num_patches=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
@@ -628,27 +616,21 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if self.config.vision_config.model_type == "pixtral":
|
||||
feat_is_patch = kwargs.pop("feat_is_patch")
|
||||
if not isinstance(feat_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of feat_is_patch. "
|
||||
f"Got type: {type(feat_is_patch)}")
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_crops = kwargs.pop("num_crops")
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
num_patches = kwargs.pop("num_patches")
|
||||
if not isinstance(num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_patches. "
|
||||
f"Got type: {type(num_patches)}")
|
||||
|
||||
return PixtralHFImagePixelInputs(
|
||||
type="pixel_values_pixtral",
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
feat_is_patch=feat_is_patch,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_crops=num_crops,
|
||||
num_patches=num_patches,
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
@@ -687,21 +669,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
|
||||
PixtralHFVisionModel],
|
||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_tower(pixel_values)
|
||||
|
||||
return self._select_image_features(
|
||||
image_features,
|
||||
strategy=self.config.vision_feature_select_strategy,
|
||||
def select_features(leaf: torch.Tensor):
|
||||
return self._select_image_features(
|
||||
leaf,
|
||||
strategy=self.config.vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
return cast(
|
||||
Union[torch.Tensor, tuple[torch.Tensor, ...]],
|
||||
json_map_leaves(select_features, image_features),
|
||||
)
|
||||
|
||||
def _process_image_pixels(
|
||||
self,
|
||||
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["pixel_values"]
|
||||
@@ -731,45 +718,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def _get_mm_embeds(
|
||||
self,
|
||||
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
|
||||
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
|
||||
num_crops: torch.Tensor, # Shape: (num_images,)
|
||||
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
|
||||
) -> list[torch.Tensor]:
|
||||
features: torch.Tensor, # Shape: (num_patch, d)
|
||||
num_patches: torch.Tensor, # Shape: (num_images,)
|
||||
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Scatter the patch features into a contiguous tensor that corresponds
|
||||
to the embedding tokens defined by the multimodal processor.
|
||||
|
||||
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
|
||||
"""
|
||||
|
||||
# Insert columns of nan values according to `feat_is_patch`. This work
|
||||
# Insert columns of nan values according to `embed_is_patch`. This work
|
||||
# ideally should be done in `_process_image_input`, but
|
||||
# `_process_image_input` is used in both V0 and V1 path. It's safer to
|
||||
# put the logic here.
|
||||
# FIXME: Move this logic to `_process_image_input` when v0 is
|
||||
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
|
||||
feat_is_patch = feat_is_patch.view(-1)
|
||||
embed_is_patch = embed_is_patch.view(-1)
|
||||
expanded_embedding = torch.full(
|
||||
(sum(num_crops), *features.shape[1:]),
|
||||
torch.nan,
|
||||
dtype=features.dtype).to(features.device)
|
||||
expanded_embedding[feat_is_patch] = features
|
||||
num_patches_per_image: list[int] = num_patches.tolist()
|
||||
|
||||
num_crops_per_image = num_crops.tolist()
|
||||
feats_per_image = expanded_embedding.split(num_crops_per_image)
|
||||
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
|
||||
embeds_flat = features.new_full(
|
||||
(sum(num_patches_per_image), *features.shape[1:]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
embeds_flat[embed_is_patch.view(-1)] = features
|
||||
|
||||
embed_dim = expanded_embedding.shape[-1]
|
||||
num_embeds = embed_is_patch.shape[0]
|
||||
|
||||
embeds_in_batch = list[torch.Tensor]()
|
||||
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
|
||||
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
|
||||
embeds[embed_is_patch] = feats[f_is_patch]
|
||||
embeds_in_batch.append(embeds)
|
||||
|
||||
return embeds_in_batch
|
||||
return embeds_flat.split(num_patches_per_image)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
@@ -784,12 +756,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# The path is used for pixtral (V0 only) and llava (V0/V1)
|
||||
return vision_embeddings
|
||||
|
||||
nested_emb = [
|
||||
return flatten_2d_lists(
|
||||
self._get_mm_embeds(*args) for args in zip(
|
||||
vision_embeddings, image_input["feat_is_patch"],
|
||||
image_input["num_crops"], image_input["embed_is_patch"])
|
||||
]
|
||||
return flatten_2d_lists(nested_emb)
|
||||
vision_embeddings,
|
||||
image_input["num_patches"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@@ -805,9 +777,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, cast(NestedTensors,
|
||||
patch_embeddings),
|
||||
self.config.image_token_index)
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
cast(NestedTensors, patch_embeddings),
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user