diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 700713e32..8fac21687 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -614,6 +614,7 @@ class EngineArgs: ) fail_on_environ_validation: bool = False + gdn_prefill_backend: Literal["flashinfer", "triton"] | None = None def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -1318,6 +1319,13 @@ class EngineArgs: help="Shutdown timeout in seconds. 0 = abort, >0 = wait.", ) + parser.add_argument( + "--gdn-prefill-backend", + dest="gdn_prefill_backend", + choices=["flashinfer", "triton"], + default=None, + help="Select GDN prefill backend.", + ) return parser @classmethod @@ -1903,6 +1911,9 @@ class EngineArgs: ), ) + if self.gdn_prefill_backend is not None: + self.additional_config["gdn_prefill_backend"] = self.gdn_prefill_backend + config = VllmConfig( model_config=model_config, cache_config=cache_config, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index cfd4c7a56..bbe30c719 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -161,13 +161,45 @@ def fi_chunk_gated_delta_rule( class ChunkGatedDeltaRule(CustomOp): def __init__(self) -> None: super().__init__() - if current_platform.is_cuda() and current_platform.is_device_capability(90): - logger.info_once( - "Using FlashInfer GDN prefill kernel on CUDA compute capability 90" + backend = ( + str( + get_current_vllm_config().additional_config.get( + "gdn_prefill_backend", "auto" + ) ) - self._forward_method = self.forward_cuda + .strip() + .lower() + ) + supports_flashinfer = ( + current_platform.is_cuda() and current_platform.is_device_capability(90) + ) + + if backend == "flashinfer": + use_flashinfer = supports_flashinfer + if not use_flashinfer: + logger.warning_once( + "GDN prefill backend 'flashinfer' is selected but " + "cannot use this kernel on the current platform. " + "Falling back to Triton/FLA." + ) + elif backend == "triton": + use_flashinfer = False else: - self._forward_method = self.forward_native + use_flashinfer = supports_flashinfer + + if use_flashinfer: + logger.info_once("Using FlashInfer GDN prefill kernel") + logger.info_once( + "FlashInfer GDN prefill kernel is JIT-compiled; first run may " + "take a while to compile. Set `--gdn-prefill-backend triton` to " + "avoid JIT compile time." + ) + else: + logger.info_once("Using Triton/FLA GDN prefill kernel") + + self._forward_method = ( + self.forward_cuda if use_flashinfer else self.forward_native + ) def forward_cuda( self,