[GDN] add a config for gdn kernel selection (#36647)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Jiangyun Zhu
2026-03-16 00:40:17 +08:00
committed by GitHub
parent a3e2e250f0
commit 697e4ff352
2 changed files with 48 additions and 5 deletions

View File

@@ -161,13 +161,45 @@ def fi_chunk_gated_delta_rule(
class ChunkGatedDeltaRule(CustomOp):
def __init__(self) -> None:
super().__init__()
if current_platform.is_cuda() and current_platform.is_device_capability(90):
logger.info_once(
"Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
backend = (
str(
get_current_vllm_config().additional_config.get(
"gdn_prefill_backend", "auto"
)
)
self._forward_method = self.forward_cuda
.strip()
.lower()
)
supports_flashinfer = (
current_platform.is_cuda() and current_platform.is_device_capability(90)
)
if backend == "flashinfer":
use_flashinfer = supports_flashinfer
if not use_flashinfer:
logger.warning_once(
"GDN prefill backend 'flashinfer' is selected but "
"cannot use this kernel on the current platform. "
"Falling back to Triton/FLA."
)
elif backend == "triton":
use_flashinfer = False
else:
self._forward_method = self.forward_native
use_flashinfer = supports_flashinfer
if use_flashinfer:
logger.info_once("Using FlashInfer GDN prefill kernel")
logger.info_once(
"FlashInfer GDN prefill kernel is JIT-compiled; first run may "
"take a while to compile. Set `--gdn-prefill-backend triton` to "
"avoid JIT compile time."
)
else:
logger.info_once("Using Triton/FLA GDN prefill kernel")
self._forward_method = (
self.forward_cuda if use_flashinfer else self.forward_native
)
def forward_cuda(
self,