From 1d65283e95f4d978c984df8585ca3f477166e651 Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Tue, 17 Feb 2026 17:29:27 +0800 Subject: [PATCH] Revert "[Models] Fuse Qwen3.5 GDN's qkvz_proj and ba_proj" (#34683) --- vllm/model_executor/layers/linear.py | 34 +--- vllm/model_executor/models/qwen3_5.py | 198 +++++++++++++++++++---- vllm/model_executor/models/qwen3_next.py | 37 ++--- 3 files changed, 182 insertions(+), 87 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 23035816b..bbd7267fd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -685,13 +685,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: tuple[int, ...] | int | None = None, + loaded_shard_id: int | None = None, ): - if isinstance(loaded_shard_id, tuple): - raise NotImplementedError( - "Shard id with multiple indices is not supported in weight_loader, " - "please use weight_loader_v2 instead." - ) # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -830,10 +825,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param_data.copy_(loaded_weight) def _load_fused_module_from_checkpoint( - self, - param: BasevLLMParameter, - loaded_weight: torch.Tensor, - output_sizes: list[int] | None = None, + self, param: BasevLLMParameter, loaded_weight: torch.Tensor ): """ Handle special case for models where MLP layers are already @@ -847,8 +839,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): current_shard_offset = 0 shard_offsets: list[tuple[int, int, int]] = [] - output_sizes = output_sizes or self.output_sizes - for i, output_size in enumerate(output_sizes): + for i, output_size in enumerate(self.output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) current_shard_offset += output_size @@ -873,30 +864,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear): self, param: BasevLLMParameter, loaded_weight: torch.Tensor, - loaded_shard_id: tuple[int, ...] | int | None = None, + loaded_shard_id: int | None = None, ): - if loaded_shard_id is None or isinstance(loaded_shard_id, tuple): + if loaded_shard_id is None: if isinstance(param, PerTensorScaleParameter): param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_merged_column_weight(loaded_weight=loaded_weight) return - output_sizes = ( - [self.output_sizes[idx] for idx in loaded_shard_id] - if loaded_shard_id - else None - ) - if isinstance(param, BlockQuantScaleParameter): - weight_block_size = getattr(self, "weight_block_size", None) - output_sizes = [ - adjust_block_scale_shard(weight_block_size, size, 0)[0] - for size in (output_sizes or self.output_sizes) - ] # TODO: @dsikka - move to parameter.py - self._load_fused_module_from_checkpoint( - param, loaded_weight, output_sizes=output_sizes - ) + self._load_fused_module_from_checkpoint(param, loaded_weight) return assert loaded_shard_id < len(self.output_sizes) diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index 7c355e8b0..5c76bf7ef 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -30,20 +30,36 @@ from collections.abc import Callable, Iterable import torch from einops import rearrange from torch import nn +from transformers.activations import ACT2FN from vllm.compilation.decorators import support_torch_compile from vllm.config import ( + CacheConfig, + ModelConfig, + SpeculativeConfig, VllmConfig, + get_current_vllm_config, ) from vllm.distributed import ( + divide, get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3_5RMSNorm, ) -from vllm.model_executor.layers.linear import MergedColumnParallelLinear +from vllm.model_executor.layers.layernorm import RMSNormGated +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + mamba_v2_sharded_weight_loader, +) from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateCopyFunc, MambaStateCopyFuncCalculator, @@ -57,8 +73,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, + sharded_weight_loader, ) +from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.qwen3_5 import ( Qwen3_5Config, @@ -80,6 +99,7 @@ from .interfaces import ( ) from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from .qwen3_next import ( + ChunkGatedDeltaRule, Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextGatedDeltaNet, @@ -119,31 +139,154 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo): class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): + def __init__( + self, + config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + speculative_config: SpeculativeConfig | None = None, + prefix: str = "", + ) -> None: + super(Qwen3NextGatedDeltaNet, self).__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = extract_layer_index(prefix) + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + self.prefix = prefix + + self.config = config + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config + else 0 + ) + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj_qkv = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[self.key_dim, self.key_dim, self.value_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_qkv", + ) + self.in_proj_z = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.value_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_z", + ) + self.in_proj_b = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_b", + ) + self.in_proj_a = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_a", + ) + + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + query_key_settings, + query_key_settings, + value_settings, + ], + self.tp_size, + self.tp_rank, + ) + }, + ) + + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), + ) + self.A_log = nn.Parameter( + torch.empty( + divide(self.num_v_heads, self.tp_size), + ) + ) + + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + device=current_platform.current_device(), + dtype=config.dtype, + ) + + self.out_proj = RowParallelLinear( + self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.chunk_gated_delta_rule = ChunkGatedDeltaRule() + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + def fix_query_key_value_ordering( self, - mixed_qkvz: torch.Tensor, - mixed_ba: torch.Tensor, + mixed_qkv, + z, + b, + a, ): raise NotImplementedError( "Qwen3.5 Series dont need to fix query key value ordering" ) - def create_qkvz_proj( - self, - hidden_size: int, - key_dim: int, - value_dim: int, - quant_config: QuantizationConfig | None, - prefix: str, - ) -> MergedColumnParallelLinear: - return MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[key_dim, key_dim, value_dim, value_dim], - bias=False, - quant_config=quant_config, - prefix=prefix, - ) - def forward( self, hidden_states: torch.Tensor, @@ -160,13 +303,11 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): # ============================================================ # Part 1: Input Projection # ============================================================ - mixed_qkvz, _ = self.in_proj_qkvz(hidden_states) - qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size - z_size = self.value_dim // self.tp_size - mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1) + mixed_qkv, _ = self.in_proj_qkv(hidden_states) + z, _ = self.in_proj_z(hidden_states) z = z.reshape(z.size(0), -1, self.head_v_dim) - ba, _ = self.in_proj_ba(hidden_states) - b, a = ba.chunk(2, dim=-1) + b, _ = self.in_proj_b(hidden_states) + a, _ = self.in_proj_a(hidden_states) b = b.contiguous() a = a.contiguous() @@ -365,18 +506,11 @@ class Qwen3_5Model(Qwen3NextModel): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) - # self attention ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - # mlp ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - # GDN - ("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)), - ("in_proj_qkvz", "in_proj_z", 3), - ("in_proj_ba", "in_proj_b", 0), - ("in_proj_ba", "in_proj_a", 1), ] params_dict = dict(self.named_parameters()) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 59468c7bf..6da5bca1b 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -44,7 +44,6 @@ from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.linear import ( ColumnParallelLinear, - MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, @@ -407,19 +406,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) # projection of the input hidden states - # Qwen3-Next and Qwen3.5 has a different qkv_proj layout, - # we need to create qkvz_proj adaptively here. - self.in_proj_qkvz = self.create_qkvz_proj( - hidden_size=self.hidden_size, - key_dim=self.key_dim, - value_dim=self.value_dim, + self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + self.projection_size_ba = self.num_v_heads * 2 + self.in_proj_qkvz = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.projection_size_qkvz, + bias=False, quant_config=quant_config, prefix=f"{prefix}.in_proj_qkvz", ) # ba_proj doesn't support blockwise fp8 quantization. - self.in_proj_ba = MergedColumnParallelLinear( + self.in_proj_ba = ColumnParallelLinear( input_size=self.hidden_size, - output_sizes=[self.num_v_heads] * 2, + output_size=self.projection_size_ba, bias=False, quant_config=quant_config, prefix=f"{prefix}.in_proj_ba", @@ -485,26 +484,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - def create_qkvz_proj( - self, - hidden_size: int, - key_dim: int, - value_dim: int, - quant_config: QuantizationConfig | None, - prefix: str, - ) -> MergedColumnParallelLinear: - return MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[sum((key_dim, key_dim, value_dim)), value_dim], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_qkvz", - ) - def fix_query_key_value_ordering( self, - mixed_qkvz: torch.Tensor, - mixed_ba: torch.Tensor, + mixed_qkvz, + mixed_ba, ): """ Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.