fused_moe_kernel - cast accumulator after applying router weights (#32002)

Signed-off-by: gnovack <gnovack@amazon.com>
This commit is contained in:
gnovack
2026-01-10 12:36:45 -08:00
committed by GitHub
parent 543c23be78
commit d1fd802fa3

View File

@@ -539,20 +539,17 @@ def fused_moe_kernel(
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
accumulator = accumulator * b_scale
elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
accumulator = accumulator * a_scale * b_scale
# Bias is added AFTER dequantization since bias is typically stored in
# the output dtype and should not be scaled by quantization factors.
if HAS_BIAS:
accumulator = accumulator + bias[None, :]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)