[ROCm][Bugfix] Fall back from CK MXFP4 MoE when GEMM dimensions are unsupported (#35893)
Signed-off-by: Li <chuali@amd.com>
This commit is contained in:
committed by
GitHub
parent
36bf213181
commit
5dc3538736
@@ -48,6 +48,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
|||||||
prepare_moe_fp4_layer_for_marlin,
|
prepare_moe_fp4_layer_for_marlin,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||||
|
CK_MXFP4_MOE_DIM_ALIGNMENT,
|
||||||
_can_support_mxfp4,
|
_can_support_mxfp4,
|
||||||
_swizzle_mxfp4,
|
_swizzle_mxfp4,
|
||||||
get_padding_alignment,
|
get_padding_alignment,
|
||||||
@@ -259,6 +260,31 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
|
||||||
|
# alignment requirements. Fall back to Triton when not met.
|
||||||
|
if (
|
||||||
|
self.mxfp4_backend == Mxfp4Backend.CK
|
||||||
|
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
|
||||||
|
):
|
||||||
|
if has_triton_kernels():
|
||||||
|
logger.warning_once(
|
||||||
|
"CK MXFP4 MoE GEMM does not support "
|
||||||
|
"intermediate_size_per_partition=%d (not a multiple of "
|
||||||
|
"%d). Falling back to Triton backend.",
|
||||||
|
moe.intermediate_size_per_partition,
|
||||||
|
CK_MXFP4_MOE_DIM_ALIGNMENT,
|
||||||
|
)
|
||||||
|
self.mxfp4_backend = Mxfp4Backend.TRITON
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"CK MXFP4 MoE GEMM does not support "
|
||||||
|
f"intermediate_size_per_partition="
|
||||||
|
f"{moe.intermediate_size_per_partition} (not a multiple "
|
||||||
|
f"of {CK_MXFP4_MOE_DIM_ALIGNMENT}) and no Triton "
|
||||||
|
f"fallback is available. Use a compatible "
|
||||||
|
f"tensor_parallel_size."
|
||||||
|
)
|
||||||
|
|
||||||
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
|
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
|
||||||
f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
|
f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
|
||||||
"no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
|
"no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
|
||||||
|
|||||||
@@ -32,7 +32,10 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
prepare_fp8_moe_layer_for_marlin,
|
prepare_fp8_moe_layer_for_marlin,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
|
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||||
|
CK_MXFP4_MOE_DIM_ALIGNMENT,
|
||||||
|
_swizzle_mxfp4,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||||
OCP_MX_BLOCK_SIZE,
|
OCP_MX_BLOCK_SIZE,
|
||||||
OCP_MX_Scheme,
|
OCP_MX_Scheme,
|
||||||
@@ -732,6 +735,32 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
|||||||
or not self.ocp_mx_scheme.startswith("w_mxfp4")
|
or not self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||||
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
|
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
|
||||||
|
|
||||||
|
# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
|
||||||
|
# alignment requirements. When violated (e.g. MiniMax-M2.1 with
|
||||||
|
# TP=4 yields intermediate_size_per_partition=384), AITER raises:
|
||||||
|
# "device_gemm ... does not support this GEMM problem".
|
||||||
|
# Fall back to emulation in that case.
|
||||||
|
if (
|
||||||
|
not self.emulate
|
||||||
|
and self.use_rocm_aiter_moe
|
||||||
|
and self.ocp_mx_scheme is not None
|
||||||
|
and self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||||
|
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
|
||||||
|
):
|
||||||
|
logger.warning_once(
|
||||||
|
"AITER CK MXFP4 MoE GEMM does not support "
|
||||||
|
"intermediate_size_per_partition=%d (not a multiple of %d). "
|
||||||
|
"This typically happens when intermediate_size / "
|
||||||
|
"tensor_parallel_size produces an incompatible dimension. "
|
||||||
|
"Falling back to emulation mode. To avoid this overhead, "
|
||||||
|
"use a compatible tensor_parallel_size or set "
|
||||||
|
"VLLM_ROCM_USE_AITER_MOE=0.",
|
||||||
|
moe.intermediate_size_per_partition,
|
||||||
|
CK_MXFP4_MOE_DIM_ALIGNMENT,
|
||||||
|
)
|
||||||
|
self.use_rocm_aiter_moe = False
|
||||||
|
self.emulate = True
|
||||||
|
|
||||||
if self.emulate:
|
if self.emulate:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||||
|
|||||||
@@ -14,6 +14,13 @@ from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# CK's pre-compiled MXFP4 MoE GEMM kernel instances require the
|
||||||
|
# intermediate_size (after TP split) to be a multiple of this value.
|
||||||
|
# This arises from FP4 packing (2 values per byte) combined with CK
|
||||||
|
# tile size constraints. When violated, AITER raises:
|
||||||
|
# "device_gemm ... does not support this GEMM problem".
|
||||||
|
CK_MXFP4_MOE_DIM_ALIGNMENT = 256
|
||||||
|
|
||||||
|
|
||||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||||
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
|
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
|
||||||
|
|||||||
Reference in New Issue
Block a user