[VLM] Enable overriding whether post layernorm is used in vision encoder + fix quant args (#9217)
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -113,7 +113,8 @@ class Idefics2VisionAttention(nn.Module):
|
||||
self,
|
||||
config: Idefics2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -130,12 +131,14 @@ class Idefics2VisionAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
@@ -178,7 +181,8 @@ class Idefics2VisionMLP(nn.Module):
|
||||
self,
|
||||
config: Idefics2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
@@ -187,12 +191,14 @@ class Idefics2VisionMLP(nn.Module):
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -204,13 +210,22 @@ class Idefics2VisionMLP(nn.Module):
|
||||
|
||||
class Idefics2EncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: Idefics2Config):
|
||||
def __init__(
|
||||
self,
|
||||
config: Idefics2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Idefics2VisionAttention(config)
|
||||
self.self_attn = Idefics2VisionAttention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = Idefics2VisionMLP(config)
|
||||
self.mlp = Idefics2VisionMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@@ -245,12 +260,20 @@ class Idefics2Encoder(nn.Module):
|
||||
config: Idefics2Config
|
||||
"""
|
||||
|
||||
def __init__(self, config: Idefics2Config):
|
||||
def __init__(
|
||||
self,
|
||||
config: Idefics2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([
|
||||
Idefics2EncoderLayer(config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
Idefics2EncoderLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
@@ -275,12 +298,20 @@ class Idefics2Encoder(nn.Module):
|
||||
|
||||
class Idefics2VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self, config: Idefics2VisionConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: Idefics2VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
embed_dim = config.hidden_size
|
||||
self.config = config
|
||||
self.embeddings = Idefics2VisionEmbeddings(config)
|
||||
self.encoder = Idefics2Encoder(config)
|
||||
self.encoder = Idefics2Encoder(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user