[Fix] Remove unused packing_position_embedding from PaddleOCRVL for better checkpoint compatibility (#38232)
Signed-off-by: zhangyue66 <zhangyue66@baidu.com>
This commit is contained in:
@@ -409,7 +409,6 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
self.cache_position_embedding = dict()
|
||||
self.cache_position_count = dict()
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
|
||||
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
@@ -501,24 +500,22 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
||||
embeddings = patch_embeds.flatten(-2).squeeze(-1)
|
||||
|
||||
if interpolate_pos_encoding and image_grid_thw is not None:
|
||||
start = 0
|
||||
tmp_embeddings = list()
|
||||
for image_grid in image_grid_thw:
|
||||
t, h, w = image_grid
|
||||
end = start + t * h * w
|
||||
image_embeddings = embeddings[start:end, :]
|
||||
position_embedding = (
|
||||
self.interpolate_pos_encoding(image_embeddings, h, w, True)
|
||||
.squeeze(0)
|
||||
.repeat(t, 1)
|
||||
)
|
||||
image_embeddings = image_embeddings + position_embedding
|
||||
tmp_embeddings.append(image_embeddings)
|
||||
start = end
|
||||
embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
|
||||
else:
|
||||
embeddings = embeddings + self.packing_position_embedding(position_ids)
|
||||
start = 0
|
||||
tmp_embeddings = list()
|
||||
for image_grid in image_grid_thw:
|
||||
t, h, w = image_grid
|
||||
end = start + t * h * w
|
||||
image_embeddings = embeddings[start:end, :]
|
||||
position_embedding = (
|
||||
self.interpolate_pos_encoding(image_embeddings, h, w, True)
|
||||
.squeeze(0)
|
||||
.repeat(t, 1)
|
||||
)
|
||||
image_embeddings = image_embeddings + position_embedding
|
||||
tmp_embeddings.append(image_embeddings)
|
||||
start = end
|
||||
embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
|
||||
|
||||
return embeddings
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -939,6 +936,8 @@ class SiglipVisionModel(nn.Module):
|
||||
continue
|
||||
if "head.mlp" in name or "head.probe" in name:
|
||||
continue
|
||||
if "packing_position_embedding" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user