[Misc] Clean up scatter_patch_features (#15559)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-27 15:45:00 +08:00
committed by GitHub
parent 43ed4143c4
commit e6c9053f9e
6 changed files with 82 additions and 136 deletions

View File

@@ -49,7 +49,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant)
@@ -72,17 +71,17 @@ POOLING_SIZE = 2
class MolmoImageInputs(TypedDict):
images: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
"""Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`"""
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
"""Shape: `(batch_size, num_crops, num_patch)`"""
"""Shape: `(batch_size * num_images, num_crops, num_patch)`"""
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
Shape: `(batch_size, num_crops, num_patch)`
Shape: `(batch_size * num_images, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
@@ -90,7 +89,7 @@ class MolmoImageInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""
num_crops: Union[torch.Tensor, list[torch.Tensor]]
@@ -696,9 +695,10 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
return image_features
def forward(
self, images: torch.Tensor, image_masks: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
self,
images: torch.Tensor,
image_masks: torch.Tensor,
) -> torch.Tensor:
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
batch_size, num_image = images.shape[:2]
images = images.to(device=self.device, dtype=self.dtype)
@@ -1491,6 +1491,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item()
embed_is_patch = flatten_bn(embed_is_patch)
return MolmoImageInputs(
images=images,
image_masks=image_masks,
@@ -1502,13 +1504,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def _process_image_input(
self,
image_input: MolmoImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor]]:
if isinstance(image_input["images"], list):
) -> list[torch.Tensor]:
images = image_input["images"]
image_masks = image_input["image_masks"]
feat_is_patch = image_input["feat_is_patch"]
num_crops = image_input["num_crops"]
if isinstance(images, list):
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(image_input["images"], concat=True)
image_masks_flat = (None if (image_masks :=
image_input["image_masks"]) is None
else flatten_bn(image_masks, concat=True))
images_flat = flatten_bn(images, concat=True)
image_masks_flat = (None if image_masks is None else flatten_bn(
image_masks, concat=True))
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
@@ -1517,63 +1523,19 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
).squeeze(0)
# Reconstruct the batch dimension
image_features = image_features_flat.split(
image_input["num_crops"].sum(-1).tolist())
num_crops_per_image = [nc.sum().item() for nc in num_crops]
image_features = image_features_flat.split(num_crops_per_image)
else:
image_features = self.vision_backbone(
images=image_input["images"],
image_masks=image_input["image_masks"],
images=images,
image_masks=image_masks,
)
return image_features
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Note:
The original code only considers patch tokens as feature
tokens, but our processor considers all image-related tokens
as feature tokens because the feature tokens need to be
consecutive in `input_ids`.
Example:
A simplified example for one item in the batch:
.. code-block::
Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
embed_is_patch (from HF processor):
[ False True True False True True False False ]
Encoder outputs (from model):
[ p1 p2 0 p3 p4 0 ]
feat_is_patch (from HF processor):
[ True True False True True False ]
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_crops_per_image = num_crops.tolist()
feats_per_image = features.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
features = torch.cat([
# Only the features corresponding to patch tokens are relevant
return [
feats[f_is_patch]
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image)
])
return scatter_patch_features(features, embed_is_patch)
for feats, f_is_patch in zip(image_features, feat_is_patch)
]
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
@@ -1583,13 +1545,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_features = self._process_image_input(image_input)
return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
image_features,
image_input["feat_is_patch"],
image_input["num_crops"],
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,