[VLM] Enable overriding whether post layernorm is used in vision encoder + fix quant args (#9217)

Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung
2024-10-23 19:27:37 +08:00
committed by GitHub
parent 3ff57ebfca
commit c18e1a3418
18 changed files with 551 additions and 253 deletions

View File

@@ -113,7 +113,8 @@ class Idefics2VisionAttention(nn.Module):
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@@ -130,12 +131,14 @@ class Idefics2VisionAttention(nn.Module):
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
@@ -178,7 +181,8 @@ class Idefics2VisionMLP(nn.Module):
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
@@ -187,12 +191,14 @@ class Idefics2VisionMLP(nn.Module):
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -204,13 +210,22 @@ class Idefics2VisionMLP(nn.Module):
class Idefics2EncoderLayer(nn.Module):
def __init__(self, config: Idefics2Config):
def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(config)
self.self_attn = Idefics2VisionAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config)
self.mlp = Idefics2VisionMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
@@ -245,12 +260,20 @@ class Idefics2Encoder(nn.Module):
config: Idefics2Config
"""
def __init__(self, config: Idefics2Config):
def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layers = nn.ModuleList([
Idefics2EncoderLayer(config)
for _ in range(config.num_hidden_layers)
Idefics2EncoderLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
def forward(
@@ -275,12 +298,20 @@ class Idefics2Encoder(nn.Module):
class Idefics2VisionTransformer(nn.Module):
def __init__(self, config: Idefics2VisionConfig):
def __init__(
self,
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
embed_dim = config.hidden_size
self.config = config
self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder(config)
self.encoder = Idefics2Encoder(config,
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)