[Misc][Model][Refactor] Pass the prefix into Linear layers (#31669)

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2026-01-06 04:03:18 +08:00
committed by GitHub
parent 02dbb933cb
commit 5708297e4e
17 changed files with 181 additions and 40 deletions

View File

@@ -142,6 +142,7 @@ class ViTMLP(nn.Module):
self,
config: VisionBackboneConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.w1 = ColumnParallelLinear(
@@ -149,6 +150,7 @@ class ViTMLP(nn.Module):
config.image_mlp_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.w1",
)
# Activation function.
assert config.image_mlp_activations == "quick_gelu"
@@ -158,6 +160,7 @@ class ViTMLP(nn.Module):
config.image_emb_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.w2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -176,6 +179,7 @@ class MultiHeadDotProductAttention(nn.Module):
use_bias: bool = True,
nlayers: int = 1,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -202,24 +206,28 @@ class MultiHeadDotProductAttention(nn.Module):
self.total_num_heads * self.head_dim,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.wq",
)
self.wk = ColumnParallelLinear(
nlayers * self.hidden_size,
self.total_num_kv_heads * self.head_dim,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.wk",
)
self.wv = ColumnParallelLinear(
nlayers * self.hidden_size,
self.total_num_kv_heads * self.head_dim,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.wv",
)
self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.wo",
)
self.scale = self.head_dim**-0.5
@@ -254,10 +262,15 @@ class ResidualAttentionBlock(nn.Module):
self,
config: VisionBackboneConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config)
self.feed_forward = ViTMLP(config, quant_config)
self.attention = MultiHeadDotProductAttention(
config, quant_config=quant_config, prefix=f"{prefix}.attention"
)
self.feed_forward = ViTMLP(
config, quant_config, prefix=f"{prefix}.feed_forward"
)
self.attention_norm = nn.LayerNorm(
config.image_emb_dim,
eps=config.image_norm_eps,
@@ -280,12 +293,15 @@ class BlockCollection(nn.Module):
self,
config: VisionBackboneConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(config, quant_config)
for _ in range(config.image_num_layers)
ResidualAttentionBlock(
config, quant_config, prefix=f"{prefix}.resblocks.{i}"
)
for i in range(config.image_num_layers)
]
)
@@ -308,6 +324,7 @@ class VisionTransformer(nn.Module):
self,
config: VisionBackboneConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
scale = config.image_emb_dim**-0.5
@@ -324,7 +341,9 @@ class VisionTransformer(nn.Module):
bias=False,
)
self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps)
self.transformer = BlockCollection(config, quant_config)
self.transformer = BlockCollection(
config, quant_config, prefix=f"{prefix}.transformer"
)
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
cls_emb = self.positional_embedding[0:1]
@@ -419,6 +438,7 @@ class MolmoAttention(nn.Module):
self.total_num_kv_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.tp_rank: int | None = None
@@ -454,6 +474,7 @@ class MolmoAttention(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
def _apply_qk_norm(
@@ -493,6 +514,7 @@ class LanguageModelMLP(nn.Module):
config: PretrainedConfig,
input_dim: int | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -503,6 +525,7 @@ class LanguageModelMLP(nn.Module):
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
# Activation function.
self.act_fn = MulAndSilu()
@@ -512,6 +535,7 @@ class LanguageModelMLP(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(
@@ -532,6 +556,7 @@ class ImageProjectorMLP(nn.Module):
config: PretrainedConfig,
input_dim: int | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -542,6 +567,7 @@ class ImageProjectorMLP(nn.Module):
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.merged_linear",
)
# Activation function.
self.act_fn = SiluAndMul()
@@ -552,6 +578,7 @@ class ImageProjectorMLP(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(
@@ -579,7 +606,9 @@ class MolmoDecoderLayer(nn.Module):
)
# MLP block.
self.mlp = LanguageModelMLP(config, quant_config=quant_config)
self.mlp = LanguageModelMLP(
config, quant_config=quant_config, prefix=f"{prefix}.mlp"
)
# LayerNorm
assert config.layer_norm_type == "rms"
@@ -643,6 +672,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
config: PretrainedConfig,
vision_config: VisionBackboneConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.vit_layers = VIT_LAYERS
@@ -651,18 +681,24 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
(self.image_num_patch[0] + 1) // POOLING_SIZE,
(self.image_num_patch[1] + 1) // POOLING_SIZE,
)
self.image_vit = VisionTransformer(vision_config, quant_config=quant_config)
self.image_vit = VisionTransformer(
vision_config, quant_config=quant_config, prefix=f"{prefix}.image_vit"
)
self.num_prefix_tokens = self.image_vit.num_prefix_tokens
assert self.num_prefix_tokens in {0, 1}, (
"Only 0 or 1 prefix tokens are supported"
)
self.image_pooling_2d = MultiHeadDotProductAttention(
vision_config, nlayers=len(self.vit_layers), quant_config=quant_config
vision_config,
nlayers=len(self.vit_layers),
quant_config=quant_config,
prefix=f"{prefix}.image_pooling_2d",
)
self.image_projector = ImageProjectorMLP(
config,
input_dim=vision_config.image_emb_dim,
quant_config=quant_config,
prefix=f"{prefix}.image_projector",
)
image_dim = vision_config.image_emb_dim * len(self.vit_layers)
@@ -1405,7 +1441,12 @@ class MolmoForCausalLM(
self.multimodal_config = multimodal_config
vision_config = VisionBackboneConfig()
self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config)
self.vision_backbone = MolmoVisionBackbone(
config,
vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_backbone"),
)
self.model = MolmoModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)