[ROCm] Fix MoE kernel test failures on gfx950 (#37833)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-25 13:46:40 -05:00
committed by GitHub
parent e38817fadb
commit 7d6917bef5
12 changed files with 478 additions and 86 deletions

View File

@@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
get_fp8_min_max,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
)
@@ -117,7 +118,10 @@ def _silu_mul_fp8_quant_deep_gemm(
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
y = gate * up
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
# Use multiply-by-reciprocal to match PyTorch's tensor/scalar
# division precision (Triton GPU fast-division for constexpr
# divisors can introduce 1-ULP error).
y_s = tl.maximum(tl.max(tl.abs(y)), eps) * (1.0 / fp8_max)
if ceil_ue8m0:
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
@@ -190,7 +194,7 @@ def persistent_masked_m_silu_mul_quant(
tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32)
fp8_dtype = torch.float8_e4m3fn
fp8_dtype = current_platform.fp8_dtype()
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
ys_shape, ys_strides, ys_dtype = scales_shape_stride_dtype(E, T, G, quant_scale_fmt)
@@ -210,11 +214,14 @@ def persistent_masked_m_silu_mul_quant(
device_id=y.device.index
).to_int()
if cuda_arch >= 80:
if current_platform.is_cuda() and cuda_arch >= 80:
torch.ops._C.persistent_masked_m_silu_mul_quant(
y, tokens_per_expert, y_q, y_s, ceil_ue8m0
)
else:
# Triton fallback for ROCm -- the C++ kernel is guarded by
# #ifndef USE_ROCM in activation_kernels.cu.
# https://github.com/ROCm/aiter/issues/2420
stride_cnt_e = tokens_per_expert.stride()[0]
# Static grid over experts and H-groups.
@@ -224,13 +231,11 @@ def persistent_masked_m_silu_mul_quant(
stride_i_e, stride_i_t, stride_i_h = y.stride()
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
f_info = torch.finfo(fp8_dtype)
fp8_max = f_info.max
fp8_min = f_info.min
fp8_min, fp8_max = get_fp8_min_max()
eps: float = 1e-10
assert y_s.dtype == torch.float32, (
"_silu_mul_fp8_quant_deep_gemm does"
"not support {y_s.dtype} scales. Only torch.float32 supported."
"_silu_mul_fp8_quant_deep_gemm Triton fallback does not "
f"support {y_s.dtype} scales. Only torch.float32 supported."
)
_silu_mul_fp8_quant_deep_gemm[grid](
y,

View File

@@ -253,10 +253,16 @@ def triton_kernel_moe_forward(
logits = gating_output
if sm_first:
logits = torch.softmax(logits, dim=-1)
sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first)
# sparse_logits.indx contains global expert IDs remap to local.
topk_ids = expert_map[sparse_logits.indx.to(torch.long)]
topk_weights = sparse_logits.vals
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
# topk may return a tuple (vals, indx, bitmatrix) or a
# SparseMatrix depending on the triton_kernels version.
if isinstance(topk_result, tuple):
topk_weights, topk_ids_raw, _ = topk_result
else:
topk_weights = topk_result.vals
topk_ids_raw = topk_result.indx
# topk_ids_raw contains global expert IDs - remap to local.
topk_ids = expert_map[topk_ids_raw.to(torch.long)]
local_num_experts = w1.shape[0]
routing_data, gather_idx, scatter_idx = make_routing_data(
topk_ids, topk_weights, local_num_experts
@@ -422,8 +428,13 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
# Shape check, only check non-mxfp4
assert hidden_states.shape[-1] == w1.shape[-2]
# Shape check: when weights are padded (e.g. hidden_size padded for
# GFX950 swizzle), unpadded_K_w1 carries the original dimension.
expected_K_w1 = unpadded_K_w1 if unpadded_K_w1 is not None else w1.shape[-2]
assert hidden_states.shape[-1] == expected_K_w1, (
f"hidden_states K={hidden_states.shape[-1]} != "
f"expected K={expected_K_w1} (w1 K={w1.shape[-2]})"
)
assert w2.shape[-1] == w1.shape[1]
E, _, N = w1.shape
@@ -483,6 +494,12 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
unpadded_K=unpadded_K_w2,
)
# When hidden_size was padded for alignment (e.g. GFX950 swizzle),
# the kernel output has the padded dimension. Slice back to the
# original hidden_size so downstream layers see the expected shape.
if unpadded_N_w2 is not None and intermediate_cache3.shape[-1] != unpadded_N_w2:
intermediate_cache3 = intermediate_cache3[..., :unpadded_N_w2].contiguous()
return intermediate_cache3

View File

@@ -741,11 +741,14 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
# TP=4 yields intermediate_size_per_partition=384), AITER raises:
# "device_gemm ... does not support this GEMM problem".
# Fall back to emulation in that case.
# For gpt_oss models, create_weights rounds up the dimensions
# internally, so the alignment check is skipped.
if (
not self.emulate
and self.use_rocm_aiter_moe
and self.ocp_mx_scheme is not None
and self.ocp_mx_scheme.startswith("w_mxfp4")
and self.model_type != "gpt_oss"
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
):
logger.warning_once(
@@ -819,6 +822,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
"unpadded_hidden_size", hidden_size
)
# On GFX950, the GFX950MXScaleLayout swizzle requires
# hidden_size to be a multiple of 256 (SCALE_K = hidden_size / 32
# must be divisible by 8). Pad hidden_size for weight/scale
# allocation; the original value is preserved in unpadded_hidden_size.
# Only applies to the native (non-emulated) CK path on GFX950.
if (
self.model_type == "gpt_oss"
and current_platform.is_rocm()
and not self.emulate
):
hidden_size = round_up(hidden_size, 256)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(

View File

@@ -615,8 +615,8 @@ def _per_token_group_quant_fp8(
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr,
# Meta-parameters
BLOCK: tl.constexpr,
@@ -647,8 +647,12 @@ def _per_token_group_quant_fp8(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
# Use multiply-by-reciprocal instead of division to match PyTorch's
# tensor/scalar division precision (GPU fast-division for constexpr
# divisors can introduce 1-ULP error that flips FP8 quantization at
# representable-value boundaries).
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
scale_raw = _absmax / fp8_max
scale_raw = _absmax * (1.0 / fp8_max)
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
@@ -667,8 +671,8 @@ def _silu_mul_per_token_group_quant_fp8_colmajor(
y_s_col_stride: tl.int64,
# Information for float8
eps,
fp8_min,
fp8_max,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr,
# Meta-parameters
GROUP_SIZE: tl.constexpr,
@@ -709,7 +713,7 @@ def _silu_mul_per_token_group_quant_fp8_colmajor(
# quant
_absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps)
scale_raw = _absmax / fp8_max
scale_raw = _absmax * (1.0 / fp8_max)
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_s = tl.reshape(y_s, (BLOCK_M, 1))
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
@@ -808,8 +812,8 @@ def _per_token_group_quant_fp8_colmajor(
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr,
# Meta-parameters
BLOCK: tl.constexpr,
@@ -849,7 +853,7 @@ def _per_token_group_quant_fp8_colmajor(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
scale_raw = _absmax / fp8_max
scale_raw = _absmax * (1.0 / fp8_max)
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)