[GDN] add a config for gdn kernel selection (#36647)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -614,6 +614,7 @@ class EngineArgs:
|
||||
)
|
||||
|
||||
fail_on_environ_validation: bool = False
|
||||
gdn_prefill_backend: Literal["flashinfer", "triton"] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
@@ -1318,6 +1319,13 @@ class EngineArgs:
|
||||
help="Shutdown timeout in seconds. 0 = abort, >0 = wait.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gdn-prefill-backend",
|
||||
dest="gdn_prefill_backend",
|
||||
choices=["flashinfer", "triton"],
|
||||
default=None,
|
||||
help="Select GDN prefill backend.",
|
||||
)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@@ -1903,6 +1911,9 @@ class EngineArgs:
|
||||
),
|
||||
)
|
||||
|
||||
if self.gdn_prefill_backend is not None:
|
||||
self.additional_config["gdn_prefill_backend"] = self.gdn_prefill_backend
|
||||
|
||||
config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
|
||||
@@ -161,13 +161,45 @@ def fi_chunk_gated_delta_rule(
|
||||
class ChunkGatedDeltaRule(CustomOp):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
||||
logger.info_once(
|
||||
"Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
|
||||
backend = (
|
||||
str(
|
||||
get_current_vllm_config().additional_config.get(
|
||||
"gdn_prefill_backend", "auto"
|
||||
)
|
||||
)
|
||||
self._forward_method = self.forward_cuda
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
supports_flashinfer = (
|
||||
current_platform.is_cuda() and current_platform.is_device_capability(90)
|
||||
)
|
||||
|
||||
if backend == "flashinfer":
|
||||
use_flashinfer = supports_flashinfer
|
||||
if not use_flashinfer:
|
||||
logger.warning_once(
|
||||
"GDN prefill backend 'flashinfer' is selected but "
|
||||
"cannot use this kernel on the current platform. "
|
||||
"Falling back to Triton/FLA."
|
||||
)
|
||||
elif backend == "triton":
|
||||
use_flashinfer = False
|
||||
else:
|
||||
self._forward_method = self.forward_native
|
||||
use_flashinfer = supports_flashinfer
|
||||
|
||||
if use_flashinfer:
|
||||
logger.info_once("Using FlashInfer GDN prefill kernel")
|
||||
logger.info_once(
|
||||
"FlashInfer GDN prefill kernel is JIT-compiled; first run may "
|
||||
"take a while to compile. Set `--gdn-prefill-backend triton` to "
|
||||
"avoid JIT compile time."
|
||||
)
|
||||
else:
|
||||
logger.info_once("Using Triton/FLA GDN prefill kernel")
|
||||
|
||||
self._forward_method = (
|
||||
self.forward_cuda if use_flashinfer else self.forward_native
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user