fused_moe_kernel - cast accumulator after applying router weights (#32002)
Signed-off-by: gnovack <gnovack@amazon.com>
This commit is contained in:
@@ -539,20 +539,17 @@ def fused_moe_kernel(
|
|||||||
accumulator = accumulator * moe_weight[:, None]
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
|
|
||||||
if use_int8_w8a16:
|
if use_int8_w8a16:
|
||||||
accumulator = (accumulator * b_scale).to(compute_type)
|
accumulator = accumulator * b_scale
|
||||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
|
||||||
if group_k > 0 and group_n > 0:
|
accumulator = accumulator * a_scale * b_scale
|
||||||
accumulator = accumulator.to(compute_type)
|
|
||||||
else:
|
|
||||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
|
||||||
else:
|
|
||||||
accumulator = accumulator.to(compute_type)
|
|
||||||
|
|
||||||
# Bias is added AFTER dequantization since bias is typically stored in
|
# Bias is added AFTER dequantization since bias is typically stored in
|
||||||
# the output dtype and should not be scaled by quantization factors.
|
# the output dtype and should not be scaled by quantization factors.
|
||||||
if HAS_BIAS:
|
if HAS_BIAS:
|
||||||
accumulator = accumulator + bias[None, :]
|
accumulator = accumulator + bias[None, :]
|
||||||
|
|
||||||
|
accumulator = accumulator.to(compute_type)
|
||||||
|
|
||||||
# -----------------------------------------------------------
|
# -----------------------------------------------------------
|
||||||
# Write back the block of the output
|
# Write back the block of the output
|
||||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
|||||||
Reference in New Issue
Block a user