From 67f9086a268fc69e1e445ecee411b463f5ab67fd Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 02:23:18 +0000 Subject: [PATCH] Fix critical dequantization bug: remove input_scale from weight dequant input_scale is for ACTIVATIONS, not weights. The correct NVFP4 weight dequantization formula is: weight_bf16 = e2m1_value * block_scale * global_scale Including input_scale made weights ~5000x too small, causing completely garbled output (multilingual gibberish with repeating patterns). --- patches/deepseek_v4.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/patches/deepseek_v4.py b/patches/deepseek_v4.py index 079a82d..9778866 100644 --- a/patches/deepseek_v4.py +++ b/patches/deepseek_v4.py @@ -1724,7 +1724,9 @@ class DeepseekV4Model(nn.Module): if hasattr(mod, "input_scale") else 1.0 ) - w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale + # NOTE: input_scale is for ACTIVATIONS, not weights. + # Weight dequant = e2m1 * block_scale * global_scale (NO input_scale) + w_dequant = w_bf16.float() * block_scale_expanded * global_scale w_dequant = w_dequant.to(torch.bfloat16) else: w_dequant = w_bf16 @@ -1764,7 +1766,9 @@ class DeepseekV4Model(nn.Module): if hasattr(mod, "input_scale") else 1.0 ) - w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale + # NOTE: input_scale is for ACTIVATIONS, not weights. + # Weight dequant = e2m1 * block_scale * global_scale (NO input_scale) + w_dequant = w_bf16.float() * block_scale_expanded * global_scale w_dequant = w_dequant.to(torch.bfloat16) else: w_dequant = w_bf16 @@ -1911,8 +1915,9 @@ class DeepseekV4Model(nn.Module): else: block_scale_exp = block_scale gs = global_scale.to(device).max().item() - inp_s = input_scale.to(device).max().item() if input_scale is not None else 1.0 - w = w_bf16.float() * block_scale_exp * gs * inp_s + # NOTE: input_scale is for activations, not weights. + # Weight dequant = e2m1 * block_scale * global_scale (NO input_scale) + w = w_bf16.float() * block_scale_exp * gs return w.to(torch.bfloat16) return w_bf16