[Model] Refactor BLIP/BLIP-2 to support composite model loading (#8407)

This commit is contained in:
Cyrus Leung
2024-09-22 20:24:21 +08:00
committed by GitHub
parent 0e40ac9b7b
commit 06ed2815e2
10 changed files with 112 additions and 113 deletions

View File

@@ -501,6 +501,7 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
@@ -511,10 +512,6 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override=num_hidden_layers_override,
)
@property
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
@@ -540,12 +537,12 @@ class SiglipVisionModel(nn.Module):
for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if ("vision_model.post_layernorm" in name
and not self._require_post_layernorm):
if (name.startswith("vision_model.post_layernorm")
and self.vision_model.post_layernorm is None):
continue
# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
if name.startswith("vision_model.encoder.layers"):
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue