[Bugfix][Kernel] fix bias adding in triton kernel implemented fused moe (#31676)
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
This commit is contained in:
@@ -518,11 +518,7 @@ def fused_moe_kernel(
|
|||||||
# Advance the ptrs to the next K block.
|
# Advance the ptrs to the next K block.
|
||||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
if HAS_BIAS:
|
|
||||||
accumulator = accumulator + bias[None, :]
|
|
||||||
if MUL_ROUTED_WEIGHT:
|
|
||||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
|
||||||
accumulator = accumulator * moe_weight[:, None]
|
|
||||||
if use_int8_w8a16:
|
if use_int8_w8a16:
|
||||||
accumulator = (accumulator * b_scale).to(compute_type)
|
accumulator = (accumulator * b_scale).to(compute_type)
|
||||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||||
@@ -533,6 +529,13 @@ def fused_moe_kernel(
|
|||||||
else:
|
else:
|
||||||
accumulator = accumulator.to(compute_type)
|
accumulator = accumulator.to(compute_type)
|
||||||
|
|
||||||
|
# Since bias is typically not quantized, it's added after dequantization.
|
||||||
|
if HAS_BIAS:
|
||||||
|
accumulator = accumulator + bias[None, :]
|
||||||
|
if MUL_ROUTED_WEIGHT:
|
||||||
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||||
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
|
|
||||||
# -----------------------------------------------------------
|
# -----------------------------------------------------------
|
||||||
# Write back the block of the output
|
# Write back the block of the output
|
||||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
|||||||
Reference in New Issue
Block a user