[Kernel] use flashinfer for gdn prefill (#32846)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu
2026-02-10 01:17:25 +08:00
committed by GitHub
parent 995bbf38f1
commit 285bab4752

View File

@@ -28,11 +28,15 @@ from vllm.distributed import (
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fla.ops import (
chunk_gated_delta_rule,
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
)
from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule,
)
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm,
@@ -101,6 +105,113 @@ logger = init_logger(__name__)
KVCache = tuple[torch.Tensor, torch.Tensor]
def fi_chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = True,
):
from flashinfer.gdn_prefill import (
chunk_gated_delta_rule as chunk_gated_delta_rule_fi,
)
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
# use flashinfer implementation
q = q.squeeze(0).contiguous()
k = k.squeeze(0).contiguous()
v = v.squeeze(0).contiguous()
g = g.squeeze(0).contiguous()
beta = beta.squeeze(0).contiguous()
fi_state = initial_state.to(torch.float32)
fi_g = g.to(torch.float32)
fi_beta = beta.to(torch.float32)
return chunk_gated_delta_rule_fi(
q=q,
k=k,
v=v,
g=torch.exp(fi_g),
beta=fi_beta,
initial_state=fi_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
@CustomOp.register("chunk_gated_delta_rule")
class ChunkGatedDeltaRule(CustomOp):
def __init__(self) -> None:
super().__init__()
if current_platform.is_cuda() and current_platform.is_device_capability(90):
logger.info_once(
"Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
)
self._forward_method = self.forward_cuda
else:
self._forward_method = self.forward_native
def forward_cuda(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = True,
):
return fi_chunk_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
def forward_native(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = True,
):
return fla_chunk_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
class Qwen3NextSparseMoeBlock(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -362,6 +473,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
prefix=f"{prefix}.out_proj",
)
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
@@ -647,7 +760,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
(
core_attn_out_non_spec,
last_recurrent_state,
) = chunk_gated_delta_rule(
) = self.chunk_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,