[Model] MiniCPM-V/O supports V1 (#15487)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-27 21:07:29 +08:00
committed by GitHub
parent 8063dfc61a
commit ac5bc615b0
4 changed files with 573 additions and 594 deletions

View File

@@ -92,8 +92,8 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size * num_images, num_embeds)`
"""
num_crops: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
num_crops: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
@dataclass
@@ -1492,6 +1492,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self.img_patch_id = img_patch_id.flatten().unique().item()
embed_is_patch = flatten_bn(embed_is_patch)
num_crops = flatten_bn(num_crops, concat=True)
return MolmoImageInputs(
images=images,
@@ -1510,31 +1511,24 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
feat_is_patch = image_input["feat_is_patch"]
num_crops = image_input["num_crops"]
if isinstance(images, list):
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(images, concat=True)
image_masks_flat = (None if image_masks is None else flatten_bn(
image_masks, concat=True))
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(images, concat=True)
image_masks_flat = (None if image_masks is None else flatten_bn(
image_masks, concat=True))
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
image_masks=(None if image_masks_flat is None else
image_masks_flat.unsqueeze(0)),
).squeeze(0)
# Reconstruct the batch dimension
num_crops_per_image = [nc.sum().item() for nc in num_crops]
image_features = image_features_flat.split(num_crops_per_image)
else:
image_features = self.vision_backbone(
images=images,
image_masks=image_masks,
)
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
image_masks=(None if image_masks_flat is None else
image_masks_flat.unsqueeze(0)),
).squeeze(0)
# Only the features corresponding to patch tokens are relevant
return [
feats[f_is_patch]
for feats, f_is_patch in zip(image_features, feat_is_patch)
feats[f_is_patch] for feats, f_is_patch in zip(
image_features_flat.split(num_crops.tolist()),
feat_is_patch_flat.split(num_crops.tolist()),
)
]
def get_multimodal_embeddings(