[Model] Refactor BLIP/BLIP-2 to support composite model loading (#8407)
This commit is contained in:
@@ -501,6 +501,7 @@ class SiglipVisionModel(nn.Module):
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||
@@ -511,10 +512,6 @@ class SiglipVisionModel(nn.Module):
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
)
|
||||
|
||||
@property
|
||||
def _require_post_layernorm(self) -> bool:
|
||||
return self.vision_model.post_layernorm is not None
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
@@ -540,12 +537,12 @@ class SiglipVisionModel(nn.Module):
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# post_layernorm is optional in SiglipVisionModel
|
||||
if ("vision_model.post_layernorm" in name
|
||||
and not self._require_post_layernorm):
|
||||
if (name.startswith("vision_model.post_layernorm")
|
||||
and self.vision_model.post_layernorm is None):
|
||||
continue
|
||||
|
||||
# omit layers when num_hidden_layers_override is set
|
||||
if "vision_model.encoder.layers." in name:
|
||||
if name.startswith("vision_model.encoder.layers"):
|
||||
layer_idx = int(name.split(".")[3])
|
||||
if layer_idx >= layer_count:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user