[ROCm][Bugfix] Fall back from CK MXFP4 MoE when GEMM dimensions are unsupported (#35893)
Signed-off-by: Li <chuali@amd.com>
This commit is contained in:
committed by
GitHub
parent
36bf213181
commit
5dc3538736
@@ -48,6 +48,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_fp4_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
CK_MXFP4_MOE_DIM_ALIGNMENT,
|
||||
_can_support_mxfp4,
|
||||
_swizzle_mxfp4,
|
||||
get_padding_alignment,
|
||||
@@ -259,6 +260,31 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
||||
)
|
||||
|
||||
# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
|
||||
# alignment requirements. Fall back to Triton when not met.
|
||||
if (
|
||||
self.mxfp4_backend == Mxfp4Backend.CK
|
||||
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
|
||||
):
|
||||
if has_triton_kernels():
|
||||
logger.warning_once(
|
||||
"CK MXFP4 MoE GEMM does not support "
|
||||
"intermediate_size_per_partition=%d (not a multiple of "
|
||||
"%d). Falling back to Triton backend.",
|
||||
moe.intermediate_size_per_partition,
|
||||
CK_MXFP4_MOE_DIM_ALIGNMENT,
|
||||
)
|
||||
self.mxfp4_backend = Mxfp4Backend.TRITON
|
||||
else:
|
||||
raise ValueError(
|
||||
f"CK MXFP4 MoE GEMM does not support "
|
||||
f"intermediate_size_per_partition="
|
||||
f"{moe.intermediate_size_per_partition} (not a multiple "
|
||||
f"of {CK_MXFP4_MOE_DIM_ALIGNMENT}) and no Triton "
|
||||
f"fallback is available. Use a compatible "
|
||||
f"tensor_parallel_size."
|
||||
)
|
||||
|
||||
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
|
||||
f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
|
||||
"no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
|
||||
|
||||
@@ -32,7 +32,10 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_fp8_moe_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
CK_MXFP4_MOE_DIM_ALIGNMENT,
|
||||
_swizzle_mxfp4,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
@@ -732,6 +735,32 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
or not self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
|
||||
|
||||
# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
|
||||
# alignment requirements. When violated (e.g. MiniMax-M2.1 with
|
||||
# TP=4 yields intermediate_size_per_partition=384), AITER raises:
|
||||
# "device_gemm ... does not support this GEMM problem".
|
||||
# Fall back to emulation in that case.
|
||||
if (
|
||||
not self.emulate
|
||||
and self.use_rocm_aiter_moe
|
||||
and self.ocp_mx_scheme is not None
|
||||
and self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
|
||||
):
|
||||
logger.warning_once(
|
||||
"AITER CK MXFP4 MoE GEMM does not support "
|
||||
"intermediate_size_per_partition=%d (not a multiple of %d). "
|
||||
"This typically happens when intermediate_size / "
|
||||
"tensor_parallel_size produces an incompatible dimension. "
|
||||
"Falling back to emulation mode. To avoid this overhead, "
|
||||
"use a compatible tensor_parallel_size or set "
|
||||
"VLLM_ROCM_USE_AITER_MOE=0.",
|
||||
moe.intermediate_size_per_partition,
|
||||
CK_MXFP4_MOE_DIM_ALIGNMENT,
|
||||
)
|
||||
self.use_rocm_aiter_moe = False
|
||||
self.emulate = True
|
||||
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
|
||||
@@ -14,6 +14,13 @@ from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# CK's pre-compiled MXFP4 MoE GEMM kernel instances require the
|
||||
# intermediate_size (after TP split) to be a multiple of this value.
|
||||
# This arises from FP4 packing (2 values per byte) combined with CK
|
||||
# tile size constraints. When violated, AITER raises:
|
||||
# "device_gemm ... does not support this GEMM problem".
|
||||
CK_MXFP4_MOE_DIM_ALIGNMENT = 256
|
||||
|
||||
|
||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
|
||||
|
||||
Reference in New Issue
Block a user