[ROCm][Bugfix] Fall back from CK MXFP4 MoE when GEMM dimensions are unsupported (#35893)

Signed-off-by: Li <chuali@amd.com>
This commit is contained in:
Chuan (Richard) Li
2026-03-04 00:30:54 -08:00
committed by GitHub
parent 36bf213181
commit 5dc3538736
3 changed files with 63 additions and 1 deletions

View File

@@ -48,6 +48,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
CK_MXFP4_MOE_DIM_ALIGNMENT,
_can_support_mxfp4,
_swizzle_mxfp4,
get_padding_alignment,
@@ -259,6 +260,31 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)
# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
# alignment requirements. Fall back to Triton when not met.
if (
self.mxfp4_backend == Mxfp4Backend.CK
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
):
if has_triton_kernels():
logger.warning_once(
"CK MXFP4 MoE GEMM does not support "
"intermediate_size_per_partition=%d (not a multiple of "
"%d). Falling back to Triton backend.",
moe.intermediate_size_per_partition,
CK_MXFP4_MOE_DIM_ALIGNMENT,
)
self.mxfp4_backend = Mxfp4Backend.TRITON
else:
raise ValueError(
f"CK MXFP4 MoE GEMM does not support "
f"intermediate_size_per_partition="
f"{moe.intermediate_size_per_partition} (not a multiple "
f"of {CK_MXFP4_MOE_DIM_ALIGNMENT}) and no Triton "
f"fallback is available. Use a compatible "
f"tensor_parallel_size."
)
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
"no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."

View File

@@ -32,7 +32,10 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
CK_MXFP4_MOE_DIM_ALIGNMENT,
_swizzle_mxfp4,
)
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
@@ -732,6 +735,32 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
# alignment requirements. When violated (e.g. MiniMax-M2.1 with
# TP=4 yields intermediate_size_per_partition=384), AITER raises:
# "device_gemm ... does not support this GEMM problem".
# Fall back to emulation in that case.
if (
not self.emulate
and self.use_rocm_aiter_moe
and self.ocp_mx_scheme is not None
and self.ocp_mx_scheme.startswith("w_mxfp4")
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
):
logger.warning_once(
"AITER CK MXFP4 MoE GEMM does not support "
"intermediate_size_per_partition=%d (not a multiple of %d). "
"This typically happens when intermediate_size / "
"tensor_parallel_size produces an incompatible dimension. "
"Falling back to emulation mode. To avoid this overhead, "
"use a compatible tensor_parallel_size or set "
"VLLM_ROCM_USE_AITER_MOE=0.",
moe.intermediate_size_per_partition,
CK_MXFP4_MOE_DIM_ALIGNMENT,
)
self.use_rocm_aiter_moe = False
self.emulate = True
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "

View File

@@ -14,6 +14,13 @@ from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_
logger = init_logger(__name__)
# CK's pre-compiled MXFP4 MoE GEMM kernel instances require the
# intermediate_size (after TP split) to be a multiple of this value.
# This arises from FP4 packing (2 values per byte) combined with CK
# tile size constraints. When violated, AITER raises:
# "device_gemm ... does not support this GEMM problem".
CK_MXFP4_MOE_DIM_ALIGNMENT = 256
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""