[VLM] Clean up Phi-4-MM ViT implementation (#14812)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Isotr0py
2025-03-16 09:53:52 +08:00
committed by GitHub
parent 3453b964a3
commit def232e122
7 changed files with 316 additions and 1988 deletions

View File

@@ -113,7 +113,7 @@ class Idefics2VisionAttention(nn.Module):
def __init__(
self,
config: Idefics2Config,
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
@@ -164,7 +164,7 @@ class Idefics2VisionMLP(nn.Module):
def __init__(
self,
config: Idefics2Config,
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
@@ -249,16 +249,24 @@ class Idefics2Encoder(nn.Module):
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
Idefics2EncoderLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
for layer_idx in range(num_hidden_layers)
])
def forward(
@@ -287,6 +295,9 @@ class Idefics2VisionTransformer(nn.Module):
self,
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: bool = True,
prefix: str = "",
) -> None:
super().__init__()
@@ -294,11 +305,24 @@ class Idefics2VisionTransformer(nn.Module):
embed_dim = config.hidden_size
self.config = config
self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder(config,
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
self.encoder = Idefics2Encoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder")
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
self.require_post_norm = require_post_norm
self.post_layernorm = nn.LayerNorm(
embed_dim,
eps=config.layer_norm_eps,
) if require_post_norm else nn.Identity()
def get_input_embeddings(self):
return self.embeddings
@@ -328,7 +352,24 @@ class Idefics2VisionTransformer(nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.encoder.layers)
for name, loaded_weight in weights:
# skip pooling header
if name.startswith("head."):
continue
# post_layernorm is optional
if (name.startswith("post_layernorm.")
and not self.require_post_norm):
continue
# omit layers when num_hidden_layers_override is set
if name.startswith("encoder.layers."):
layer_idx = int(name.split(".")[2])
if layer_idx >= layer_count:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue