From 079781177ae4c9fba429bf093cae73cf4cfae7a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Honig?= <5851246+renehonig@users.noreply.github.com> Date: Sat, 31 Jan 2026 23:06:42 +0100 Subject: [PATCH] fix: Add SM120 (RTX Blackwell) support for FlashInfer CUTLASS NVFP4 MoE kernels (#33417) Signed-off-by: mgoin Co-authored-by: mgoin --- .../layers/fused_moe/cutlass_moe.py | 7 +- .../fused_moe/flashinfer_cutedsl_moe.py | 3 +- .../fused_moe/flashinfer_cutlass_moe.py | 27 ++++---- .../layers/fused_moe/flashinfer_trtllm_moe.py | 1 - .../quantization/utils/flashinfer_fp4_moe.py | 27 -------- .../quantization/utils/nvfp4_moe_support.py | 67 ------------------- 6 files changed, 22 insertions(+), 110 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 86edbe303..74f05a2c0 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -657,7 +657,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): @staticmethod def _supports_current_device() -> bool: - return current_platform.has_device_capability((10, 0)) + p = current_platform + return p.is_cuda() and ( + p.is_device_capability_family(100) + or p.is_device_capability_family(110) + or p.is_device_capability_family(120) + ) @staticmethod def _supports_no_act_and_mul() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 036ee2a2e..2ad949577 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -54,7 +54,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): @staticmethod def _supports_current_device() -> bool: - return current_platform.is_device_capability_family(100) + p = current_platform + return p.is_cuda() and p.is_device_capability_family(100) @staticmethod def _supports_no_act_and_mul() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index faa654ea3..7c27da46f 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -84,11 +84,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): @staticmethod def _supports_current_device() -> bool: + p = current_platform return ( - current_platform.is_cuda() + p.is_cuda() and ( - current_platform.is_device_capability((9, 0)) - or current_platform.is_device_capability_family(100) + p.is_device_capability(90) + or p.is_device_capability_family(100) + or p.is_device_capability_family(110) + or p.is_device_capability_family(120) ) and has_flashinfer_cutlass_fused_moe() ) @@ -102,29 +105,27 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - # The following are supported by FlashInferExperts: - # * unquantized - # * fp8 static per-tensor on 9.0+ - # * fp8 block on 9.0 - # * nvfp4 on 10.0+ - p = current_platform scheme = (weight_key, activation_key) + # The following are supported by FlashInferExperts: return ( + # unquantized and fp8 static per-tensor on 9.0+ ( scheme in [ (None, None), (kFp8StaticTensorSym, kFp8StaticTensorSym), ] + and p.has_device_capability(90) ) + # fp8 block-scale on 9.0 or ( - (scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)) - and (p.is_device_capability((9, 0))) + scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym) + and p.is_device_capability(90) ) + # nvfp4 on 10.0+ or ( - (scheme == (kNvfp4Static, kNvfp4Dynamic)) - and (p.is_device_capability_family(100)) + scheme == (kNvfp4Static, kNvfp4Dynamic) and p.has_device_capability(100) ) ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 6b140ea3a..a066535c5 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -30,7 +30,6 @@ from vllm.utils.torch_utils import direct_register_custom_op def _supports_current_device() -> bool: """Supports only Blackwell-family GPUs.""" p = current_platform - # Add check flashinfer trtllm is available return p.is_cuda() and p.is_device_capability_family(100) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 6f3d19e09..4783ca5e0 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING import torch -import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -24,10 +23,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kNvfp4Static, ) from vllm.platforms import current_platform -from vllm.utils.flashinfer import ( - has_flashinfer_cutedsl_grouped_gemm_nt_masked, - has_flashinfer_cutlass_fused_moe, -) if TYPE_CHECKING: from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( @@ -38,8 +33,6 @@ logger = init_logger(__name__) __all__ = [ - "is_flashinfer_fp4_cutlass_moe_available", - "is_flashinfer_fp4_cutedsl_moe_available", "reorder_w1w3_to_w3w1", ] @@ -124,26 +117,6 @@ def is_supported_config_trtllm( return True, None -def is_flashinfer_fp4_cutlass_moe_available() -> bool: - """Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" - return ( - envs.VLLM_USE_FLASHINFER_MOE_FP4 - and has_flashinfer_cutlass_fused_moe() - and current_platform.is_cuda() - and current_platform.has_device_capability(100) - ) - - -def is_flashinfer_fp4_cutedsl_moe_available() -> bool: - """Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used.""" - return ( - envs.VLLM_USE_FLASHINFER_MOE_FP4 - and has_flashinfer_cutedsl_grouped_gemm_nt_masked() - and current_platform.is_cuda() - and current_platform.is_device_capability_family(100) - ) - - def reorder_w1w3_to_w3w1( weight: torch.Tensor, scale: torch.Tensor, dim: int = -2 ) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py deleted file mode 100644 index 199a81c42..000000000 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py +++ /dev/null @@ -1,67 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass - -import vllm.envs as envs -from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - is_flashinfer_fp4_cutedsl_moe_available, - is_flashinfer_fp4_cutlass_moe_available, -) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - is_fp4_marlin_supported, -) -from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( - cutlass_fp4_supported, -) - -__all__ = ["detect_nvfp4_moe_support", "NvFp4Support"] - -_logger = init_logger(__name__) - - -@dataclass(frozen=True) -class NvFp4Support: - """Result container for NV-FP4 capability probing.""" - - cutlass_supported: bool - allow_flashinfer: bool - use_marlin: bool - - -def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support: - """Detect platform support for NV-FP4 fused-MoE path""" - cutlass_supported = cutlass_fp4_supported() - - allow_flashinfer = cutlass_supported and ( - is_flashinfer_fp4_cutlass_moe_available() - or is_flashinfer_fp4_cutedsl_moe_available() - ) - - if allow_flashinfer: - _logger.info_once( - "Using FlashInfer kernels for %s.", class_name or "NVFP4 path" - ) - else: - if envs.VLLM_USE_FLASHINFER_MOE_FP4: - _logger.warning_once( - "FlashInfer kernels unavailable for %s on current platform.", - class_name or "NVFP4 path", - ) - - use_marlin = False - if not cutlass_supported: - if is_fp4_marlin_supported(): - use_marlin = True - _logger.info_once("Falling back to Marlin FP4 MoE kernel.") - else: - raise ValueError( - "Current platform does not support NVFP4 quantization. " - "Please use Blackwell GPUs or enable FlashInfer." - ) - - return NvFp4Support( - cutlass_supported=cutlass_supported, - allow_flashinfer=allow_flashinfer, - use_marlin=use_marlin, - )