[Model] MiniCPM-V/O supports V1 (#15487)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -92,8 +92,8 @@ class MolmoImageInputs(TypedDict):
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
num_crops: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1492,6 +1492,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
|
||||
return MolmoImageInputs(
|
||||
images=images,
|
||||
@@ -1510,31 +1511,24 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
feat_is_patch = image_input["feat_is_patch"]
|
||||
num_crops = image_input["num_crops"]
|
||||
|
||||
if isinstance(images, list):
|
||||
# Call the vision backbone on the whole batch at once
|
||||
images_flat = flatten_bn(images, concat=True)
|
||||
image_masks_flat = (None if image_masks is None else flatten_bn(
|
||||
image_masks, concat=True))
|
||||
# Call the vision backbone on the whole batch at once
|
||||
images_flat = flatten_bn(images, concat=True)
|
||||
image_masks_flat = (None if image_masks is None else flatten_bn(
|
||||
image_masks, concat=True))
|
||||
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
|
||||
|
||||
image_features_flat = self.vision_backbone(
|
||||
images=images_flat.unsqueeze(0),
|
||||
image_masks=(None if image_masks_flat is None else
|
||||
image_masks_flat.unsqueeze(0)),
|
||||
).squeeze(0)
|
||||
|
||||
# Reconstruct the batch dimension
|
||||
num_crops_per_image = [nc.sum().item() for nc in num_crops]
|
||||
image_features = image_features_flat.split(num_crops_per_image)
|
||||
else:
|
||||
image_features = self.vision_backbone(
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
)
|
||||
image_features_flat = self.vision_backbone(
|
||||
images=images_flat.unsqueeze(0),
|
||||
image_masks=(None if image_masks_flat is None else
|
||||
image_masks_flat.unsqueeze(0)),
|
||||
).squeeze(0)
|
||||
|
||||
# Only the features corresponding to patch tokens are relevant
|
||||
return [
|
||||
feats[f_is_patch]
|
||||
for feats, f_is_patch in zip(image_features, feat_is_patch)
|
||||
feats[f_is_patch] for feats, f_is_patch in zip(
|
||||
image_features_flat.split(num_crops.tolist()),
|
||||
feat_is_patch_flat.split(num_crops.tolist()),
|
||||
)
|
||||
]
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
|
||||
Reference in New Issue
Block a user