[Misc] Remove redundant num_embeds (#15443)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -76,9 +76,6 @@ class PixtralHFImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
class LlavaImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
@@ -358,15 +355,10 @@ class PixtralHFMultiModalProcessor(
|
||||
image_height=pixel_value.shape[-2],
|
||||
) for pixel_value in processed_outputs["pixel_values"]
|
||||
]
|
||||
num_embeds = torch.tensor([(ncols + 1) * nrows
|
||||
for ncols, nrows in tile_sizes])
|
||||
# Each image may result to masks of different sizes, so we need to
|
||||
# later use `num_embeds` to get per-image masks.
|
||||
embed_is_patch = [
|
||||
torch.tensor(([True] * ncols + [False]) * nrows)
|
||||
for ncols, nrows in tile_sizes
|
||||
]
|
||||
processed_outputs["num_embeds"] = num_embeds
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
return processed_outputs
|
||||
@@ -378,7 +370,6 @@ class PixtralHFMultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
@@ -627,16 +618,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_embeds = kwargs.pop("num_embeds")
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
return PixtralHFImagePixelInputs(
|
||||
type="pixel_values_pixtral",
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
@@ -738,7 +723,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
vision_embeddings,
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user