fused_moe_kernel - cast accumulator after applying router weights (#32002)
Signed-off-by: gnovack <gnovack@amazon.com>
This commit is contained in:
@@ -539,20 +539,17 @@ def fused_moe_kernel(
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
else:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
accumulator = accumulator * b_scale
|
||||
elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
|
||||
accumulator = accumulator * a_scale * b_scale
|
||||
|
||||
# Bias is added AFTER dequantization since bias is typically stored in
|
||||
# the output dtype and should not be scaled by quantization factors.
|
||||
if HAS_BIAS:
|
||||
accumulator = accumulator + bias[None, :]
|
||||
|
||||
accumulator = accumulator.to(compute_type)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
Reference in New Issue
Block a user