[Bug][MoE] Strengthen _supports_current_device() checks in the TRTLLM FP8, NVFP4, and FlashInfer CuteDSL MoE experts (#36728)
Signed-off-by: Yifan Zong <yzong@redhat.com>
This commit is contained in:
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
scaled_fp4_grouped_quantize,
|
||||
silu_and_mul_scaled_nvfp4_experts_quantize,
|
||||
)
|
||||
@@ -60,7 +61,11 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
return (
|
||||
p.is_cuda()
|
||||
and p.is_device_capability_family(100)
|
||||
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
|
||||
@@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kMxfp8Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -61,8 +62,11 @@ class TrtLlmFp8ExpertsBase:
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
# Add check flashinfer trtllm is available
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
return (
|
||||
p.is_cuda()
|
||||
and p.is_device_capability_family(100)
|
||||
and has_flashinfer_trtllm_fused_moe()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsBase:
|
||||
@@ -80,7 +81,11 @@ class TrtLlmNvFp4ExpertsBase:
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
return (
|
||||
p.is_cuda()
|
||||
and p.is_device_capability_family(100)
|
||||
and has_flashinfer_trtllm_fused_moe()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
|
||||
Reference in New Issue
Block a user