[Model] Support multi-image for Molmo (#15438)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-26 11:26:33 +08:00
committed by GitHub
parent e42389f9d7
commit 997c8811d6
4 changed files with 39 additions and 35 deletions

View File

@@ -155,7 +155,7 @@ def resolve_visual_encoder_outputs(
def scatter_patch_features(
features: torch.Tensor,
embed_is_patch: torch.Tensor,
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]],
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
@@ -194,14 +194,19 @@ def scatter_patch_features(
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_images, num_embeds = embed_is_patch.shape
num_embeds_per_image = [num_embeds] * num_images
num_embeds_per_image = [
e_is_patch.numel() for e_is_patch in embed_is_patch
]
if isinstance(embed_is_patch, torch.Tensor):
embed_is_patch_flat = embed_is_patch.view(-1)
else:
embed_is_patch_flat = torch.cat(embed_is_patch)
embeds_flat = features.new_full(
(sum(num_embeds_per_image), features.shape[-1]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2)
embeds_flat[embed_is_patch_flat] = features.flatten(0, -2)
return embeds_flat.split(num_embeds_per_image)