Revert "[Models] Fuse Qwen3.5 GDN's qkvz_proj and ba_proj" (#34683)

This commit is contained in:
Jiangyun Zhu
2026-02-17 17:29:27 +08:00
committed by GitHub
parent c464b57374
commit 1d65283e95
3 changed files with 182 additions and 87 deletions

View File

@@ -685,13 +685,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self, self,
param: Parameter, param: Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: tuple[int, ...] | int | None = None, loaded_shard_id: int | None = None,
): ):
if isinstance(loaded_shard_id, tuple):
raise NotImplementedError(
"Shard id with multiple indices is not supported in weight_loader, "
"please use weight_loader_v2 instead."
)
# Special case for GGUF # Special case for GGUF
# initialize GGUF param after we know the quantize type # initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight = getattr(param, "is_gguf_weight", False)
@@ -830,10 +825,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def _load_fused_module_from_checkpoint( def _load_fused_module_from_checkpoint(
self, self, param: BasevLLMParameter, loaded_weight: torch.Tensor
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
output_sizes: list[int] | None = None,
): ):
""" """
Handle special case for models where MLP layers are already Handle special case for models where MLP layers are already
@@ -847,8 +839,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset = 0 current_shard_offset = 0
shard_offsets: list[tuple[int, int, int]] = [] shard_offsets: list[tuple[int, int, int]] = []
output_sizes = output_sizes or self.output_sizes for i, output_size in enumerate(self.output_sizes):
for i, output_size in enumerate(output_sizes):
shard_offsets.append((i, current_shard_offset, output_size)) shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size current_shard_offset += output_size
@@ -873,30 +864,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self, self,
param: BasevLLMParameter, param: BasevLLMParameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: tuple[int, ...] | int | None = None, loaded_shard_id: int | None = None,
): ):
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple): if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter): if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
return return
elif type(param) in (RowvLLMParameter, BasevLLMParameter): elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight) param.load_merged_column_weight(loaded_weight=loaded_weight)
return return
output_sizes = (
[self.output_sizes[idx] for idx in loaded_shard_id]
if loaded_shard_id
else None
)
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
output_sizes = [
adjust_block_scale_shard(weight_block_size, size, 0)[0]
for size in (output_sizes or self.output_sizes)
]
# TODO: @dsikka - move to parameter.py # TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint( self._load_fused_module_from_checkpoint(param, loaded_weight)
param, loaded_weight, output_sizes=output_sizes
)
return return
assert loaded_shard_id < len(self.output_sizes) assert loaded_shard_id < len(self.output_sizes)

View File

