[Misc] Remove redundant num_embeds (#15443)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-25 18:27:57 +08:00
committed by GitHub
parent a9e879b316
commit 5994430b84
5 changed files with 25 additions and 64 deletions

View File

@@ -77,9 +77,6 @@ class PixtralImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class PixtralProcessorAdapter:
"""
@@ -153,7 +150,6 @@ class PixtralProcessorAdapter:
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
images_embed_is_patch = list[torch.Tensor]()
images_num_embeds = list[int]()
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
@@ -163,13 +159,11 @@ class PixtralProcessorAdapter:
images_processed.append(image_processed)
images_tokens.append(image_tokens)
images_embed_is_patch.append(image_tokens == image_token_id)
images_num_embeds.append(len(image_tokens))
return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
"embed_is_patch": images_embed_is_patch,
"num_embeds": torch.tensor(images_num_embeds),
}
@@ -273,7 +267,6 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
return dict(
images=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@@ -394,16 +387,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
def _process_image_input(
@@ -447,7 +434,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))