Fix critical dequantization bug: remove input_scale from weight dequant

input_scale is for ACTIVATIONS, not weights. The correct NVFP4 weight
dequantization formula is: weight_bf16 = e2m1_value * block_scale * global_scale

Including input_scale made weights ~5000x too small, causing completely
garbled output (multilingual gibberish with repeating patterns).
This commit is contained in:
2026-05-11 02:23:18 +00:00
parent 02b8ea536f
commit 67f9086a26

View File

@@ -1724,7 +1724,9 @@ class DeepseekV4Model(nn.Module):
if hasattr(mod, "input_scale")
else 1.0
)
w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale
# NOTE: input_scale is for ACTIVATIONS, not weights.
# Weight dequant = e2m1 * block_scale * global_scale (NO input_scale)
w_dequant = w_bf16.float() * block_scale_expanded * global_scale
w_dequant = w_dequant.to(torch.bfloat16)
else:
w_dequant = w_bf16
@@ -1764,7 +1766,9 @@ class DeepseekV4Model(nn.Module):
if hasattr(mod, "input_scale")
else 1.0
)
w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale
# NOTE: input_scale is for ACTIVATIONS, not weights.
# Weight dequant = e2m1 * block_scale * global_scale (NO input_scale)
w_dequant = w_bf16.float() * block_scale_expanded * global_scale
w_dequant = w_dequant.to(torch.bfloat16)
else:
w_dequant = w_bf16
@@ -1911,8 +1915,9 @@ class DeepseekV4Model(nn.Module):
else:
block_scale_exp = block_scale
gs = global_scale.to(device).max().item()
inp_s = input_scale.to(device).max().item() if input_scale is not None else 1.0
w = w_bf16.float() * block_scale_exp * gs * inp_s
# NOTE: input_scale is for activations, not weights.
# Weight dequant = e2m1 * block_scale * global_scale (NO input_scale)
w = w_bf16.float() * block_scale_exp * gs
return w.to(torch.bfloat16)
return w_bf16