[Bug][MoE] Strengthen _supports_current_device() checks in the TRTLLM FP8, NVFP4, and FlashInfer CuteDSL MoE experts (#36728)

Signed-off-by: Yifan Zong <yzong@redhat.com>
This commit is contained in:
yzong-rh
2026-03-23 17:02:57 -04:00
committed by GitHub
parent 5bf3c42d4c
commit e85f8f0932
4 changed files with 19 additions and 5 deletions

View File

@@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
@@ -60,7 +61,11 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
return (
p.is_cuda()
and p.is_device_capability_family(100)
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
)
@staticmethod
def _supports_no_act_and_mul() -> bool:

View File

@@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kMxfp8Static,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
logger = init_logger(__name__)
@@ -61,8 +62,11 @@ class TrtLlmFp8ExpertsBase:
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)
return (
p.is_cuda()
and p.is_device_capability_family(100)
and has_flashinfer_trtllm_fused_moe()
)
@staticmethod
def _supports_no_act_and_mul() -> bool:

View File

@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Static,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
class TrtLlmNvFp4ExpertsBase:
@@ -80,7 +81,11 @@ class TrtLlmNvFp4ExpertsBase:
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
return (
p.is_cuda()
and p.is_device_capability_family(100)
and has_flashinfer_trtllm_fused_moe()
)
@staticmethod
def _supports_no_act_and_mul() -> bool: