[Kernel][MoE] fix computation order of MoE weight multiplication and improve flow (#31962)
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
This commit is contained in:
@@ -531,22 +531,37 @@ def fused_moe_kernel(
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
# Router weight multiplication MUST happen in float32 before precision
|
||||
# conversion for numerical stability (especially critical on ROCm).
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
|
||||
# Dequantization for supported quantization schemes:
|
||||
# - int8_w8a16
|
||||
# - fp8_w8a8
|
||||
# - int8_w8a8
|
||||
# Accumulator and scalings are in float32 to preserve numerical accuracy.
|
||||
if use_int8_w8a16:
|
||||
accumulator = accumulator * b_scale
|
||||
elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
|
||||
accumulator = accumulator * a_scale * b_scale
|
||||
|
||||
# Bias is added AFTER dequantization since bias is typically stored in
|
||||
# the output dtype and should not be scaled by quantization factors.
|
||||
# Bias addition:
|
||||
# Bias must be applied after dequantization:
|
||||
# - Since bias is typically not quantized
|
||||
# - Bias should not be scaled by quantization factors
|
||||
if HAS_BIAS:
|
||||
accumulator = accumulator + bias[None, :]
|
||||
accumulator += bias[None, :]
|
||||
|
||||
# Router (MoE) weight multiplication:
|
||||
# This multiplication MUST be performed in float32 before any precision
|
||||
# conversion to ensure numerical stability, which is especially critical
|
||||
# on ROCm platforms.
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(
|
||||
topk_weights_ptr + offs_token,
|
||||
mask=token_mask,
|
||||
other=0,
|
||||
)
|
||||
accumulator *= moe_weight[:, None]
|
||||
|
||||
# Final precision conversion:
|
||||
# Cast once at the end to the desired compute/output dtype.
|
||||
accumulator = accumulator.to(compute_type)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user