[Kernel] use flashinfer for gdn prefill (#32846)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
@@ -28,11 +28,15 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fla.ops import (
|
||||
chunk_gated_delta_rule,
|
||||
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops import (
|
||||
fused_recurrent_gated_delta_rule,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
GemmaRMSNorm as Qwen3NextRMSNorm,
|
||||
@@ -101,6 +105,113 @@ logger = init_logger(__name__)
|
||||
KVCache = tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
def fi_chunk_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
from flashinfer.gdn_prefill import (
|
||||
chunk_gated_delta_rule as chunk_gated_delta_rule_fi,
|
||||
)
|
||||
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
|
||||
# use flashinfer implementation
|
||||
q = q.squeeze(0).contiguous()
|
||||
k = k.squeeze(0).contiguous()
|
||||
v = v.squeeze(0).contiguous()
|
||||
|
||||
g = g.squeeze(0).contiguous()
|
||||
beta = beta.squeeze(0).contiguous()
|
||||
fi_state = initial_state.to(torch.float32)
|
||||
fi_g = g.to(torch.float32)
|
||||
fi_beta = beta.to(torch.float32)
|
||||
return chunk_gated_delta_rule_fi(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=torch.exp(fi_g),
|
||||
beta=fi_beta,
|
||||
initial_state=fi_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
|
||||
@CustomOp.register("chunk_gated_delta_rule")
|
||||
class ChunkGatedDeltaRule(CustomOp):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
||||
logger.info_once(
|
||||
"Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
|
||||
)
|
||||
self._forward_method = self.forward_cuda
|
||||
else:
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
return fi_chunk_gated_delta_rule(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
head_first=head_first,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
return fla_chunk_gated_delta_rule(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
head_first=head_first,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
)
|
||||
|
||||
|
||||
class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@@ -362,6 +473,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
@@ -647,7 +760,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
(
|
||||
core_attn_out_non_spec,
|
||||
last_recurrent_state,
|
||||
) = chunk_gated_delta_rule(
|
||||
) = self.chunk_gated_delta_rule(
|
||||
q=query_non_spec,
|
||||
k=key_non_spec,
|
||||
v=value_non_spec,
|
||||
|
||||
Reference in New Issue
Block a user