From 55392bc87932da63b888e58f371fe4b67b438496 Mon Sep 17 00:00:00 2001 From: "sangho.lee" Date: Sat, 11 Oct 2025 00:28:23 -0500 Subject: [PATCH] [Bugfix][Multi Modal] Fix incorrect Molmo image processing (#26563) Signed-off-by: sanghol --- vllm/model_executor/models/molmo.py | 40 +++++++++++++++++------------ 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 734841d0d..f1dd06f3a 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -114,11 +114,11 @@ class MolmoImageInputs(TensorSchema): TensorShape("bn", "nc", "np", dynamic_dims={"nc"}), ] - feat_is_patch: Annotated[ + image_input_idx: Annotated[ Union[torch.Tensor, list[torch.Tensor]], TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}), ] - # A boolean mask indicating which image features correspond to patch tokens. + # An index tensor that maps image features to their corresponding patch tokens. num_crops: Annotated[torch.Tensor, TensorShape("bn")] @@ -1177,7 +1177,7 @@ class MolmoProcessorWrapper: num_crops = torch.tensor(tilings).prod(-1) + 1 assert num_crops.sum() == len(feat_is_patch) - outputs["feat_is_patch"] = feat_is_patch + outputs["image_input_idx"] = image_input_idx outputs["num_crops"] = num_crops outputs["img_patch_id"] = self.image_patch_id @@ -1211,8 +1211,9 @@ class MolmoProcessingInfo(BaseProcessingInfo): image_token_length_w = processor.image_token_length_w image_token_length_h = processor.image_token_length_h - extra = image_token_length_w * image_token_length_h - joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size) + # Calculate total tokens: 2 for start/end + (w+1)*h for column separators + extra = 2 + (image_token_length_w + 1) * image_token_length_h + joint = 2 + ((ncols + 1) // pooling_size + 1) * ((nrows + 1) // pooling_size) return extra + joint @@ -1299,7 +1300,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): return dict( images=MultiModalFieldConfig.flat_from_sizes("image", num_crops), image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops), - feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops), + image_input_idx=MultiModalFieldConfig.flat_from_sizes("image", num_crops), num_crops=MultiModalFieldConfig.batched("image"), img_patch_id=MultiModalFieldConfig.shared("image", num_images), ) @@ -1444,7 +1445,7 @@ class MolmoForCausalLM( ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) image_masks = kwargs.pop("image_masks", None) - feat_is_patch = kwargs.pop("feat_is_patch", None) + image_input_idx = kwargs.pop("image_input_idx", None) num_crops = kwargs.pop("num_crops", None) if images is None: @@ -1466,7 +1467,7 @@ class MolmoForCausalLM( return MolmoImageInputs( images=images, image_masks=image_masks, - feat_is_patch=feat_is_patch, + image_input_idx=image_input_idx, num_crops=num_crops, ) @@ -1476,7 +1477,7 @@ class MolmoForCausalLM( ) -> list[torch.Tensor]: images = image_input["images"] image_masks = image_input["image_masks"] - feat_is_patch = image_input["feat_is_patch"] + image_input_idx = image_input["image_input_idx"] num_crops = image_input["num_crops"] # Call the vision backbone on the whole batch at once @@ -1484,7 +1485,7 @@ class MolmoForCausalLM( image_masks_flat = ( None if image_masks is None else flatten_bn(image_masks, concat=True) ) - feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True) + image_input_idx_flat = flatten_bn(image_input_idx, concat=True) image_features_flat = self.vision_backbone( images=images_flat.unsqueeze(0), @@ -1494,13 +1495,18 @@ class MolmoForCausalLM( ).squeeze(0) # Only the features corresponding to patch tokens are relevant - return [ - feats[f_is_patch] - for feats, f_is_patch in zip( - image_features_flat.split(num_crops.tolist()), - feat_is_patch_flat.split(num_crops.tolist()), - ) - ] + # Re-order the features using the image_input_idx tensor + results = [] + num_crops_list = num_crops.tolist() + for feats, img_idx in zip( + image_features_flat.split(num_crops_list), + image_input_idx_flat.split(num_crops_list), + ): + is_valid = img_idx >= 0 + valid_img_idx = img_idx[is_valid] + order = torch.argsort(valid_img_idx) + results.append(feats[is_valid][order]) + return results def get_language_model(self) -> torch.nn.Module: return self.model