[Misc] Remove redundant num_embeds (#15443)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -63,9 +63,6 @@ class Gemma3ImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
|
||||
@@ -317,11 +314,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||
for image_repl in image_repl_features
|
||||
]
|
||||
num_embeds = [
|
||||
len(image_repl_feature_tokens)
|
||||
for image_repl_feature_tokens in image_repls_feature_tokens
|
||||
]
|
||||
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
|
||||
|
||||
vocab = tokenizer.get_vocab()
|
||||
image_token_id = vocab[tokenizer.image_token]
|
||||
@@ -354,7 +346,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
"image", num_crops + 1),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -583,7 +574,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
num_embeds = kwargs.pop("num_embeds", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
@@ -601,10 +591,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
|
||||
@@ -613,7 +599,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
pixel_values=self._validate_pixel_values(pixel_values),
|
||||
num_patches=num_crops + 1,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
def _image_pixels_to_features(
|
||||
@@ -656,7 +641,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user