[Model] Extract GatedDeltaNetAttention into shared layer for Qwen3Next and Qwen3.5 (#37975)

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
Xiaoshuang Wang
2026-03-27 14:13:21 +08:00
committed by GitHub
parent 2babac0bed
commit a8eab8f30d
3 changed files with 1053 additions and 1126 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -28,7 +28,6 @@ import typing
from collections.abc import Callable, Iterable
import torch
from einops import rearrange
from torch import nn
from vllm.compilation.decorators import support_torch_compile
@@ -40,18 +39,14 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3_5RMSNorm,
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.gdn_linear_attn import GatedDeltaNetAttention
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
@@ -85,7 +80,6 @@ from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from .qwen3_next import (
Qwen3NextAttention,
Qwen3NextDecoderLayer,
Qwen3NextGatedDeltaNet,
Qwen3NextModel,
Qwen3NextSparseMoeBlock,
QwenNextMixtureOfExperts,
@@ -121,149 +115,6 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
return self.ctx.get_hf_config(Qwen3_5MoeConfig)
class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
def fix_query_key_value_ordering(
self,
mixed_qkvz: torch.Tensor,
mixed_ba: torch.Tensor,
):
raise NotImplementedError(
"Qwen3.5 Series dont need to fix query key value ordering"
)
def __init__(
self,
config: Qwen3_5Config,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
create_in_proj_qkvz = vllm_config.lora_config is None
super().__init__(
config,
vllm_config=vllm_config,
prefix=prefix,
create_in_proj_qkvz=create_in_proj_qkvz,
)
if vllm_config.lora_config is not None:
# Separate in_proj_qkv (Q,K,V) and in_proj_z for LoRA compatibility.
# Use MergedColumnParallelLinear for in_proj_qkv because GDN can have
# linear_num_key_heads != linear_num_value_heads (e.g. 16 vs 32), so
# output sizes [key_dim, key_dim, value_dim] are not representable
# with a single QKVParallelLinear (which ties K and V head counts).
self.in_proj_qkv = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.in_proj_qkv",
)
self.in_proj_z = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.value_dim,
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.in_proj_z",
)
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[key_dim, key_dim, value_dim, value_dim],
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def create_ba_proj(
self,
hidden_size: int,
num_v_heads: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
# Qwen3.5 has separate in_proj_b and in_proj_a weights in the
# checkpoint, which are loaded into the fused in_proj_ba parameter
# via stacked_params_mapping with shard_id 0 and 1 respectively.
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[num_v_heads] * 2,
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
"""
Forward pass with three parts:
1. Input projection
2. Core attention (custom op)
3. Output projection
"""
num_tokens = hidden_states.size(0)
# ============================================================
# Part 1: Input Projection
# ============================================================
if hasattr(self, "in_proj_qkv"):
# LoRA path: separate in_proj_qkv and in_proj_z
mixed_qkv, _ = self.in_proj_qkv(hidden_states)
ba, _ = self.in_proj_ba(hidden_states)
z, _ = self.in_proj_z(hidden_states)
else:
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
ba, _ = self.in_proj_ba(hidden_states)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
z = z.reshape(z.size(0), -1, self.head_v_dim)
b, a = ba.chunk(2, dim=-1)
b = b.contiguous()
a = a.contiguous()
# ============================================================
# Part 2: Core Attention (Custom Op)
# ============================================================
# Note: we should not use torch.empty here like other attention backends,
# see discussions in https://github.com/vllm-project/vllm/pull/28182
core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops.vllm.gdn_attention_core(
mixed_qkv,
b,
a,
core_attn_out,
self.prefix,
)
# ============================================================
# Part 3: Output Projection
# ============================================================
z_shape_og = z.shape
# Reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)
class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
def __init__(
self,
@@ -282,10 +133,12 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
self.layer_idx = extract_layer_index(prefix)
if self.layer_type == "linear_attention":
self.linear_attn = Qwen3_5GatedDeltaNet(
self.linear_attn = GatedDeltaNetAttention(
config=config,
vllm_config=vllm_config,
prefix=f"{prefix}.linear_attn",
gqa_interleaved_layout=False,
create_in_proj_qkvz=vllm_config.lora_config is None,
)
elif self.layer_type == "full_attention":
self.self_attn = Qwen3NextAttention(

File diff suppressed because it is too large Load Diff