[Fix] Remove unused packing_position_embedding from PaddleOCRVL for better checkpoint compatibility (#38232)

Signed-off-by: zhangyue66 <zhangyue66@baidu.com>
This commit is contained in:
zhang-prog
2026-03-26 23:34:49 +08:00
committed by GitHub
parent be1a85b7a2
commit 0f5b526040

View File

@@ -409,7 +409,6 @@ class SiglipVisionEmbeddings(nn.Module):
self.cache_position_embedding = dict() self.cache_position_embedding = dict()
self.cache_position_count = dict() self.cache_position_count = dict()
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
self.register_buffer( self.register_buffer(
"position_ids", "position_ids",
@@ -501,24 +500,22 @@ class SiglipVisionEmbeddings(nn.Module):
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
embeddings = patch_embeds.flatten(-2).squeeze(-1) embeddings = patch_embeds.flatten(-2).squeeze(-1)
if interpolate_pos_encoding and image_grid_thw is not None: start = 0
start = 0 tmp_embeddings = list()
tmp_embeddings = list() for image_grid in image_grid_thw:
for image_grid in image_grid_thw: t, h, w = image_grid
t, h, w = image_grid end = start + t * h * w
end = start + t * h * w image_embeddings = embeddings[start:end, :]
image_embeddings = embeddings[start:end, :] position_embedding = (
position_embedding = ( self.interpolate_pos_encoding(image_embeddings, h, w, True)
self.interpolate_pos_encoding(image_embeddings, h, w, True) .squeeze(0)
.squeeze(0) .repeat(t, 1)
.repeat(t, 1) )
) image_embeddings = image_embeddings + position_embedding
image_embeddings = image_embeddings + position_embedding tmp_embeddings.append(image_embeddings)
tmp_embeddings.append(image_embeddings) start = end
start = end embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
else:
embeddings = embeddings + self.packing_position_embedding(position_ids)
return embeddings return embeddings
else: else:
raise ValueError( raise ValueError(
@@ -939,6 +936,8 @@ class SiglipVisionModel(nn.Module):
continue continue
if "head.mlp" in name or "head.probe" in name: if "head.mlp" in name or "head.probe" in name:
continue continue
if "packing_position_embedding" in name:
continue
if self.quant_config is not None and ( if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name) scale_name := self.quant_config.get_cache_scale(name)
): ):