[Misc] Remove redundant num_embeds (#15443)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-25 18:27:57 +08:00
committed by GitHub
parent a9e879b316
commit 5994430b84
5 changed files with 25 additions and 64 deletions

View File

@@ -76,9 +76,6 @@ class PixtralHFImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@@ -358,15 +355,10 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
num_embeds = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# later use `num_embeds` to get per-image masks.
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["num_embeds"] = num_embeds
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
@@ -378,7 +370,6 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@@ -627,16 +618,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
return LlavaImagePixelInputs(
@@ -738,7 +723,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["num_embeds"],
image_input["embed_is_patch"],
))