[Misc][Model][Refactor] Pass the prefix into Linear layers (#31669)
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -142,6 +142,7 @@ class ViTMLP(nn.Module):
|
||||
self,
|
||||
config: VisionBackboneConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.w1 = ColumnParallelLinear(
|
||||
@@ -149,6 +150,7 @@ class ViTMLP(nn.Module):
|
||||
config.image_mlp_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.w1",
|
||||
)
|
||||
# Activation function.
|
||||
assert config.image_mlp_activations == "quick_gelu"
|
||||
@@ -158,6 +160,7 @@ class ViTMLP(nn.Module):
|
||||
config.image_emb_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.w2",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -176,6 +179,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
use_bias: bool = True,
|
||||
nlayers: int = 1,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -202,24 +206,28 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
self.total_num_heads * self.head_dim,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wq",
|
||||
)
|
||||
self.wk = ColumnParallelLinear(
|
||||
nlayers * self.hidden_size,
|
||||
self.total_num_kv_heads * self.head_dim,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wk",
|
||||
)
|
||||
self.wv = ColumnParallelLinear(
|
||||
nlayers * self.hidden_size,
|
||||
self.total_num_kv_heads * self.head_dim,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wv",
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wo",
|
||||
)
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
@@ -254,10 +262,15 @@ class ResidualAttentionBlock(nn.Module):
|
||||
self,
|
||||
config: VisionBackboneConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config)
|
||||
self.feed_forward = ViTMLP(config, quant_config)
|
||||
self.attention = MultiHeadDotProductAttention(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.attention"
|
||||
)
|
||||
self.feed_forward = ViTMLP(
|
||||
config, quant_config, prefix=f"{prefix}.feed_forward"
|
||||
)
|
||||
self.attention_norm = nn.LayerNorm(
|
||||
config.image_emb_dim,
|
||||
eps=config.image_norm_eps,
|
||||
@@ -280,12 +293,15 @@ class BlockCollection(nn.Module):
|
||||
self,
|
||||
config: VisionBackboneConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(config, quant_config)
|
||||
for _ in range(config.image_num_layers)
|
||||
ResidualAttentionBlock(
|
||||
config, quant_config, prefix=f"{prefix}.resblocks.{i}"
|
||||
)
|
||||
for i in range(config.image_num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -308,6 +324,7 @@ class VisionTransformer(nn.Module):
|
||||
self,
|
||||
config: VisionBackboneConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
scale = config.image_emb_dim**-0.5
|
||||
@@ -324,7 +341,9 @@ class VisionTransformer(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps)
|
||||
self.transformer = BlockCollection(config, quant_config)
|
||||
self.transformer = BlockCollection(
|
||||
config, quant_config, prefix=f"{prefix}.transformer"
|
||||
)
|
||||
|
||||
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
|
||||
cls_emb = self.positional_embedding[0:1]
|
||||
@@ -419,6 +438,7 @@ class MolmoAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.tp_rank: int | None = None
|
||||
@@ -454,6 +474,7 @@ class MolmoAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
def _apply_qk_norm(
|
||||
@@ -493,6 +514,7 @@ class LanguageModelMLP(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
input_dim: int | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -503,6 +525,7 @@ class LanguageModelMLP(nn.Module):
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
# Activation function.
|
||||
self.act_fn = MulAndSilu()
|
||||
@@ -512,6 +535,7 @@ class LanguageModelMLP(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -532,6 +556,7 @@ class ImageProjectorMLP(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
input_dim: int | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -542,6 +567,7 @@ class ImageProjectorMLP(nn.Module):
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merged_linear",
|
||||
)
|
||||
# Activation function.
|
||||
self.act_fn = SiluAndMul()
|
||||
@@ -552,6 +578,7 @@ class ImageProjectorMLP(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -579,7 +606,9 @@ class MolmoDecoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
# MLP block.
|
||||
self.mlp = LanguageModelMLP(config, quant_config=quant_config)
|
||||
self.mlp = LanguageModelMLP(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.mlp"
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
assert config.layer_norm_type == "rms"
|
||||
@@ -643,6 +672,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
config: PretrainedConfig,
|
||||
vision_config: VisionBackboneConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vit_layers = VIT_LAYERS
|
||||
@@ -651,18 +681,24 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
(self.image_num_patch[0] + 1) // POOLING_SIZE,
|
||||
(self.image_num_patch[1] + 1) // POOLING_SIZE,
|
||||
)
|
||||
self.image_vit = VisionTransformer(vision_config, quant_config=quant_config)
|
||||
self.image_vit = VisionTransformer(
|
||||
vision_config, quant_config=quant_config, prefix=f"{prefix}.image_vit"
|
||||
)
|
||||
self.num_prefix_tokens = self.image_vit.num_prefix_tokens
|
||||
assert self.num_prefix_tokens in {0, 1}, (
|
||||
"Only 0 or 1 prefix tokens are supported"
|
||||
)
|
||||
self.image_pooling_2d = MultiHeadDotProductAttention(
|
||||
vision_config, nlayers=len(self.vit_layers), quant_config=quant_config
|
||||
vision_config,
|
||||
nlayers=len(self.vit_layers),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.image_pooling_2d",
|
||||
)
|
||||
self.image_projector = ImageProjectorMLP(
|
||||
config,
|
||||
input_dim=vision_config.image_emb_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.image_projector",
|
||||
)
|
||||
|
||||
image_dim = vision_config.image_emb_dim * len(self.vit_layers)
|
||||
@@ -1405,7 +1441,12 @@ class MolmoForCausalLM(
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
vision_config = VisionBackboneConfig()
|
||||
self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config)
|
||||
self.vision_backbone = MolmoVisionBackbone(
|
||||
config,
|
||||
vision_config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_backbone"),
|
||||
)
|
||||
self.model = MolmoModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user