@@ -30,20 +30,36 @@ from collections.abc import Callable, Iterable
import torch import torch
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from transformers.activations import ACT2FN
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ( from vllm.config import (
CacheConfig,
ModelConfig,
SpeculativeConfig,
VllmConfig, VllmConfig,
get_current_vllm_config,
) )
from vllm.distributed import ( from vllm.distributed import (
divide,
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3_5RMSNorm, GemmaRMSNorm as Qwen3_5RMSNorm,
) )
from vllm.model_executor.layers.linear import MergedColumnParallelLinear from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
mamba_v2_sharded_weight_loader,
)
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc, MambaStateCopyFunc,
MambaStateCopyFuncCalculator, MambaStateCopyFuncCalculator,
@@ -57,8 +73,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
) )
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
sharded_weight_loader,
) )
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.qwen3_5 import ( from vllm.transformers_utils.configs.qwen3_5 import (
Qwen3_5Config, Qwen3_5Config,
@@ -80,6 +99,7 @@ from .interfaces import (
) )
from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from .qwen3_next import ( from .qwen3_next import (
ChunkGatedDeltaRule,
Qwen3NextAttention, Qwen3NextAttention,
Qwen3NextDecoderLayer, Qwen3NextDecoderLayer,
Qwen3NextGatedDeltaNet, Qwen3NextGatedDeltaNet,
@@ -119,31 +139,154 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
def __init__(
self,
config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
speculative_config: SpeculativeConfig | None = None,
prefix: str = "",
) -> None:
super(Qwen3NextGatedDeltaNet, self).__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = extract_layer_index(prefix)
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
self.prefix = prefix
self.config = config
self.model_config = model_config
self.cache_config = cache_config
self.quant_config = quant_config
self.speculative_config = speculative_config
self.num_spec = (
self.speculative_config.num_speculative_tokens
if self.speculative_config
else 0
)
# QKV
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.conv_dim,
bias=False,
prefix=f"{prefix}.conv1d",
)
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj_qkv = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkv",
)
self.in_proj_z = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.value_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_z",
)
self.in_proj_b = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_v_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_b",
)
self.in_proj_a = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_v_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_a",
)
query_key_settings = (self.key_dim, 0, False)
value_settings = (self.value_dim, 0, False)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
query_key_settings,
query_key_settings,
value_settings,
],
self.tp_size,
self.tp_rank,
)
},
)
# selective projection used to make dt, B and C input dependant
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(
torch.ones(self.num_v_heads // self.tp_size),
)
self.A_log = nn.Parameter(
torch.empty(
divide(self.num_v_heads, self.tp_size),
)
)
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.norm = RMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
group_size=None,
norm_before_gate=True,
device=current_platform.current_device(),
dtype=config.dtype,
)
self.out_proj = RowParallelLinear(
self.value_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def fix_query_key_value_ordering( def fix_query_key_value_ordering(
self, self,
mixed_qkvz: torch.Tensor, mixed_qkv,
mixed_ba: torch.Tensor, z,
b,
a,
): ):
raise NotImplementedError( raise NotImplementedError(
"Qwen3.5 Series dont need to fix query key value ordering" "Qwen3.5 Series dont need to fix query key value ordering"
) )
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[key_dim, key_dim, value_dim, value_dim],
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -160,13 +303,11 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================ # ============================================================
# Part 1: Input Projection # Part 1: Input Projection
# ============================================================ # ============================================================
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states) mixed_qkv, _ = self.in_proj_qkv(hidden_states)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size z, _ = self.in_proj_z(hidden_states)
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
z = z.reshape(z.size(0), -1, self.head_v_dim) z = z.reshape(z.size(0), -1, self.head_v_dim)
ba, _ = self.in_proj_ba(hidden_states) b, _ = self.in_proj_b(hidden_states)
b, a = ba.chunk(2, dim=-1) a, _ = self.in_proj_a(hidden_states)
b = b.contiguous() b = b.contiguous()
a = a.contiguous() a = a.contiguous()
@@ -365,18 +506,11 @@ class Qwen3_5Model(Qwen3NextModel):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
# self attention
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
# mlp
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
# GDN
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
("in_proj_ba", "in_proj_b", 0),
("in_proj_ba", "in_proj_a", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())

View File

@@ -44,7 +44,6 @@ from vllm.model_executor.layers.layernorm import (
from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
@@ -407,19 +406,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
# projection of the input hidden states # projection of the input hidden states
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout, self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
# we need to create qkvz_proj adaptively here. self.projection_size_ba = self.num_v_heads * 2
self.in_proj_qkvz = self.create_qkvz_proj( self.in_proj_qkvz = ColumnParallelLinear(
hidden_size=self.hidden_size, input_size=self.hidden_size,
key_dim=self.key_dim, output_size=self.projection_size_qkvz,
value_dim=self.value_dim, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz", prefix=f"{prefix}.in_proj_qkvz",
) )
# ba_proj doesn't support blockwise fp8 quantization. # ba_proj doesn't support blockwise fp8 quantization.
self.in_proj_ba = MergedColumnParallelLinear( self.in_proj_ba = ColumnParallelLinear(
input_size=self.hidden_size, input_size=self.hidden_size,
output_sizes=[self.num_v_heads] * 2, output_size=self.projection_size_ba,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.in_proj_ba", prefix=f"{prefix}.in_proj_ba",
@@ -485,26 +484,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[sum((key_dim, key_dim, value_dim)), value_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz",
)
def fix_query_key_value_ordering( def fix_query_key_value_ordering(
self, self,
mixed_qkvz: torch.Tensor, mixed_qkvz,
mixed_ba: torch.Tensor, mixed_ba,
): ):
""" """
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.