[ROCm] fix num_stages for default moe config to avoid triton OutOfResource error (#17744)
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
This commit is contained in:
@@ -747,13 +747,15 @@ def get_default_config(
|
|||||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||||
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
||||||
|
# num_stages=3 can cause triton.runtime.errors.OutOfResources
|
||||||
|
# on ROCm, set it to 2 instead.
|
||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": block_shape[0],
|
"BLOCK_SIZE_N": block_shape[0],
|
||||||
"BLOCK_SIZE_K": block_shape[1],
|
"BLOCK_SIZE_K": block_shape[1],
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 32,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 3,
|
"num_stages": 3 if not current_platform.is_rocm() else 2,
|
||||||
}
|
}
|
||||||
elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:
|
elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:
|
||||||
# moe wna16 kernels
|
# moe wna16 kernels
|
||||||
|
|||||||
Reference in New Issue
Block a user