[Misc] Remove redundant num_embeds (#15443)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -69,9 +69,6 @@ class InternVLImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
class InternVLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
@@ -426,7 +423,6 @@ class BaseInternVLProcessor(ABC):
|
||||
tokenizer = self.tokenizer
|
||||
image_token_id = self.image_token_id
|
||||
|
||||
num_embeds = list[int]()
|
||||
embed_is_patch = list[torch.Tensor]()
|
||||
|
||||
for pixel_values in pixel_values_lst:
|
||||
@@ -438,11 +434,9 @@ class BaseInternVLProcessor(ABC):
|
||||
add_special_tokens=False)
|
||||
|
||||
text = [t.replace('<image>', image_repl.full, 1) for t in text]
|
||||
num_embeds.append(len(feature_tokens))
|
||||
embed_is_patch.append(
|
||||
torch.tensor(feature_tokens) == image_token_id)
|
||||
|
||||
image_inputs["num_embeds"] = torch.tensor(num_embeds)
|
||||
image_inputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
@@ -607,7 +601,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
"image", image_num_patches),
|
||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
@@ -840,7 +833,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
num_embeds = kwargs.pop("num_embeds", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values_flat is None and image_embeds is None:
|
||||
@@ -873,10 +865,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
|
||||
@@ -886,7 +874,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
pixel_values_flat),
|
||||
num_patches=image_num_patches,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
@@ -941,7 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user