From 5708297e4e7f27c4d06d6e4f62d75c3d2d1fc674 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Tue, 6 Jan 2026 04:03:18 +0800 Subject: [PATCH] [Misc][Model][Refactor] Pass the prefix into Linear layers (#31669) Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- .../layers/mamba/mamba_mixer.py | 14 ++++- vllm/model_executor/models/aria.py | 20 +++++-- vllm/model_executor/models/gpt_neox.py | 2 +- vllm/model_executor/models/hunyuan_v1.py | 1 + vllm/model_executor/models/jamba.py | 1 + vllm/model_executor/models/jina_vl.py | 18 ++++-- vllm/model_executor/models/minicpm.py | 4 ++ vllm/model_executor/models/minicpm_eagle.py | 1 + .../models/mistral_large_3_eagle.py | 1 + vllm/model_executor/models/modernbert.py | 33 ++++++++--- vllm/model_executor/models/molmo.py | 59 ++++++++++++++++--- vllm/model_executor/models/olmoe.py | 6 +- vllm/model_executor/models/phimoe.py | 1 + vllm/model_executor/models/qwen.py | 18 +++++- vllm/model_executor/models/qwen3_next.py | 1 + vllm/model_executor/models/qwen_vl.py | 32 ++++++++-- vllm/model_executor/models/zamba2.py | 9 ++- 17 files changed, 181 insertions(+), 40 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 0b63acf2d..a8d412784 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -82,6 +82,7 @@ class MambaMixer(MambaBase, CustomOp): input_size=conv_kernel_size, output_size=intermediate_size, bias=use_conv_bias, + prefix=f"{prefix}.conv1d", ) # unsqueeze to fit conv1d weights shape into the linear weights shape. # Can't do this in `weight_loader` since it already exists in @@ -90,7 +91,10 @@ class MambaMixer(MambaBase, CustomOp): self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.in_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, bias=use_bias + hidden_size, + [intermediate_size] * 2, + bias=use_bias, + prefix=f"{prefix}.in_proj", ) # selective projection used to make dt, B and C input dependent @@ -98,12 +102,17 @@ class MambaMixer(MambaBase, CustomOp): intermediate_size, time_step_rank + ssm_state_size * 2, bias=False, + prefix=f"{prefix}.x_proj", ) # time step projection (discretization) - # In the forward we need to apply dt_proj without the bias, # as the bias is added in the selective scan kernel. self.dt_proj = ColumnParallelLinear( - time_step_rank, intermediate_size, bias=True, skip_bias_add=True + time_step_rank, + intermediate_size, + bias=True, + skip_bias_add=True, + prefix=f"{prefix}.dt_proj", ) def weight_loader(param: Parameter, loaded_weight: torch.Tensor): @@ -136,6 +145,7 @@ class MambaMixer(MambaBase, CustomOp): hidden_size, bias=use_bias, input_is_parallel=True, + prefix=f"{prefix}.out_proj", ) self.dt_layernorm = ( diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index c6d7f19cb..c7f44762f 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -127,11 +127,16 @@ class AriaProjectorMLP(nn.Module): in_features: int, hidden_features: int, output_dim: int, + prefix: str = "", ) -> None: super().__init__() - self.linear_in = ColumnParallelLinear(in_features, hidden_features, bias=False) - self.linear_out = RowParallelLinear(hidden_features, output_dim, bias=False) + self.linear_in = ColumnParallelLinear( + in_features, hidden_features, bias=False, prefix=f"{prefix}.linear_in" + ) + self.linear_out = RowParallelLinear( + hidden_features, output_dim, bias=False, prefix=f"{prefix}.linear_out" + ) self.act = get_act_fn("gelu_new") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -154,7 +159,7 @@ class AriaProjector(nn.Module): A tensor with the shape of (batch_size, query_number, output_dim) """ - def __init__(self, config: AriaConfig) -> None: + def __init__(self, config: AriaConfig, prefix: str = "") -> None: super().__init__() self.patch_to_query_dict = config.projector_patch_to_query_dict @@ -174,7 +179,10 @@ class AriaProjector(nn.Module): self.layer_norm = nn.LayerNorm(self.in_features) self.feed_forward = AriaProjectorMLP( - self.in_features, self.hidden_features, self.output_dim + self.in_features, + self.hidden_features, + self.output_dim, + prefix=f"{prefix}.feed_forward", ) def forward( @@ -536,7 +544,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): quant_config=quant_config, prefix=f"{prefix}.vision_tower", ) - self.multi_modal_projector = AriaProjector(config) + self.multi_modal_projector = AriaProjector( + config, prefix=maybe_prefix(prefix, "multi_modal_projector") + ) self.vocab_size = config.text_config.vocab_size self.language_model = AriaTextModel( vllm_config=vllm_config.with_hf_config(config.text_config), diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index c4d11b488..d994e380d 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -166,7 +166,7 @@ class GPTNeoXLayer(nn.Module): self.attention = GPTNeoXAttention( config, cache_config, quant_config, prefix=f"{prefix}.attention" ) - self.mlp = GPTNeoXMLP(config, quant_config) + self.mlp = GPTNeoXMLP(config, quant_config, prefix=f"{prefix}.mlp") def forward( self, diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 0e82e84c4..b7132e662 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -427,6 +427,7 @@ class HunYuanSparseMoeBlock(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, + prefix=f"{prefix}.shared_mlp", ) else: self.shared_mlp = None diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b2ad12be1..946a9f6fc 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -78,6 +78,7 @@ class JambaMoE(nn.Module): bias=False, quant_config=None, params_dtype=params_dtype, + prefix=f"{prefix}.router", ) self.experts = FusedMoE( diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 8bba7b628..7be3d4778 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -27,15 +27,23 @@ logger = init_logger(__name__) class JinaVLScorer(nn.Module): - def __init__(self, model_config: "ModelConfig"): + def __init__(self, model_config: "ModelConfig", prefix: str = ""): super().__init__() config = model_config.hf_config.get_text_config() head_dtype = model_config.head_dtype self.dense = ColumnParallelLinear( - config.hidden_size, config.hidden_size, params_dtype=head_dtype, bias=True + config.hidden_size, + config.hidden_size, + params_dtype=head_dtype, + bias=True, + prefix=f"{prefix}.dense", ) self.out_proj = RowParallelLinear( - config.hidden_size, config.num_labels, params_dtype=head_dtype, bias=True + config.hidden_size, + config.num_labels, + params_dtype=head_dtype, + bias=True, + prefix=f"{prefix}.out_proj", ) def forward(self, x, **kwargs): @@ -94,7 +102,9 @@ class JinaVLForSequenceClassification( pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.score = JinaVLScorer(vllm_config.model_config) + self.score = JinaVLScorer( + vllm_config.model_config, prefix=maybe_prefix(prefix, "score") + ) self.pooler = DispatchPooler( { "token_classify": Pooler.for_token_classify( diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index f104018d3..a05be794a 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -90,6 +90,7 @@ class MiniCPMMoE(nn.Module): intermediate_size: int, params_dtype: torch.dtype | None = None, tp_size: int | None = None, + prefix: str = "", ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() @@ -108,6 +109,7 @@ class MiniCPMMoE(nn.Module): bias=False, params_dtype=self.params_dtype, quant_config=None, + prefix=f"{prefix}.gate", ) self.ws = nn.Parameter( @@ -352,6 +354,7 @@ class MiniCPMDecoderLayer(nn.Module): hidden_act=self.config.hidden_act, hidden_act_param=getattr(self.config, "hidden_act_param", 0.0), quant_config=self.quant_config, + prefix=f"{self.prefix}.mlp", ) else: self.mlp = MiniCPMMoE( @@ -359,6 +362,7 @@ class MiniCPMDecoderLayer(nn.Module): top_k=self.config.num_experts_per_tok, hidden_size=self.config.hidden_size, intermediate_size=self.config.intermediate_size, + prefix=f"{self.prefix}.mlp", ) def forward( diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 9f3587a6d..e9f1a91bf 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -108,6 +108,7 @@ class EagleMiniCPMDecoderLayer(nn.Module): top_k=self.config.num_experts_per_tok, hidden_size=self.config.hidden_size, intermediate_size=self.config.intermediate_size, + prefix=f"{self.prefix}.mlp", ) def forward( diff --git a/vllm/model_executor/models/mistral_large_3_eagle.py b/vllm/model_executor/models/mistral_large_3_eagle.py index 37cd4324e..830f210e7 100644 --- a/vllm/model_executor/models/mistral_large_3_eagle.py +++ b/vllm/model_executor/models/mistral_large_3_eagle.py @@ -67,6 +67,7 @@ class EagleMistralLarge3Model(DeepseekV2Model): input_is_parallel=False, quant_config=quant_config, return_bias=False, + prefix=maybe_prefix(prefix, "fc"), ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 4655ffa7b..fb8f6a28e 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -63,7 +63,9 @@ class ModernBertEmbeddings(nn.Module): class ModernBertAttention(nn.Module): - def __init__(self, config: ModernBertConfig, layer_id: int | None = None): + def __init__( + self, config: ModernBertConfig, layer_id: int | None = None, prefix: str = "" + ): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -80,6 +82,7 @@ class ModernBertAttention(nn.Module): self.head_dim, self.num_heads, bias=config.attention_bias, + prefix=f"{prefix}.Wqkv", ) if layer_types := getattr(config, "layer_types", None): @@ -117,7 +120,10 @@ class ModernBertAttention(nn.Module): per_layer_sliding_window=sliding_window, ) self.Wo = RowParallelLinear( - config.hidden_size, config.hidden_size, bias=config.attention_bias + config.hidden_size, + config.hidden_size, + bias=config.attention_bias, + prefix=f"{prefix}.Wo", ) def forward( @@ -135,7 +141,7 @@ class ModernBertAttention(nn.Module): class ModernBertMLP(nn.Module): - def __init__(self, config: ModernBertConfig): + def __init__(self, config: ModernBertConfig, prefix: str = ""): super().__init__() self.config = config self.Wi = nn.Linear( @@ -143,7 +149,10 @@ class ModernBertMLP(nn.Module): ) self.act = nn.GELU() self.Wo = RowParallelLinear( - config.intermediate_size, config.hidden_size, bias=config.mlp_bias + config.intermediate_size, + config.hidden_size, + bias=config.mlp_bias, + prefix=f"{prefix}.Wo", ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -163,11 +172,13 @@ class ModernBertLayer(nn.Module): self.attn_norm = nn.LayerNorm( config.hidden_size, eps=config.norm_eps, bias=config.norm_bias ) - self.attn = ModernBertAttention(config=config, layer_id=layer_id) + self.attn = ModernBertAttention( + config=config, layer_id=layer_id, prefix=f"{prefix}.attn" + ) self.mlp_norm = nn.LayerNorm( config.hidden_size, eps=config.norm_eps, bias=config.norm_bias ) - self.mlp = ModernBertMLP(config) + self.mlp = ModernBertMLP(config, prefix=f"{prefix}.mlp") def forward( self, @@ -189,7 +200,11 @@ class ModernBertEncoderLayer(nn.Module): config = vllm_config.model_config.hf_config self.layers = nn.ModuleList( [ - ModernBertLayer(config=config, layer_id=layer_id) + ModernBertLayer( + config=config, + layer_id=layer_id, + prefix=f"{prefix}.layers.{layer_id}", + ) for layer_id in range(config.num_hidden_layers) ] ) @@ -220,7 +235,9 @@ class ModernBertModel(nn.Module): config = vllm_config.model_config.hf_config self.config = config self.embeddings = ModernBertEmbeddings(config) - self.encoder_layer = ModernBertEncoderLayer(vllm_config) + self.encoder_layer = ModernBertEncoderLayer( + vllm_config, prefix=f"{prefix}.encoder_layer" + ) self.final_norm = nn.LayerNorm( config.hidden_size, eps=config.norm_eps, bias=config.norm_bias ) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 9c741e1f5..5ccc5653e 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -142,6 +142,7 @@ class ViTMLP(nn.Module): self, config: VisionBackboneConfig, quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() self.w1 = ColumnParallelLinear( @@ -149,6 +150,7 @@ class ViTMLP(nn.Module): config.image_mlp_dim, bias=True, quant_config=quant_config, + prefix=f"{prefix}.w1", ) # Activation function. assert config.image_mlp_activations == "quick_gelu" @@ -158,6 +160,7 @@ class ViTMLP(nn.Module): config.image_emb_dim, bias=True, quant_config=quant_config, + prefix=f"{prefix}.w2", ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -176,6 +179,7 @@ class MultiHeadDotProductAttention(nn.Module): use_bias: bool = True, nlayers: int = 1, quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() @@ -202,24 +206,28 @@ class MultiHeadDotProductAttention(nn.Module): self.total_num_heads * self.head_dim, bias=use_bias, quant_config=quant_config, + prefix=f"{prefix}.wq", ) self.wk = ColumnParallelLinear( nlayers * self.hidden_size, self.total_num_kv_heads * self.head_dim, bias=use_bias, quant_config=quant_config, + prefix=f"{prefix}.wk", ) self.wv = ColumnParallelLinear( nlayers * self.hidden_size, self.total_num_kv_heads * self.head_dim, bias=use_bias, quant_config=quant_config, + prefix=f"{prefix}.wv", ) self.wo = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=use_bias, quant_config=quant_config, + prefix=f"{prefix}.wo", ) self.scale = self.head_dim**-0.5 @@ -254,10 +262,15 @@ class ResidualAttentionBlock(nn.Module): self, config: VisionBackboneConfig, quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() - self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config) - self.feed_forward = ViTMLP(config, quant_config) + self.attention = MultiHeadDotProductAttention( + config, quant_config=quant_config, prefix=f"{prefix}.attention" + ) + self.feed_forward = ViTMLP( + config, quant_config, prefix=f"{prefix}.feed_forward" + ) self.attention_norm = nn.LayerNorm( config.image_emb_dim, eps=config.image_norm_eps, @@ -280,12 +293,15 @@ class BlockCollection(nn.Module): self, config: VisionBackboneConfig, quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() self.resblocks = nn.ModuleList( [ - ResidualAttentionBlock(config, quant_config) - for _ in range(config.image_num_layers) + ResidualAttentionBlock( + config, quant_config, prefix=f"{prefix}.resblocks.{i}" + ) + for i in range(config.image_num_layers) ] ) @@ -308,6 +324,7 @@ class VisionTransformer(nn.Module): self, config: VisionBackboneConfig, quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() scale = config.image_emb_dim**-0.5 @@ -324,7 +341,9 @@ class VisionTransformer(nn.Module): bias=False, ) self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps) - self.transformer = BlockCollection(config, quant_config) + self.transformer = BlockCollection( + config, quant_config, prefix=f"{prefix}.transformer" + ) def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: cls_emb = self.positional_embedding[0:1] @@ -419,6 +438,7 @@ class MolmoAttention(nn.Module): self.total_num_kv_heads, bias=config.qkv_bias, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.tp_rank: int | None = None @@ -454,6 +474,7 @@ class MolmoAttention(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) def _apply_qk_norm( @@ -493,6 +514,7 @@ class LanguageModelMLP(nn.Module): config: PretrainedConfig, input_dim: int | None = None, quant_config: QuantizationConfig | None = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -503,6 +525,7 @@ class LanguageModelMLP(nn.Module): [self.intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) # Activation function. self.act_fn = MulAndSilu() @@ -512,6 +535,7 @@ class LanguageModelMLP(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) def forward( @@ -532,6 +556,7 @@ class ImageProjectorMLP(nn.Module): config: PretrainedConfig, input_dim: int | None = None, quant_config: QuantizationConfig | None = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -542,6 +567,7 @@ class ImageProjectorMLP(nn.Module): [self.intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.merged_linear", ) # Activation function. self.act_fn = SiluAndMul() @@ -552,6 +578,7 @@ class ImageProjectorMLP(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) def forward( @@ -579,7 +606,9 @@ class MolmoDecoderLayer(nn.Module): ) # MLP block. - self.mlp = LanguageModelMLP(config, quant_config=quant_config) + self.mlp = LanguageModelMLP( + config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) # LayerNorm assert config.layer_norm_type == "rms" @@ -643,6 +672,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): config: PretrainedConfig, vision_config: VisionBackboneConfig, quant_config: QuantizationConfig | None = None, + prefix: str = "", ) -> None: super().__init__() self.vit_layers = VIT_LAYERS @@ -651,18 +681,24 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): (self.image_num_patch[0] + 1) // POOLING_SIZE, (self.image_num_patch[1] + 1) // POOLING_SIZE, ) - self.image_vit = VisionTransformer(vision_config, quant_config=quant_config) + self.image_vit = VisionTransformer( + vision_config, quant_config=quant_config, prefix=f"{prefix}.image_vit" + ) self.num_prefix_tokens = self.image_vit.num_prefix_tokens assert self.num_prefix_tokens in {0, 1}, ( "Only 0 or 1 prefix tokens are supported" ) self.image_pooling_2d = MultiHeadDotProductAttention( - vision_config, nlayers=len(self.vit_layers), quant_config=quant_config + vision_config, + nlayers=len(self.vit_layers), + quant_config=quant_config, + prefix=f"{prefix}.image_pooling_2d", ) self.image_projector = ImageProjectorMLP( config, input_dim=vision_config.image_emb_dim, quant_config=quant_config, + prefix=f"{prefix}.image_projector", ) image_dim = vision_config.image_emb_dim * len(self.vit_layers) @@ -1405,7 +1441,12 @@ class MolmoForCausalLM( self.multimodal_config = multimodal_config vision_config = VisionBackboneConfig() - self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config) + self.vision_backbone = MolmoVisionBackbone( + config, + vision_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_backbone"), + ) self.model = MolmoModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index a5a926151..3d7aa2000 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -86,7 +86,11 @@ class OlmoeMoE(nn.Module): # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear( - hidden_size, num_experts, bias=False, quant_config=None + hidden_size, + num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", ) self.experts = FusedMoE( diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 14f73d0c6..e28772061 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -272,6 +272,7 @@ class PhiMoE(nn.Module): bias=False, params_dtype=params_dtype, quant_config=None, + prefix=f"{prefix}.gate", ) self.experts = FusedMoE( diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 61a6e6780..50b53a1ff 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -56,13 +56,22 @@ class QWenMLP(nn.Module): intermediate_size: int, hidden_act: str = "silu", quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) self.c_proj = RowParallelLinear( - intermediate_size, hidden_size, bias=False, quant_config=quant_config + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", ) if hidden_act != "silu": raise ValueError( @@ -163,7 +172,10 @@ class QWenBlock(nn.Module): self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.mlp = QWenMLP( - config.hidden_size, config.intermediate_size // 2, quant_config=quant_config + config.hidden_size, + config.intermediate_size // 2, + quant_config=quant_config, + prefix=f"{prefix}.mlp", ) def forward( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index ccf6cc6e5..7137d3d8e 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -864,6 +864,7 @@ class Qwen3NextDecoderLayer(nn.Module): intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = Qwen3NextRMSNorm( diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index caac14716..df0733de9 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -109,6 +109,7 @@ class VisualAttention(nn.Module): bias: bool = True, kdim: int | None = None, vdim: int | None = None, + prefix: str = "", ): super().__init__() self.embed_dim = embed_dim @@ -128,8 +129,12 @@ class VisualAttention(nn.Module): assert self._qkv_same_embed_dim, ( "Visual Attention implementation only supports self-attention" ) - self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim) - self.out_proj = ReplicatedLinear(embed_dim, embed_dim) + self.in_proj = ReplicatedLinear( + embed_dim, 3 * embed_dim, prefix=f"{prefix}.in_proj" + ) + self.out_proj = ReplicatedLinear( + embed_dim, embed_dim, prefix=f"{prefix}.out_proj" + ) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) def forward( @@ -214,10 +219,15 @@ class QwenVLMLP(nn.Module): hidden_size: int, intermediate_size: int, quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() self.c_fc = ColumnParallelLinear( - hidden_size, intermediate_size, bias=True, quant_config=quant_config + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_fc", ) self.act_fn = get_act_fn("gelu") self.c_proj = RowParallelLinear( @@ -225,6 +235,7 @@ class QwenVLMLP(nn.Module): hidden_size, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_proj", ) def forward(self, x): @@ -242,17 +253,19 @@ class VisualAttentionBlock(nn.Module): mlp_ratio: float = 4.0, norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() self.ln_1 = norm_layer(d_model) self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) - self.attn = VisualAttention(d_model, n_head) + self.attn = VisualAttention(d_model, n_head, prefix=f"{prefix}.attn") self.mlp = QwenVLMLP( hidden_size=d_model, intermediate_size=mlp_width, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) def attention( @@ -282,6 +295,7 @@ class TransformerBlock(nn.Module): mlp_ratio: float = 4.0, norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() self.width = width @@ -295,8 +309,9 @@ class TransformerBlock(nn.Module): mlp_ratio, norm_layer=norm_layer, quant_config=quant_config, + prefix=f"{prefix}.resblocks.{i}", ) - for _ in range(layers) + for i in range(layers) ] ) @@ -327,6 +342,7 @@ class VisionTransformer(nn.Module): output_dim: int = 512, image_start_id: int = 151857, quant_config: QuantizationConfig | None = None, + prefix: str = "", **kwargs, ): super().__init__() @@ -356,6 +372,7 @@ class VisionTransformer(nn.Module): mlp_ratio, norm_layer=norm_layer, quant_config=quant_config, + prefix=f"{prefix}.transformer", ) self.attn_pool = Resampler2( @@ -366,6 +383,7 @@ class VisionTransformer(nn.Module): norm_layer=norm_layer, adaptive=False, do_post_projection=False, + prefix=f"{prefix}.attn_pool", ).to( device=self.positional_embedding.device, dtype=self.positional_embedding.dtype, @@ -413,7 +431,9 @@ class QwenVLModel(QWenModel): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - self.visual = VisionTransformer(**config.visual, quant_config=quant_config) + self.visual = VisionTransformer( + **config.visual, quant_config=quant_config, prefix=f"{prefix}.visual" + ) @lru_cache(maxsize=1) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index fe157887e..b5132cd86 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -86,7 +86,13 @@ class Zamba2LoRA(nn.Module): B_class = MergedColumnParallelLinear else: B_class = ColumnParallelLinear - self.B = B_class(rank, output_dim, bias=False, quant_config=quant_config) + self.B = B_class( + rank, + output_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.B", + ) def forward( self, @@ -346,6 +352,7 @@ class Zamba2MLP(nn.Module): config.adapter_rank, 2 * [self.intermediate_size], quant_config, + prefix=f"{prefix}.gate_up_proj_adapter_list.{block_idx}", ) else: gate_up_proj_adapter = nn.Identity()