Fix critical dequantization bug: remove input_scale from weight dequant
input_scale is for ACTIVATIONS, not weights. The correct NVFP4 weight dequantization formula is: weight_bf16 = e2m1_value * block_scale * global_scale Including input_scale made weights ~5000x too small, causing completely garbled output (multilingual gibberish with repeating patterns).
This commit is contained in:
@@ -1724,7 +1724,9 @@ class DeepseekV4Model(nn.Module):
|
||||
if hasattr(mod, "input_scale")
|
||||
else 1.0
|
||||
)
|
||||
w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale
|
||||
# NOTE: input_scale is for ACTIVATIONS, not weights.
|
||||
# Weight dequant = e2m1 * block_scale * global_scale (NO input_scale)
|
||||
w_dequant = w_bf16.float() * block_scale_expanded * global_scale
|
||||
w_dequant = w_dequant.to(torch.bfloat16)
|
||||
else:
|
||||
w_dequant = w_bf16
|
||||
@@ -1764,7 +1766,9 @@ class DeepseekV4Model(nn.Module):
|
||||
if hasattr(mod, "input_scale")
|
||||
else 1.0
|
||||
)
|
||||
w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale
|
||||
# NOTE: input_scale is for ACTIVATIONS, not weights.
|
||||
# Weight dequant = e2m1 * block_scale * global_scale (NO input_scale)
|
||||
w_dequant = w_bf16.float() * block_scale_expanded * global_scale
|
||||
w_dequant = w_dequant.to(torch.bfloat16)
|
||||
else:
|
||||
w_dequant = w_bf16
|
||||
@@ -1911,8 +1915,9 @@ class DeepseekV4Model(nn.Module):
|
||||
else:
|
||||
block_scale_exp = block_scale
|
||||
gs = global_scale.to(device).max().item()
|
||||
inp_s = input_scale.to(device).max().item() if input_scale is not None else 1.0
|
||||
w = w_bf16.float() * block_scale_exp * gs * inp_s
|
||||
# NOTE: input_scale is for activations, not weights.
|
||||
# Weight dequant = e2m1 * block_scale * global_scale (NO input_scale)
|
||||
w = w_bf16.float() * block_scale_exp * gs
|
||||
return w.to(torch.bfloat16)
|
||||
return w_bf16
|
||||
|
||||
|
||||
Reference in New Issue
Block a user