[Model] Support multi-image for Molmo (#15438)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -57,7 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import select_patch_features
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
# TODO: hard-coded for now. Consider making it configurable.
|
||||
VIT_LAYERS = [-2, -9]
|
||||
@@ -71,13 +71,13 @@ POOLING_SIZE = 2
|
||||
|
||||
|
||||
class MolmoImageInputs(TypedDict):
|
||||
images: Union[torch.Tensor, List[torch.Tensor]]
|
||||
images: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
|
||||
|
||||
image_masks: Optional[Union[torch.Tensor, List[torch.Tensor]]]
|
||||
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
|
||||
"""Shape: `(batch_size, num_crops, num_patch)`"""
|
||||
|
||||
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image features correspond
|
||||
to patch tokens.
|
||||
@@ -85,7 +85,7 @@ class MolmoImageInputs(TypedDict):
|
||||
Shape: `(batch_size, num_crops, num_patch)`
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
@@ -93,7 +93,7 @@ class MolmoImageInputs(TypedDict):
|
||||
Shape: `(batch_size, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: torch.Tensor
|
||||
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
@@ -1144,13 +1144,7 @@ class MolmoProcessorWrapper:
|
||||
|
||||
image_input_idx = outputs.pop("image_input_idx", None)
|
||||
if image_input_idx is not None:
|
||||
input_is_patch = input_ids == self.image_patch_id
|
||||
image_input_idx_flat: torch.Tensor = image_input_idx.view(-1)
|
||||
image_valid_flat = image_input_idx_flat >= 0
|
||||
feat_is_patch_flat = image_valid_flat.clone()
|
||||
feat_is_patch_flat[image_valid_flat] = (
|
||||
input_is_patch[image_input_idx_flat[image_valid_flat]])
|
||||
feat_is_patch = feat_is_patch_flat.view(*image_input_idx.shape)
|
||||
feat_is_patch = image_input_idx >= 0
|
||||
|
||||
input_is_embed = torch.isin(
|
||||
input_ids,
|
||||
@@ -1165,6 +1159,17 @@ class MolmoProcessorWrapper:
|
||||
embed_is_patch = embed_ids == self.image_patch_id
|
||||
assert embed_is_patch.sum() == feat_is_patch.sum()
|
||||
|
||||
# image_tokens = extra_joint + joint
|
||||
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
|
||||
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
|
||||
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
|
||||
assert len(embed_start) == len(embed_end) == len(images)
|
||||
|
||||
embed_is_patch = [
|
||||
embed_is_patch[start:end + 1]
|
||||
for start, end in zip(embed_start, embed_end)
|
||||
]
|
||||
|
||||
tilings = [
|
||||
self.select_tiling(
|
||||
image_width=image.size[0],
|
||||
@@ -1180,7 +1185,7 @@ class MolmoProcessorWrapper:
|
||||
outputs["num_crops"] = num_crops
|
||||
outputs["img_patch_id"] = self.image_patch_id
|
||||
|
||||
return BatchFeature(outputs, tensor_type=return_tensors)
|
||||
return BatchFeature(outputs)
|
||||
|
||||
|
||||
class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
@@ -1190,9 +1195,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
return MolmoProcessorWrapper(processor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
# TODO: Investigate different `embed_is_patch` between cache/no-cache
|
||||
# in multi-image case
|
||||
return {"image": 1}
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
@@ -1325,7 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
"image", num_crops),
|
||||
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops),
|
||||
embed_is_patch=MultiModalFieldConfig.shared("image", num_images),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
@@ -1499,7 +1502,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: MolmoImageInputs,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
if isinstance(image_input["images"], list):
|
||||
# Call the vision backbone on the whole batch at once
|
||||
images_flat = flatten_bn(image_input["images"], concat=True)
|
||||
@@ -1530,7 +1533,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
|
||||
num_crops: torch.Tensor, # Shape: (num_images,)
|
||||
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
|
||||
) -> list[torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Scatter the patch features into a contiguous tensor that corresponds
|
||||
to the embedding tokens defined by the multimodal processor.
|
||||
@@ -1565,16 +1568,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
feats_per_image = features.split(num_crops_per_image)
|
||||
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
|
||||
|
||||
_, _, embed_dim = features.shape
|
||||
(num_embeds, ) = embed_is_patch.shape
|
||||
features = torch.cat([
|
||||
feats[f_is_patch]
|
||||
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image)
|
||||
])
|
||||
|
||||
embeds_in_batch = list[torch.Tensor]()
|
||||
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
|
||||
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
|
||||
embeds[embed_is_patch] = feats[f_is_patch]
|
||||
embeds_in_batch.append(embeds)
|
||||
|
||||
return embeds_in_batch
|
||||
return scatter_patch_features(features, embed_is_patch)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
|
||||
Reference in New Issue
Block a user