[Model] Support multi-image for Molmo (#15438)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -155,7 +155,7 @@ def resolve_visual_encoder_outputs(
|
||||
|
||||
def scatter_patch_features(
|
||||
features: torch.Tensor,
|
||||
embed_is_patch: torch.Tensor,
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Scatter the patch features into a contiguous tensor that corresponds
|
||||
@@ -194,14 +194,19 @@ def scatter_patch_features(
|
||||
The resulting embedding tensor is:
|
||||
[ nan p1 p2 nan p3 p4 nan nan ]
|
||||
"""
|
||||
num_images, num_embeds = embed_is_patch.shape
|
||||
num_embeds_per_image = [num_embeds] * num_images
|
||||
num_embeds_per_image = [
|
||||
e_is_patch.numel() for e_is_patch in embed_is_patch
|
||||
]
|
||||
if isinstance(embed_is_patch, torch.Tensor):
|
||||
embed_is_patch_flat = embed_is_patch.view(-1)
|
||||
else:
|
||||
embed_is_patch_flat = torch.cat(embed_is_patch)
|
||||
|
||||
embeds_flat = features.new_full(
|
||||
(sum(num_embeds_per_image), features.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2)
|
||||
embeds_flat[embed_is_patch_flat] = features.flatten(0, -2)
|
||||
|
||||
return embeds_flat.split(num_embeds_per_image)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user