[Bugfix] Redo Qwen3.5/Qwen3-Next GDN projector fusion (#34697)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com>
This commit is contained in:
Isotr0py
2026-02-19 01:46:53 +08:00
committed by GitHub
parent caeb887bf6
commit c0bd8b13da
3 changed files with 102 additions and 192 deletions

View File

@@ -685,8 +685,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self, self,
param: Parameter, param: Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: int | None = None, loaded_shard_id: tuple[int, ...] | int | None = None,
): ):
if isinstance(loaded_shard_id, tuple):
raise NotImplementedError(
"Shard id with multiple indices is not supported in weight_loader, "
"please use weight_loader_v2 instead."
)
# Special case for GGUF # Special case for GGUF
# initialize GGUF param after we know the quantize type # initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight = getattr(param, "is_gguf_weight", False)
@@ -770,6 +775,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if output_dim is not None: if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id] shard_size = self.output_sizes[loaded_shard_id]
shard_offset //= self.tp_size
shard_size //= self.tp_size
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None) weight_block_size = getattr(self, "weight_block_size", None)
@@ -777,9 +784,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
weight_block_size, shard_size, shard_offset weight_block_size, shard_size, shard_offset
) )
shard_offset //= self.tp_size
shard_size //= self.tp_size
# Special case for quantization. # Special case for quantization.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
# for the packing. # for the packing.
@@ -825,7 +829,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def _load_fused_module_from_checkpoint( def _load_fused_module_from_checkpoint(
self, param: BasevLLMParameter, loaded_weight: torch.Tensor self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
output_sizes: list[int] | None = None,
): ):
""" """
Handle special case for models where MLP layers are already Handle special case for models where MLP layers are already
@@ -839,7 +846,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset = 0 current_shard_offset = 0
shard_offsets: list[tuple[int, int, int]] = [] shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes): output_sizes = output_sizes or self.output_sizes
for i, output_size in enumerate(output_sizes):
shard_offsets.append((i, current_shard_offset, output_size)) shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size current_shard_offset += output_size
@@ -864,23 +872,38 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self, self,
param: BasevLLMParameter, param: BasevLLMParameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: int | None = None, loaded_shard_id: tuple[int, ...] | int | None = None,
): ):
if loaded_shard_id is None: if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
if isinstance(param, PerTensorScaleParameter): if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
return return
elif type(param) in (RowvLLMParameter, BasevLLMParameter): elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight) param.load_merged_column_weight(loaded_weight=loaded_weight)
return return
output_sizes = (
[self.output_sizes[idx] for idx in loaded_shard_id]
if loaded_shard_id
else None
)
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
output_sizes = [
adjust_block_scale_shard(weight_block_size, size, 0)[0]
for size in (output_sizes or self.output_sizes)
]
# TODO: @dsikka - move to parameter.py # TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight) self._load_fused_module_from_checkpoint(
param, loaded_weight, output_sizes=output_sizes
)
return return
assert loaded_shard_id < len(self.output_sizes) assert loaded_shard_id < len(self.output_sizes)
shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id] shard_size = self.output_sizes[loaded_shard_id]
shard_offset //= self.tp_size
shard_size //= self.tp_size
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None) weight_block_size = getattr(self, "weight_block_size", None)
@@ -888,9 +911,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
weight_block_size, shard_size, shard_offset weight_block_size, shard_size, shard_offset
) )
shard_offset //= self.tp_size
shard_size //= self.tp_size
param.load_merged_column_weight( param.load_merged_column_weight(
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
shard_id=loaded_shard_id, shard_id=loaded_shard_id,

View File

@@ -30,36 +30,20 @@ from collections.abc import Callable, Iterable
import torch import torch
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from transformers.activations import ACT2FN
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ( from vllm.config import (
CacheConfig,
ModelConfig,
SpeculativeConfig,
VllmConfig, VllmConfig,
get_current_vllm_config,
) )
from vllm.distributed import ( from vllm.distributed import (
divide,
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3_5RMSNorm, GemmaRMSNorm as Qwen3_5RMSNorm,
) )
from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.linear import MergedColumnParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
mamba_v2_sharded_weight_loader,
)
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc, MambaStateCopyFunc,
MambaStateCopyFuncCalculator, MambaStateCopyFuncCalculator,
@@ -73,11 +57,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
) )
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
sharded_weight_loader,
) )
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.qwen3_5 import ( from vllm.transformers_utils.configs.qwen3_5 import (
Qwen3_5Config, Qwen3_5Config,
@@ -99,7 +80,6 @@ from .interfaces import (
) )
from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from .qwen3_next import ( from .qwen3_next import (
ChunkGatedDeltaRule,
Qwen3NextAttention, Qwen3NextAttention,
Qwen3NextDecoderLayer, Qwen3NextDecoderLayer,
Qwen3NextGatedDeltaNet, Qwen3NextGatedDeltaNet,
@@ -139,154 +119,31 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
def __init__(
self,
config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
speculative_config: SpeculativeConfig | None = None,
prefix: str = "",
) -> None:
super(Qwen3NextGatedDeltaNet, self).__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = extract_layer_index(prefix)
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
self.prefix = prefix
self.config = config
self.model_config = model_config
self.cache_config = cache_config
self.quant_config = quant_config
self.speculative_config = speculative_config
self.num_spec = (
self.speculative_config.num_speculative_tokens
if self.speculative_config
else 0
)
# QKV
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.conv_dim,
bias=False,
prefix=f"{prefix}.conv1d",
)
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj_qkv = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkv",
)
self.in_proj_z = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.value_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_z",
)
self.in_proj_b = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_v_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_b",
)
self.in_proj_a = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_v_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_a",
)
query_key_settings = (self.key_dim, 0, False)
value_settings = (self.value_dim, 0, False)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
query_key_settings,
query_key_settings,
value_settings,
],
self.tp_size,
self.tp_rank,
)
},
)
# selective projection used to make dt, B and C input dependant
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(
torch.ones(self.num_v_heads // self.tp_size),
)
self.A_log = nn.Parameter(
torch.empty(
divide(self.num_v_heads, self.tp_size),
)
)
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.norm = RMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
group_size=None,
norm_before_gate=True,
device=current_platform.current_device(),
dtype=config.dtype,
)
self.out_proj = RowParallelLinear(
self.value_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def fix_query_key_value_ordering( def fix_query_key_value_ordering(
self, self,
mixed_qkv, mixed_qkvz: torch.Tensor,
z, mixed_ba: torch.Tensor,
b,
a,
): ):
raise NotImplementedError( raise NotImplementedError(
"Qwen3.5 Series dont need to fix query key value ordering" "Qwen3.5 Series dont need to fix query key value ordering"
) )
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[key_dim, key_dim, value_dim, value_dim],
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -303,11 +160,13 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================ # ============================================================
# Part 1: Input Projection # Part 1: Input Projection
# ============================================================ # ============================================================
mixed_qkv, _ = self.in_proj_qkv(hidden_states) mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
z, _ = self.in_proj_z(hidden_states) qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
z = z.reshape(z.size(0), -1, self.head_v_dim) z = z.reshape(z.size(0), -1, self.head_v_dim)
b, _ = self.in_proj_b(hidden_states) ba, _ = self.in_proj_ba(hidden_states)
a, _ = self.in_proj_a(hidden_states) b, a = ba.chunk(2, dim=-1)
b = b.contiguous() b = b.contiguous()
a = a.contiguous() a = a.contiguous()
@@ -506,11 +365,18 @@ class Qwen3_5Model(Qwen3NextModel):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
# self attention
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
# mlp
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
# GDN
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
("in_proj_ba", "in_proj_b", 0),
("in_proj_ba", "in_proj_a", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
@@ -657,6 +523,9 @@ class Qwen3_5ForCausalLMBase(
"v_proj", "v_proj",
], ],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
# GDN fused projections.
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"],
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -676,10 +545,9 @@ class Qwen3_5ForCausalLMBase(
super().__init__() super().__init__()
self.config = config self.config = config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
# Deal with the case where the prefix is already "language_model" since self.model = Qwen3_5Model(
# Qwen/Qwen3.5-397B-A17B has naming like: model.language_model.layers.0 vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
model_prefix = prefix if "model" in prefix else "model" )
self.model = Qwen3_5Model(vllm_config=vllm_config, prefix=model_prefix)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
if config.tie_word_embeddings: if config.tie_word_embeddings:
@@ -755,6 +623,11 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
dummy_inputs=Qwen3VLDummyInputsBuilder, dummy_inputs=Qwen3VLDummyInputsBuilder,
) )
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid): class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | {
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
# protocols have not __init__ method, so we need to use nn.Module.__init__ # protocols have not __init__ method, so we need to use nn.Module.__init__
nn.Module.__init__(self) nn.Module.__init__(self)

View File

@@ -44,6 +44,7 @@ from vllm.model_executor.layers.layernorm import (
from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
@@ -406,19 +407,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
# projection of the input hidden states # projection of the input hidden states
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 # Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
self.projection_size_ba = self.num_v_heads * 2 # we need to create qkvz_proj adaptively here.
self.in_proj_qkvz = ColumnParallelLinear( self.in_proj_qkvz = self.create_qkvz_proj(
input_size=self.hidden_size, hidden_size=self.hidden_size,
output_size=self.projection_size_qkvz, key_dim=self.key_dim,
bias=False, value_dim=self.value_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz", prefix=f"{prefix}.in_proj_qkvz",
) )
# ba_proj doesn't support blockwise fp8 quantization. # ba_proj doesn't support blockwise fp8 quantization.
self.in_proj_ba = ColumnParallelLinear( self.in_proj_ba = MergedColumnParallelLinear(
input_size=self.hidden_size, input_size=self.hidden_size,
output_size=self.projection_size_ba, output_sizes=[self.num_v_heads] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.in_proj_ba", prefix=f"{prefix}.in_proj_ba",
@@ -484,10 +485,26 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[sum((key_dim, key_dim, value_dim, value_dim))],
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def fix_query_key_value_ordering( def fix_query_key_value_ordering(
self, self,
mixed_qkvz, mixed_qkvz: torch.Tensor,
mixed_ba, mixed_ba: torch.Tensor,
): ):
""" """
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.