[GDN] add a config for gdn kernel selection (#36647)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -161,13 +161,45 @@ def fi_chunk_gated_delta_rule(
|
||||
class ChunkGatedDeltaRule(CustomOp):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
||||
logger.info_once(
|
||||
"Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
|
||||
backend = (
|
||||
str(
|
||||
get_current_vllm_config().additional_config.get(
|
||||
"gdn_prefill_backend", "auto"
|
||||
)
|
||||
)
|
||||
self._forward_method = self.forward_cuda
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
supports_flashinfer = (
|
||||
current_platform.is_cuda() and current_platform.is_device_capability(90)
|
||||
)
|
||||
|
||||
if backend == "flashinfer":
|
||||
use_flashinfer = supports_flashinfer
|
||||
if not use_flashinfer:
|
||||
logger.warning_once(
|
||||
"GDN prefill backend 'flashinfer' is selected but "
|
||||
"cannot use this kernel on the current platform. "
|
||||
"Falling back to Triton/FLA."
|
||||
)
|
||||
elif backend == "triton":
|
||||
use_flashinfer = False
|
||||
else:
|
||||
self._forward_method = self.forward_native
|
||||
use_flashinfer = supports_flashinfer
|
||||
|
||||
if use_flashinfer:
|
||||
logger.info_once("Using FlashInfer GDN prefill kernel")
|
||||
logger.info_once(
|
||||
"FlashInfer GDN prefill kernel is JIT-compiled; first run may "
|
||||
"take a while to compile. Set `--gdn-prefill-backend triton` to "
|
||||
"avoid JIT compile time."
|
||||
)
|
||||
else:
|
||||
logger.info_once("Using Triton/FLA GDN prefill kernel")
|
||||
|
||||
self._forward_method = (
|
||||
self.forward_cuda if use_flashinfer else self.forward_native
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user