[Misc] Remove redundant num_embeds (#15443)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -77,9 +77,6 @@ class PixtralImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
class PixtralProcessorAdapter:
|
||||
"""
|
||||
@@ -153,7 +150,6 @@ class PixtralProcessorAdapter:
|
||||
images_processed = list[torch.Tensor]()
|
||||
images_tokens = list[torch.Tensor]()
|
||||
images_embed_is_patch = list[torch.Tensor]()
|
||||
images_num_embeds = list[int]()
|
||||
|
||||
for image in images:
|
||||
image_inputs = self.image_processor(ImageChunk(image=image))
|
||||
@@ -163,13 +159,11 @@ class PixtralProcessorAdapter:
|
||||
images_processed.append(image_processed)
|
||||
images_tokens.append(image_tokens)
|
||||
images_embed_is_patch.append(image_tokens == image_token_id)
|
||||
images_num_embeds.append(len(image_tokens))
|
||||
|
||||
return {
|
||||
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
||||
"images": images_processed,
|
||||
"embed_is_patch": images_embed_is_patch,
|
||||
"num_embeds": torch.tensor(images_num_embeds),
|
||||
}
|
||||
|
||||
|
||||
@@ -273,7 +267,6 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
return dict(
|
||||
images=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -394,16 +387,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_embeds = kwargs.pop("num_embeds")
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
return PixtralImagePixelInputs(
|
||||
type="pixel_values",
|
||||
images=flatten_bn(images),
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
@@ -447,7 +434,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user