[XPU] Initial support for GDN attention on Qwen3-next/Qwen3.5 (#33657)
Signed-off-by: Yan Ma <yan.ma@intel.com> Signed-off-by: Chendi Xue <chendi.xue@intel.com> Co-authored-by: Chendi Xue <chendi.xue@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -560,6 +560,11 @@ class RMSNormGated(CustomOp):
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
def forward_xpu(
|
||||
self, x: torch.Tensor, z: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return self.forward_cuda(x, z)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
"""
|
||||
|
||||
@@ -262,6 +262,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
else 0
|
||||
)
|
||||
self.gqa_interleaved_layout = gqa_interleaved_layout
|
||||
self._forward_method = (
|
||||
self.forward_xpu if current_platform.is_xpu() else self.forward_cuda
|
||||
)
|
||||
|
||||
# QKV
|
||||
self.conv_dim = self.key_dim * 2 + self.value_dim
|
||||
@@ -493,6 +496,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
self._forward_method(hidden_states, output)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Forward pass with three parts:
|
||||
@@ -567,6 +577,90 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
|
||||
output[:num_tokens], _ = self.out_proj(core_attn_out)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Forward pass with three parts:
|
||||
1. Input projection
|
||||
2. Core attention (custom op)
|
||||
3. Output projection
|
||||
"""
|
||||
num_tokens = hidden_states.size(0)
|
||||
|
||||
assert not hasattr(self, "in_proj_qkv"), "lora isn't supported on XPU."
|
||||
|
||||
# ============================================================
|
||||
# Part 1: Input Projection
|
||||
# ============================================================
|
||||
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
|
||||
projected_states_ba, _ = self.in_proj_ba(hidden_states)
|
||||
|
||||
# ============================================================
|
||||
# Part 2: Core Attention
|
||||
# ============================================================
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
core_attn_out = torch.zeros(
|
||||
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
z = torch.empty_like(core_attn_out)
|
||||
if attn_metadata is not None:
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
|
||||
# TODO: xpu does not support this param yet
|
||||
spec_sequence_masks = attn_metadata.spec_sequence_masks
|
||||
assert spec_sequence_masks is None
|
||||
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
)
|
||||
|
||||
conv_state = self.kv_cache[0]
|
||||
ssm_state = self.kv_cache[1]
|
||||
|
||||
torch.ops._xpu_C.gdn_attention(
|
||||
core_attn_out,
|
||||
z,
|
||||
projected_states_qkvz,
|
||||
projected_states_ba,
|
||||
self.num_k_heads,
|
||||
self.num_v_heads,
|
||||
self.head_k_dim,
|
||||
self.head_v_dim,
|
||||
conv_state=conv_state,
|
||||
ssm_state=ssm_state,
|
||||
conv_weights=conv_weights,
|
||||
conv_bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
A_log=self.A_log,
|
||||
dt_bias=self.dt_bias,
|
||||
num_prefills=attn_metadata.num_prefills,
|
||||
num_decodes=attn_metadata.num_decodes,
|
||||
has_initial_state=attn_metadata.has_initial_state,
|
||||
non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc,
|
||||
non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor,
|
||||
num_actual_tokens=attn_metadata.num_actual_tokens,
|
||||
tp_size=self.tp_size,
|
||||
reorder_input=not self.gqa_interleaved_layout,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Part 3: Output Projection
|
||||
# ============================================================
|
||||
z_shape_og = z.shape
|
||||
# Reshape input data into 2D tensor
|
||||
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
core_attn_out = self.norm(core_attn_out, z)
|
||||
core_attn_out = core_attn_out.reshape(z_shape_og)
|
||||
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
|
||||
output[:num_tokens], _ = self.out_proj(core_attn_out)
|
||||
|
||||
def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
|
||||
"""Warm up GDN prefill kernels during V1 profiling.
|
||||
|
||||
|
||||
@@ -218,6 +218,57 @@ class XPUPlatform(Platform):
|
||||
# ref. https://openucx.readthedocs.io/en/master/faq.html
|
||||
os.environ["UCX_MEMTYPE_CACHE"] = "n"
|
||||
|
||||
@classmethod
|
||||
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
|
||||
super().update_block_size_for_backend(vllm_config)
|
||||
from vllm.config.vllm import get_layers_from_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import (
|
||||
AttentionLayerBase,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
# special fix for GDN since kernel only supports block size dividable by 64
|
||||
attn_layers = get_layers_from_vllm_config(
|
||||
vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
|
||||
kernel_block_size = None
|
||||
for layer in attn_layers.values():
|
||||
b = layer.get_attn_backend()
|
||||
if b.get_name() == "GDN_ATTN":
|
||||
kernel_block_size = 64
|
||||
break
|
||||
|
||||
if kernel_block_size is None:
|
||||
return
|
||||
new_block_size = (
|
||||
cdiv(cache_config.block_size, kernel_block_size) * kernel_block_size
|
||||
)
|
||||
if new_block_size == cache_config.block_size:
|
||||
return
|
||||
|
||||
if cache_config.mamba_cache_mode == "align":
|
||||
cache_config.mamba_block_size = new_block_size
|
||||
original_mamba_page_size_padded = cache_config.mamba_page_size_padded
|
||||
if cache_config.mamba_page_size_padded is not None:
|
||||
attn_page_size_1_token = (
|
||||
cache_config.mamba_page_size_padded // cache_config.block_size
|
||||
)
|
||||
cache_config.mamba_page_size_padded = (
|
||||
new_block_size * attn_page_size_1_token
|
||||
)
|
||||
cache_config.block_size = new_block_size
|
||||
logger.info(
|
||||
"[XPU]Setting attention block size to %d tokens to ensure multiple of %d, "
|
||||
"set mamba_page_size_padded to %d bytes accordingly, before was %d bytes.",
|
||||
new_block_size,
|
||||
kernel_block_size,
|
||||
cache_config.mamba_page_size_padded,
|
||||
original_mamba_page_size_padded,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user