fix: Add SM120 (RTX Blackwell) support for FlashInfer CUTLASS NVFP4 MoE kernels (#33417)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -657,7 +657,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
return current_platform.has_device_capability((10, 0))
|
||||
p = current_platform
|
||||
return p.is_cuda() and (
|
||||
p.is_device_capability_family(100)
|
||||
or p.is_device_capability_family(110)
|
||||
or p.is_device_capability_family(120)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
|
||||
@@ -54,7 +54,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
return current_platform.is_device_capability_family(100)
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
|
||||
@@ -84,11 +84,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
p = current_platform
|
||||
return (
|
||||
current_platform.is_cuda()
|
||||
p.is_cuda()
|
||||
and (
|
||||
current_platform.is_device_capability((9, 0))
|
||||
or current_platform.is_device_capability_family(100)
|
||||
p.is_device_capability(90)
|
||||
or p.is_device_capability_family(100)
|
||||
or p.is_device_capability_family(110)
|
||||
or p.is_device_capability_family(120)
|
||||
)
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
)
|
||||
@@ -102,29 +105,27 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
# The following are supported by FlashInferExperts:
|
||||
# * unquantized
|
||||
# * fp8 static per-tensor on 9.0+
|
||||
# * fp8 block on 9.0
|
||||
# * nvfp4 on 10.0+
|
||||
|
||||
p = current_platform
|
||||
scheme = (weight_key, activation_key)
|
||||
# The following are supported by FlashInferExperts:
|
||||
return (
|
||||
# unquantized and fp8 static per-tensor on 9.0+
|
||||
(
|
||||
scheme
|
||||
in [
|
||||
(None, None),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
]
|
||||
and p.has_device_capability(90)
|
||||
)
|
||||
# fp8 block-scale on 9.0
|
||||
or (
|
||||
(scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym))
|
||||
and (p.is_device_capability((9, 0)))
|
||||
scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)
|
||||
and p.is_device_capability(90)
|
||||
)
|
||||
# nvfp4 on 10.0+
|
||||
or (
|
||||
(scheme == (kNvfp4Static, kNvfp4Dynamic))
|
||||
and (p.is_device_capability_family(100))
|
||||
scheme == (kNvfp4Static, kNvfp4Dynamic) and p.has_device_capability(100)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ from vllm.utils.torch_utils import direct_register_custom_op
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
# Add check flashinfer trtllm is available
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
@@ -24,10 +23,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
has_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
@@ -38,8 +33,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"is_flashinfer_fp4_cutlass_moe_available",
|
||||
"is_flashinfer_fp4_cutedsl_moe_available",
|
||||
"reorder_w1w3_to_w3w1",
|
||||
]
|
||||
|
||||
@@ -124,26 +117,6 @@ def is_supported_config_trtllm(
|
||||
return True, None
|
||||
|
||||
|
||||
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
||||
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
|
||||
return (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(100)
|
||||
)
|
||||
|
||||
|
||||
def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
|
||||
"""Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used."""
|
||||
return (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
|
||||
def reorder_w1w3_to_w3w1(
|
||||
weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
is_flashinfer_fp4_cutedsl_moe_available,
|
||||
is_flashinfer_fp4_cutlass_moe_available,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
is_fp4_marlin_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
|
||||
__all__ = ["detect_nvfp4_moe_support", "NvFp4Support"]
|
||||
|
||||
_logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NvFp4Support:
|
||||
"""Result container for NV-FP4 capability probing."""
|
||||
|
||||
cutlass_supported: bool
|
||||
allow_flashinfer: bool
|
||||
use_marlin: bool
|
||||
|
||||
|
||||
def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
|
||||
"""Detect platform support for NV-FP4 fused-MoE path"""
|
||||
cutlass_supported = cutlass_fp4_supported()
|
||||
|
||||
allow_flashinfer = cutlass_supported and (
|
||||
is_flashinfer_fp4_cutlass_moe_available()
|
||||
or is_flashinfer_fp4_cutedsl_moe_available()
|
||||
)
|
||||
|
||||
if allow_flashinfer:
|
||||
_logger.info_once(
|
||||
"Using FlashInfer kernels for %s.", class_name or "NVFP4 path"
|
||||
)
|
||||
else:
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP4:
|
||||
_logger.warning_once(
|
||||
"FlashInfer kernels unavailable for %s on current platform.",
|
||||
class_name or "NVFP4 path",
|
||||
)
|
||||
|
||||
use_marlin = False
|
||||
if not cutlass_supported:
|
||||
if is_fp4_marlin_supported():
|
||||
use_marlin = True
|
||||
_logger.info_once("Falling back to Marlin FP4 MoE kernel.")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Current platform does not support NVFP4 quantization. "
|
||||
"Please use Blackwell GPUs or enable FlashInfer."
|
||||
)
|
||||
|
||||
return NvFp4Support(
|
||||
cutlass_supported=cutlass_supported,
|
||||
allow_flashinfer=allow_flashinfer,
|
||||
use_marlin=use_marlin,
|
||||
)
|
||||
Reference in New Issue
Block a user