[XPU] Initial support for GDN attention on Qwen3-next/Qwen3.5 (#33657)

Signed-off-by: Yan Ma <yan.ma@intel.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Co-authored-by: Chendi Xue <chendi.xue@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Yan Ma
2026-04-03 08:59:11 +08:00
committed by GitHub
parent 05e68e1f81
commit ee3cf45739
3 changed files with 150 additions and 0 deletions

View File

@@ -560,6 +560,11 @@ class RMSNormGated(CustomOp):
activation=self.activation,
)
def forward_xpu(
self, x: torch.Tensor, z: torch.Tensor | None = None
) -> torch.Tensor:
return self.forward_cuda(x, z)
class LayerNorm(nn.Module):
"""

View File

@@ -262,6 +262,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
else 0
)
self.gqa_interleaved_layout = gqa_interleaved_layout
self._forward_method = (
self.forward_xpu if current_platform.is_xpu() else self.forward_cuda
)
# QKV
self.conv_dim = self.key_dim * 2 + self.value_dim
@@ -493,6 +496,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
self._forward_method(hidden_states, output)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
"""
Forward pass with three parts:
@@ -567,6 +577,90 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)
def forward_xpu(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
"""
Forward pass with three parts:
1. Input projection
2. Core attention (custom op)
3. Output projection
"""
num_tokens = hidden_states.size(0)
assert not hasattr(self, "in_proj_qkv"), "lora isn't supported on XPU."
# ============================================================
# Part 1: Input Projection
# ============================================================
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
projected_states_ba, _ = self.in_proj_ba(hidden_states)
# ============================================================
# Part 2: Core Attention
# ============================================================
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
z = torch.empty_like(core_attn_out)
if attn_metadata is not None:
attn_metadata = attn_metadata[self.prefix]
# TODO: xpu does not support this param yet
spec_sequence_masks = attn_metadata.spec_sequence_masks
assert spec_sequence_masks is None
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
conv_state = self.kv_cache[0]
ssm_state = self.kv_cache[1]
torch.ops._xpu_C.gdn_attention(
core_attn_out,
z,
projected_states_qkvz,
projected_states_ba,
self.num_k_heads,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
conv_state=conv_state,
ssm_state=ssm_state,
conv_weights=conv_weights,
conv_bias=self.conv1d.bias,
activation=self.activation,
A_log=self.A_log,
dt_bias=self.dt_bias,
num_prefills=attn_metadata.num_prefills,
num_decodes=attn_metadata.num_decodes,
has_initial_state=attn_metadata.has_initial_state,
non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc,
non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor,
num_actual_tokens=attn_metadata.num_actual_tokens,
tp_size=self.tp_size,
reorder_input=not self.gqa_interleaved_layout,
)
# ============================================================
# Part 3: Output Projection
# ============================================================
z_shape_og = z.shape
# Reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)
def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
"""Warm up GDN prefill kernels during V1 profiling.

View File

@@ -218,6 +218,57 @@ class XPUPlatform(Platform):
# ref. https://openucx.readthedocs.io/en/master/faq.html
os.environ["UCX_MEMTYPE_CACHE"] = "n"
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
super().update_block_size_for_backend(vllm_config)
from vllm.config.vllm import get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import (
AttentionLayerBase,
)
from vllm.utils.math_utils import cdiv
cache_config = vllm_config.cache_config
# special fix for GDN since kernel only supports block size dividable by 64
attn_layers = get_layers_from_vllm_config(
vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
kernel_block_size = None
for layer in attn_layers.values():
b = layer.get_attn_backend()
if b.get_name() == "GDN_ATTN":
kernel_block_size = 64
break
if kernel_block_size is None:
return
new_block_size = (
cdiv(cache_config.block_size, kernel_block_size) * kernel_block_size
)
if new_block_size == cache_config.block_size:
return
if cache_config.mamba_cache_mode == "align":
cache_config.mamba_block_size = new_block_size
original_mamba_page_size_padded = cache_config.mamba_page_size_padded
if cache_config.mamba_page_size_padded is not None:
attn_page_size_1_token = (
cache_config.mamba_page_size_padded // cache_config.block_size
)
cache_config.mamba_page_size_padded = (
new_block_size * attn_page_size_1_token
)
cache_config.block_size = new_block_size
logger.info(
"[XPU]Setting attention block size to %d tokens to ensure multiple of %d, "
"set mamba_page_size_padded to %d bytes accordingly, before was %d bytes.",
new_block_size,
kernel_block_size,
cache_config.mamba_page_size_padded,
original_mamba_page_size_padded,
)
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True