[ROCm][LoRA] Fix MoE accuracy regression by preserving float32 router weight scaling (#31931)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -519,6 +519,12 @@ def fused_moe_kernel(
|
|||||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
|
# Router weight multiplication MUST happen in float32 before precision
|
||||||
|
# conversion for numerical stability (especially critical on ROCm).
|
||||||
|
if MUL_ROUTED_WEIGHT:
|
||||||
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||||
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
|
|
||||||
if use_int8_w8a16:
|
if use_int8_w8a16:
|
||||||
accumulator = (accumulator * b_scale).to(compute_type)
|
accumulator = (accumulator * b_scale).to(compute_type)
|
||||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||||
@@ -529,12 +535,10 @@ def fused_moe_kernel(
|
|||||||
else:
|
else:
|
||||||
accumulator = accumulator.to(compute_type)
|
accumulator = accumulator.to(compute_type)
|
||||||
|
|
||||||
# Since bias is typically not quantized, it's added after dequantization.
|
# Bias is added AFTER dequantization since bias is typically stored in
|
||||||
|
# the output dtype and should not be scaled by quantization factors.
|
||||||
if HAS_BIAS:
|
if HAS_BIAS:
|
||||||
accumulator = accumulator + bias[None, :]
|
accumulator = accumulator + bias[None, :]
|
||||||
if MUL_ROUTED_WEIGHT:
|
|
||||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
|
||||||
accumulator = accumulator * moe_weight[:, None]
|
|
||||||
|
|
||||||
# -----------------------------------------------------------
|
# -----------------------------------------------------------
|
||||||
# Write back the block of the output
|
# Write back the block of the output
|
||||||
|
|||||||
Reference in New Issue
Block a user