[Model] Extract GatedDeltaNetAttention into shared layer for Qwen3Next and Qwen3.5 (#37975)
Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
1046
vllm/model_executor/layers/mamba/gdn_linear_attn.py
Normal file
1046
vllm/model_executor/layers/mamba/gdn_linear_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -28,7 +28,6 @@ import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
@@ -40,18 +39,14 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
GemmaRMSNorm as Qwen3_5RMSNorm,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.gdn_linear_attn import GatedDeltaNetAttention
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
@@ -85,7 +80,6 @@ from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
||||
from .qwen3_next import (
|
||||
Qwen3NextAttention,
|
||||
Qwen3NextDecoderLayer,
|
||||
Qwen3NextGatedDeltaNet,
|
||||
Qwen3NextModel,
|
||||
Qwen3NextSparseMoeBlock,
|
||||
QwenNextMixtureOfExperts,
|
||||
@@ -121,149 +115,6 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
|
||||
return self.ctx.get_hf_config(Qwen3_5MoeConfig)
|
||||
|
||||
|
||||
class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
||||
def fix_query_key_value_ordering(
|
||||
self,
|
||||
mixed_qkvz: torch.Tensor,
|
||||
mixed_ba: torch.Tensor,
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Qwen3.5 Series dont need to fix query key value ordering"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3_5Config,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
create_in_proj_qkvz = vllm_config.lora_config is None
|
||||
super().__init__(
|
||||
config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
create_in_proj_qkvz=create_in_proj_qkvz,
|
||||
)
|
||||
if vllm_config.lora_config is not None:
|
||||
# Separate in_proj_qkv (Q,K,V) and in_proj_z for LoRA compatibility.
|
||||
# Use MergedColumnParallelLinear for in_proj_qkv because GDN can have
|
||||
# linear_num_key_heads != linear_num_value_heads (e.g. 16 vs 32), so
|
||||
# output sizes [key_dim, key_dim, value_dim] are not representable
|
||||
# with a single QKVParallelLinear (which ties K and V head counts).
|
||||
self.in_proj_qkv = MergedColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
|
||||
bias=False,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=f"{prefix}.in_proj_qkv",
|
||||
)
|
||||
self.in_proj_z = ColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_size=self.value_dim,
|
||||
bias=False,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=f"{prefix}.in_proj_z",
|
||||
)
|
||||
|
||||
def create_qkvz_proj(
|
||||
self,
|
||||
hidden_size: int,
|
||||
key_dim: int,
|
||||
value_dim: int,
|
||||
quant_config: QuantizationConfig | None,
|
||||
prefix: str,
|
||||
) -> MergedColumnParallelLinear:
|
||||
return MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[key_dim, key_dim, value_dim, value_dim],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def create_ba_proj(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_v_heads: int,
|
||||
quant_config: QuantizationConfig | None,
|
||||
prefix: str,
|
||||
) -> MergedColumnParallelLinear:
|
||||
# Qwen3.5 has separate in_proj_b and in_proj_a weights in the
|
||||
# checkpoint, which are loaded into the fused in_proj_ba parameter
|
||||
# via stacked_params_mapping with shard_id 0 and 1 respectively.
|
||||
return MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[num_v_heads] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Forward pass with three parts:
|
||||
1. Input projection
|
||||
2. Core attention (custom op)
|
||||
3. Output projection
|
||||
"""
|
||||
num_tokens = hidden_states.size(0)
|
||||
|
||||
# ============================================================
|
||||
# Part 1: Input Projection
|
||||
# ============================================================
|
||||
if hasattr(self, "in_proj_qkv"):
|
||||
# LoRA path: separate in_proj_qkv and in_proj_z
|
||||
mixed_qkv, _ = self.in_proj_qkv(hidden_states)
|
||||
ba, _ = self.in_proj_ba(hidden_states)
|
||||
z, _ = self.in_proj_z(hidden_states)
|
||||
else:
|
||||
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
|
||||
ba, _ = self.in_proj_ba(hidden_states)
|
||||
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
|
||||
z_size = self.value_dim // self.tp_size
|
||||
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
|
||||
z = z.reshape(z.size(0), -1, self.head_v_dim)
|
||||
b, a = ba.chunk(2, dim=-1)
|
||||
|
||||
b = b.contiguous()
|
||||
a = a.contiguous()
|
||||
|
||||
# ============================================================
|
||||
# Part 2: Core Attention (Custom Op)
|
||||
# ============================================================
|
||||
# Note: we should not use torch.empty here like other attention backends,
|
||||
# see discussions in https://github.com/vllm-project/vllm/pull/28182
|
||||
core_attn_out = torch.zeros(
|
||||
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
torch.ops.vllm.gdn_attention_core(
|
||||
mixed_qkv,
|
||||
b,
|
||||
a,
|
||||
core_attn_out,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Part 3: Output Projection
|
||||
# ============================================================
|
||||
z_shape_og = z.shape
|
||||
# Reshape input data into 2D tensor
|
||||
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
core_attn_out = self.norm(core_attn_out, z)
|
||||
core_attn_out = core_attn_out.reshape(z_shape_og)
|
||||
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
|
||||
output[:num_tokens], _ = self.out_proj(core_attn_out)
|
||||
|
||||
|
||||
class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -282,10 +133,12 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
|
||||
if self.layer_type == "linear_attention":
|
||||
self.linear_attn = Qwen3_5GatedDeltaNet(
|
||||
self.linear_attn = GatedDeltaNetAttention(
|
||||
config=config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.linear_attn",
|
||||
gqa_interleaved_layout=False,
|
||||
create_in_proj_qkvz=vllm_config.lora_config is None,
|
||||
)
|
||||
elif self.layer_type == "full_attention":
|
||||
self.self_attn = Qwen3NextAttention(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user