[Misc] Clean up scatter_patch_features (#15559)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-27 15:45:00 +08:00
committed by GitHub
parent 43ed4143c4
commit e6c9053f9e
6 changed files with 82 additions and 136 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
import torch
@@ -154,8 +155,8 @@ def resolve_visual_encoder_outputs(
def scatter_patch_features(
features: torch.Tensor,
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]],
patches: Union[torch.Tensor, Sequence[torch.Tensor]],
embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]],
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
@@ -165,8 +166,8 @@ def scatter_patch_features(
can be filtered out by :func`select_patch_features`.
Args:
features: The patch features, concatenated across each image.
Shape: `(num_patch, feature_depth)`
patches: The patch features for each image.
Shape: `(num_images, <patch_dims>, feature_depth)`
embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)`
@@ -194,21 +195,21 @@ def scatter_patch_features(
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_embeds_per_image = [
e_is_patch.numel() for e_is_patch in embed_is_patch
]
if isinstance(embed_is_patch, torch.Tensor):
embed_is_patch_flat = embed_is_patch.view(-1)
else:
embed_is_patch_flat = torch.cat(embed_is_patch)
if len(patches) != len(embed_is_patch):
raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. "
f"{len(embed_is_patch)=}")
embeds_flat = features.new_full(
(sum(num_embeds_per_image), features.shape[-1]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch_flat] = features.flatten(0, -2)
def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor):
embed_one = patches_one.new_full(
(e_is_patch.shape[0], patches_one.shape[-1]),
fill_value=torch.nan,
)
embed_one[e_is_patch] = patches_one.flatten(0, -2)
return embed_one
return embeds_flat.split(num_embeds_per_image)
return tuple(
get_embed_one(patches_one, e_is_patch)
for patches_one, e_is_patch in zip(patches, embed_is_patch))
def select_patch_features(