[Perf] Disable DeepGEMM MoE by default when TP=8 is used (#29346)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import (
|
|||||||
FusedMoeWeightScaleSupported,
|
FusedMoeWeightScaleSupported,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEParallelConfig,
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
RoutingMethodType,
|
RoutingMethodType,
|
||||||
fp8_w8a8_moe_quant_config,
|
fp8_w8a8_moe_quant_config,
|
||||||
@@ -118,7 +119,9 @@ class Fp8MoeBackend(Enum):
|
|||||||
TRITON = 6
|
TRITON = 6
|
||||||
|
|
||||||
|
|
||||||
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
|
def get_fp8_moe_backend(
|
||||||
|
block_quant: bool, moe_parallel_config: FusedMoEParallelConfig
|
||||||
|
) -> Fp8MoeBackend:
|
||||||
"""
|
"""
|
||||||
Select the primary FP8 MoE backend
|
Select the primary FP8 MoE backend
|
||||||
Note: Shape-specific fallbacks may still occur at runtime.
|
Note: Shape-specific fallbacks may still occur at runtime.
|
||||||
@@ -159,8 +162,19 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
|
|||||||
logger.info_once("Using Marlin backend for FP8 MoE")
|
logger.info_once("Using Marlin backend for FP8 MoE")
|
||||||
return Fp8MoeBackend.MARLIN
|
return Fp8MoeBackend.MARLIN
|
||||||
|
|
||||||
# deepGEMM on supported platforms with block-quantized weights
|
# Determine if we should use DeepGEMM with block-quantized weights:
|
||||||
if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
|
# - If explicitly set by user, respect their choice
|
||||||
|
# - If not explicitly set (default), disable when TP size is >= 8
|
||||||
|
moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
|
||||||
|
if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
|
||||||
|
moe_use_deep_gemm = False
|
||||||
|
logger.info_once(
|
||||||
|
"DeepGEMM MoE is disabled by default when TP size is >= 8. "
|
||||||
|
"Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
|
||||||
|
scope="local",
|
||||||
|
)
|
||||||
|
|
||||||
|
if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
|
||||||
if not has_deep_gemm():
|
if not has_deep_gemm():
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"DeepGEMM backend requested but not available.", scope="local"
|
"DeepGEMM backend requested but not available.", scope="local"
|
||||||
@@ -641,7 +655,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.weight_block_size = self.quant_config.weight_block_size
|
self.weight_block_size = self.quant_config.weight_block_size
|
||||||
self.block_quant: bool = self.weight_block_size is not None
|
self.block_quant: bool = self.weight_block_size is not None
|
||||||
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
|
self.fp8_backend = get_fp8_moe_backend(
|
||||||
|
self.block_quant, layer.moe_parallel_config
|
||||||
|
)
|
||||||
|
|
||||||
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user