[Bugfix] Fix missing post_layernorm in CLIP (#8155)

This commit is contained in:
Cyrus Leung
2024-09-10 16:22:50 +08:00
committed by GitHub
parent a1d874224d
commit da1a844e61
2 changed files with 42 additions and 19 deletions

View File

@@ -355,6 +355,19 @@ class CLIPVisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None
def forward(
self,
pixel_values: torch.Tensor,
@@ -364,7 +377,10 @@ class CLIPVisionTransformer(nn.Module):
hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states)
return hidden_states
if self.post_layernorm is None:
return hidden_states
return self.post_layernorm(hidden_states)
class CLIPVisionModel(nn.Module):
@@ -386,9 +402,12 @@ class CLIPVisionModel(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)
def forward(self, pixel_values: Optional[torch.Tensor] = None):
@property
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None
return self.vision_model(pixel_values=pixel_values)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values)
@property
def device(self):
@@ -408,8 +427,10 @@ class CLIPVisionModel(nn.Module):
for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
if ("vision_model.post_layernorm" in name
and not self._require_post_layernorm):
continue
# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])