[ROCm] Fix MoE kernel test failures on gfx950 (#37833)
Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
get_fp8_min_max,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
)
|
||||
@@ -117,7 +118,10 @@ def _silu_mul_fp8_quant_deep_gemm(
|
||||
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
|
||||
y = gate * up
|
||||
|
||||
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
|
||||
# Use multiply-by-reciprocal to match PyTorch's tensor/scalar
|
||||
# division precision (Triton GPU fast-division for constexpr
|
||||
# divisors can introduce 1-ULP error).
|
||||
y_s = tl.maximum(tl.max(tl.abs(y)), eps) * (1.0 / fp8_max)
|
||||
if ceil_ue8m0:
|
||||
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
|
||||
|
||||
@@ -190,7 +194,7 @@ def persistent_masked_m_silu_mul_quant(
|
||||
|
||||
tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32)
|
||||
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||
|
||||
ys_shape, ys_strides, ys_dtype = scales_shape_stride_dtype(E, T, G, quant_scale_fmt)
|
||||
@@ -210,11 +214,14 @@ def persistent_masked_m_silu_mul_quant(
|
||||
device_id=y.device.index
|
||||
).to_int()
|
||||
|
||||
if cuda_arch >= 80:
|
||||
if current_platform.is_cuda() and cuda_arch >= 80:
|
||||
torch.ops._C.persistent_masked_m_silu_mul_quant(
|
||||
y, tokens_per_expert, y_q, y_s, ceil_ue8m0
|
||||
)
|
||||
else:
|
||||
# Triton fallback for ROCm -- the C++ kernel is guarded by
|
||||
# #ifndef USE_ROCM in activation_kernels.cu.
|
||||
# https://github.com/ROCm/aiter/issues/2420
|
||||
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||
|
||||
# Static grid over experts and H-groups.
|
||||
@@ -224,13 +231,11 @@ def persistent_masked_m_silu_mul_quant(
|
||||
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
||||
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
||||
|
||||
f_info = torch.finfo(fp8_dtype)
|
||||
fp8_max = f_info.max
|
||||
fp8_min = f_info.min
|
||||
fp8_min, fp8_max = get_fp8_min_max()
|
||||
eps: float = 1e-10
|
||||
assert y_s.dtype == torch.float32, (
|
||||
"_silu_mul_fp8_quant_deep_gemm does"
|
||||
"not support {y_s.dtype} scales. Only torch.float32 supported."
|
||||
"_silu_mul_fp8_quant_deep_gemm Triton fallback does not "
|
||||
f"support {y_s.dtype} scales. Only torch.float32 supported."
|
||||
)
|
||||
_silu_mul_fp8_quant_deep_gemm[grid](
|
||||
y,
|
||||
|
||||
@@ -253,10 +253,16 @@ def triton_kernel_moe_forward(
|
||||
logits = gating_output
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first)
|
||||
# sparse_logits.indx contains global expert IDs – remap to local.
|
||||
topk_ids = expert_map[sparse_logits.indx.to(torch.long)]
|
||||
topk_weights = sparse_logits.vals
|
||||
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
|
||||
# topk may return a tuple (vals, indx, bitmatrix) or a
|
||||
# SparseMatrix depending on the triton_kernels version.
|
||||
if isinstance(topk_result, tuple):
|
||||
topk_weights, topk_ids_raw, _ = topk_result
|
||||
else:
|
||||
topk_weights = topk_result.vals
|
||||
topk_ids_raw = topk_result.indx
|
||||
# topk_ids_raw contains global expert IDs - remap to local.
|
||||
topk_ids = expert_map[topk_ids_raw.to(torch.long)]
|
||||
local_num_experts = w1.shape[0]
|
||||
routing_data, gather_idx, scatter_idx = make_routing_data(
|
||||
topk_ids, topk_weights, local_num_experts
|
||||
@@ -422,8 +428,13 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
|
||||
assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
|
||||
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
|
||||
|
||||
# Shape check, only check non-mxfp4
|
||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||
# Shape check: when weights are padded (e.g. hidden_size padded for
|
||||
# GFX950 swizzle), unpadded_K_w1 carries the original dimension.
|
||||
expected_K_w1 = unpadded_K_w1 if unpadded_K_w1 is not None else w1.shape[-2]
|
||||
assert hidden_states.shape[-1] == expected_K_w1, (
|
||||
f"hidden_states K={hidden_states.shape[-1]} != "
|
||||
f"expected K={expected_K_w1} (w1 K={w1.shape[-2]})"
|
||||
)
|
||||
assert w2.shape[-1] == w1.shape[1]
|
||||
|
||||
E, _, N = w1.shape
|
||||
@@ -483,6 +494,12 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
|
||||
unpadded_K=unpadded_K_w2,
|
||||
)
|
||||
|
||||
# When hidden_size was padded for alignment (e.g. GFX950 swizzle),
|
||||
# the kernel output has the padded dimension. Slice back to the
|
||||
# original hidden_size so downstream layers see the expected shape.
|
||||
if unpadded_N_w2 is not None and intermediate_cache3.shape[-1] != unpadded_N_w2:
|
||||
intermediate_cache3 = intermediate_cache3[..., :unpadded_N_w2].contiguous()
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
|
||||
@@ -741,11 +741,14 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
# TP=4 yields intermediate_size_per_partition=384), AITER raises:
|
||||
# "device_gemm ... does not support this GEMM problem".
|
||||
# Fall back to emulation in that case.
|
||||
# For gpt_oss models, create_weights rounds up the dimensions
|
||||
# internally, so the alignment check is skipped.
|
||||
if (
|
||||
not self.emulate
|
||||
and self.use_rocm_aiter_moe
|
||||
and self.ocp_mx_scheme is not None
|
||||
and self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||
and self.model_type != "gpt_oss"
|
||||
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
|
||||
):
|
||||
logger.warning_once(
|
||||
@@ -819,6 +822,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
"unpadded_hidden_size", hidden_size
|
||||
)
|
||||
|
||||
# On GFX950, the GFX950MXScaleLayout swizzle requires
|
||||
# hidden_size to be a multiple of 256 (SCALE_K = hidden_size / 32
|
||||
# must be divisible by 8). Pad hidden_size for weight/scale
|
||||
# allocation; the original value is preserved in unpadded_hidden_size.
|
||||
# Only applies to the native (non-emulated) CK path on GFX950.
|
||||
if (
|
||||
self.model_type == "gpt_oss"
|
||||
and current_platform.is_rocm()
|
||||
and not self.emulate
|
||||
):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
|
||||
@@ -615,8 +615,8 @@ def _per_token_group_quant_fp8(
|
||||
# Avoid to divide zero
|
||||
eps,
|
||||
# Information for float8
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
fp8_min: tl.constexpr,
|
||||
fp8_max: tl.constexpr,
|
||||
use_ue8m0: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK: tl.constexpr,
|
||||
@@ -647,8 +647,12 @@ def _per_token_group_quant_fp8(
|
||||
|
||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
# Quant
|
||||
# Use multiply-by-reciprocal instead of division to match PyTorch's
|
||||
# tensor/scalar division precision (GPU fast-division for constexpr
|
||||
# divisors can introduce 1-ULP error that flips FP8 quantization at
|
||||
# representable-value boundaries).
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
scale_raw = _absmax / fp8_max
|
||||
scale_raw = _absmax * (1.0 / fp8_max)
|
||||
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
@@ -667,8 +671,8 @@ def _silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
y_s_col_stride: tl.int64,
|
||||
# Information for float8
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
fp8_min: tl.constexpr,
|
||||
fp8_max: tl.constexpr,
|
||||
use_ue8m0: tl.constexpr,
|
||||
# Meta-parameters
|
||||
GROUP_SIZE: tl.constexpr,
|
||||
@@ -709,7 +713,7 @@ def _silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
|
||||
# quant
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps)
|
||||
scale_raw = _absmax / fp8_max
|
||||
scale_raw = _absmax * (1.0 / fp8_max)
|
||||
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
|
||||
y_s = tl.reshape(y_s, (BLOCK_M, 1))
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
@@ -808,8 +812,8 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
# Avoid to divide zero
|
||||
eps,
|
||||
# Information for float8
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
fp8_min: tl.constexpr,
|
||||
fp8_max: tl.constexpr,
|
||||
use_ue8m0: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK: tl.constexpr,
|
||||
@@ -849,7 +853,7 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
# Quant
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
scale_raw = _absmax / fp8_max
|
||||
scale_raw = _absmax * (1.0 / fp8_max)
|
||||
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user