[GDN] add a config for gdn kernel selection (#36647)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Jiangyun Zhu
2026-03-16 00:40:17 +08:00
committed by GitHub
parent a3e2e250f0
commit 697e4ff352
2 changed files with 48 additions and 5 deletions

View File

@@ -614,6 +614,7 @@ class EngineArgs:
)
fail_on_environ_validation: bool = False
gdn_prefill_backend: Literal["flashinfer", "triton"] | None = None
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
@@ -1318,6 +1319,13 @@ class EngineArgs:
help="Shutdown timeout in seconds. 0 = abort, >0 = wait.",
)
parser.add_argument(
"--gdn-prefill-backend",
dest="gdn_prefill_backend",
choices=["flashinfer", "triton"],
default=None,
help="Select GDN prefill backend.",
)
return parser
@classmethod
@@ -1903,6 +1911,9 @@ class EngineArgs:
),
)
if self.gdn_prefill_backend is not None:
self.additional_config["gdn_prefill_backend"] = self.gdn_prefill_backend
config = VllmConfig(
model_config=model_config,
cache_config=cache_config,

View File

@@ -161,13 +161,45 @@ def fi_chunk_gated_delta_rule(
class ChunkGatedDeltaRule(CustomOp):
def __init__(self) -> None:
super().__init__()
if current_platform.is_cuda() and current_platform.is_device_capability(90):
logger.info_once(
"Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
backend = (
str(
get_current_vllm_config().additional_config.get(
"gdn_prefill_backend", "auto"
)
)
self._forward_method = self.forward_cuda
.strip()
.lower()
)
supports_flashinfer = (
current_platform.is_cuda() and current_platform.is_device_capability(90)
)
if backend == "flashinfer":
use_flashinfer = supports_flashinfer
if not use_flashinfer:
logger.warning_once(
"GDN prefill backend 'flashinfer' is selected but "
"cannot use this kernel on the current platform. "
"Falling back to Triton/FLA."
)
elif backend == "triton":
use_flashinfer = False
else:
self._forward_method = self.forward_native
use_flashinfer = supports_flashinfer
if use_flashinfer:
logger.info_once("Using FlashInfer GDN prefill kernel")
logger.info_once(
"FlashInfer GDN prefill kernel is JIT-compiled; first run may "
"take a while to compile. Set `--gdn-prefill-backend triton` to "
"avoid JIT compile time."
)
else:
logger.info_once("Using Triton/FLA GDN prefill kernel")
self._forward_method = (
self.forward_cuda if use_flashinfer else self.forward_native
)
def forward_cuda(
self,