[Misc] Clean up scatter_patch_features (#15559)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -49,7 +49,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptInsertion, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
@@ -72,17 +71,17 @@ POOLING_SIZE = 2
|
||||
|
||||
class MolmoImageInputs(TypedDict):
|
||||
images: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
|
||||
"""Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`"""
|
||||
|
||||
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
|
||||
"""Shape: `(batch_size, num_crops, num_patch)`"""
|
||||
"""Shape: `(batch_size * num_images, num_crops, num_patch)`"""
|
||||
|
||||
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image features correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_crops, num_patch)`
|
||||
Shape: `(batch_size * num_images, num_crops, num_patch)`
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
@@ -90,7 +89,7 @@ class MolmoImageInputs(TypedDict):
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_embeds)`
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
||||
@@ -696,9 +695,10 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
return image_features
|
||||
|
||||
def forward(
|
||||
self, images: torch.Tensor, image_masks: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
image_masks: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
|
||||
batch_size, num_image = images.shape[:2]
|
||||
images = images.to(device=self.device, dtype=self.dtype)
|
||||
@@ -1491,6 +1491,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
f"Got type: {type(img_patch_id)}")
|
||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return MolmoImageInputs(
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
@@ -1502,13 +1504,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: MolmoImageInputs,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
if isinstance(image_input["images"], list):
|
||||
) -> list[torch.Tensor]:
|
||||
images = image_input["images"]
|
||||
image_masks = image_input["image_masks"]
|
||||
feat_is_patch = image_input["feat_is_patch"]
|
||||
num_crops = image_input["num_crops"]
|
||||
|
||||
if isinstance(images, list):
|
||||
# Call the vision backbone on the whole batch at once
|
||||
images_flat = flatten_bn(image_input["images"], concat=True)
|
||||
image_masks_flat = (None if (image_masks :=
|
||||
image_input["image_masks"]) is None
|
||||
else flatten_bn(image_masks, concat=True))
|
||||
images_flat = flatten_bn(images, concat=True)
|
||||
image_masks_flat = (None if image_masks is None else flatten_bn(
|
||||
image_masks, concat=True))
|
||||
|
||||
image_features_flat = self.vision_backbone(
|
||||
images=images_flat.unsqueeze(0),
|
||||
@@ -1517,63 +1523,19 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
).squeeze(0)
|
||||
|
||||
# Reconstruct the batch dimension
|
||||
image_features = image_features_flat.split(
|
||||
image_input["num_crops"].sum(-1).tolist())
|
||||
num_crops_per_image = [nc.sum().item() for nc in num_crops]
|
||||
image_features = image_features_flat.split(num_crops_per_image)
|
||||
else:
|
||||
image_features = self.vision_backbone(
|
||||
images=image_input["images"],
|
||||
image_masks=image_input["image_masks"],
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
)
|
||||
|
||||
return image_features
|
||||
|
||||
def _get_mm_embeds(
|
||||
self,
|
||||
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
|
||||
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
|
||||
num_crops: torch.Tensor, # Shape: (num_images,)
|
||||
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Scatter the patch features into a contiguous tensor that corresponds
|
||||
to the embedding tokens defined by the multimodal processor.
|
||||
|
||||
Note:
|
||||
The original code only considers patch tokens as feature
|
||||
tokens, but our processor considers all image-related tokens
|
||||
as feature tokens because the feature tokens need to be
|
||||
consecutive in `input_ids`.
|
||||
|
||||
Example:
|
||||
A simplified example for one item in the batch:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Embedding tokens (from HF processor):
|
||||
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
|
||||
|
||||
embed_is_patch (from HF processor):
|
||||
[ False True True False True True False False ]
|
||||
|
||||
Encoder outputs (from model):
|
||||
[ p1 p2 0 p3 p4 0 ]
|
||||
|
||||
feat_is_patch (from HF processor):
|
||||
[ True True False True True False ]
|
||||
|
||||
The resulting embedding tensor is:
|
||||
[ nan p1 p2 nan p3 p4 nan nan ]
|
||||
"""
|
||||
num_crops_per_image = num_crops.tolist()
|
||||
feats_per_image = features.split(num_crops_per_image)
|
||||
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
|
||||
|
||||
features = torch.cat([
|
||||
# Only the features corresponding to patch tokens are relevant
|
||||
return [
|
||||
feats[f_is_patch]
|
||||
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image)
|
||||
])
|
||||
|
||||
return scatter_patch_features(features, embed_is_patch)
|
||||
for feats, f_is_patch in zip(image_features, feat_is_patch)
|
||||
]
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
@@ -1583,13 +1545,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return flatten_2d_lists(
|
||||
self._get_mm_embeds(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["feat_is_patch"],
|
||||
image_input["num_crops"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user