fix: Add SM120 (RTX Blackwell) support for FlashInfer CUTLASS NVFP4 MoE kernels (#33417)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
René Honig
2026-01-31 23:06:42 +01:00
committed by GitHub
parent 63c0889416
commit 079781177a
6 changed files with 22 additions and 110 deletions

View File

@@ -657,7 +657,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
return current_platform.has_device_capability((10, 0)) p = current_platform
return p.is_cuda() and (
p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability_family(120)
)
@staticmethod @staticmethod
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:

View File

@@ -54,7 +54,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
return current_platform.is_device_capability_family(100) p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod @staticmethod
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:

View File

@@ -84,11 +84,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
p = current_platform
return ( return (
current_platform.is_cuda() p.is_cuda()
and ( and (
current_platform.is_device_capability((9, 0)) p.is_device_capability(90)
or current_platform.is_device_capability_family(100) or p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability_family(120)
) )
and has_flashinfer_cutlass_fused_moe() and has_flashinfer_cutlass_fused_moe()
) )
@@ -102,29 +105,27 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
) -> bool: ) -> bool:
# The following are supported by FlashInferExperts:
# * unquantized
# * fp8 static per-tensor on 9.0+
# * fp8 block on 9.0
# * nvfp4 on 10.0+
p = current_platform p = current_platform
scheme = (weight_key, activation_key) scheme = (weight_key, activation_key)
# The following are supported by FlashInferExperts:
return ( return (
# unquantized and fp8 static per-tensor on 9.0+
( (
scheme scheme
in [ in [
(None, None), (None, None),
(kFp8StaticTensorSym, kFp8StaticTensorSym), (kFp8StaticTensorSym, kFp8StaticTensorSym),
] ]
and p.has_device_capability(90)
) )
# fp8 block-scale on 9.0
or ( or (
(scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)) scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)
and (p.is_device_capability((9, 0))) and p.is_device_capability(90)
) )
# nvfp4 on 10.0+
or ( or (
(scheme == (kNvfp4Static, kNvfp4Dynamic)) scheme == (kNvfp4Static, kNvfp4Dynamic) and p.has_device_capability(100)
and (p.is_device_capability_family(100))
) )
) )

View File

@@ -30,7 +30,6 @@ from vllm.utils.torch_utils import direct_register_custom_op
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs.""" """Supports only Blackwell-family GPUs."""
p = current_platform p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100) return p.is_cuda() and p.is_device_capability_family(100)

View File

@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING
import torch import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
@@ -24,10 +23,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Static, kNvfp4Static,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutlass_fused_moe,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
@@ -38,8 +33,6 @@ logger = init_logger(__name__)
__all__ = [ __all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1", "reorder_w1w3_to_w3w1",
] ]
@@ -124,26 +117,6 @@ def is_supported_config_trtllm(
return True, None return True, None
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.has_device_capability(100)
)
def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
"""Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
and current_platform.is_cuda()
and current_platform.is_device_capability_family(100)
)
def reorder_w1w3_to_w3w1( def reorder_w1w3_to_w3w1(
weight: torch.Tensor, scale: torch.Tensor, dim: int = -2 weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:

View File

@@ -1,67 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_flashinfer_fp4_cutedsl_moe_available,
is_flashinfer_fp4_cutlass_moe_available,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
__all__ = ["detect_nvfp4_moe_support", "NvFp4Support"]
_logger = init_logger(__name__)
@dataclass(frozen=True)
class NvFp4Support:
"""Result container for NV-FP4 capability probing."""
cutlass_supported: bool
allow_flashinfer: bool
use_marlin: bool
def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
"""Detect platform support for NV-FP4 fused-MoE path"""
cutlass_supported = cutlass_fp4_supported()
allow_flashinfer = cutlass_supported and (
is_flashinfer_fp4_cutlass_moe_available()
or is_flashinfer_fp4_cutedsl_moe_available()
)
if allow_flashinfer:
_logger.info_once(
"Using FlashInfer kernels for %s.", class_name or "NVFP4 path"
)
else:
if envs.VLLM_USE_FLASHINFER_MOE_FP4:
_logger.warning_once(
"FlashInfer kernels unavailable for %s on current platform.",
class_name or "NVFP4 path",
)
use_marlin = False
if not cutlass_supported:
if is_fp4_marlin_supported():
use_marlin = True
_logger.info_once("Falling back to Marlin FP4 MoE kernel.")
else:
raise ValueError(
"Current platform does not support NVFP4 quantization. "
"Please use Blackwell GPUs or enable FlashInfer."
)
return NvFp4Support(
cutlass_supported=cutlass_supported,
allow_flashinfer=allow_flashinfer,
use_marlin=use_marlin,
)