Replace BF16 dequant with input_scale warmup fix for attention NVFP4

Instead of dequantizing attention weights to BF16 (which had issues
with MergedColumnParallelLinear and different weight_scale_2 values),
keep the NVFP4 path but fix the activation global scale.

Compute correct input_global_scale_inv by:
1. Temporarily dequantizing weight to BF16
2. Running warmup forward with random input
3. Computing actual activation amax
4. Setting scale_inv = amax * headroom

This preserves the original NVFP4 quantization pipeline.
This commit is contained in:
2026-05-18 15:43:46 +00:00
parent 301015b037
commit f86892e26b

View File

@@ -1685,49 +1685,26 @@ class DeepseekV4Model(nn.Module):
def _convert_nvfp4_post_load(self):
"""Post-load conversion of NVFP4 weights for vLLM compatibility.
All attention NVFP4 projections are dequantized to BF16 because
the checkpoint input_scale values cause NaN during activation
quantization in FlashInferCutlassNvFp4LinearKernel. BF16 bypasses
the broken input_scale entirely.
Fixes the attention input_global_scale_inv (activation global scale)
by running a warmup forward and computing the correct scale from
actual activation magnitudes. The checkpoint input_scale values are
calibrated incorrectly and cause NaN during activation quantization.
Compressor weights are reconstructed from checkpoint sub-weights
because the stacking weight_loader corrupts NVFP4 uint8 data.
wo_a is converted to FP8 for fp8_einsum (no input_scale needed).
Compressor weights are reconstructed from checkpoint sub-weights.
"""
# All attention projections to dequantize to BF16
# wo_a is excluded — it uses fp8_einsum (no input_scale, weight-only FP8)
# wq_a and wkv are fused into fused_wqa_wkv
bf16_proj_names = {"wq_b", "wo_b", "fused_wqa_wkv"}
# wo_a → FP8 (fp8_einsum path, no input_scale)
fp8_proj_names = {"wo_a"}
bf16_converted = 0
fp8_converted = 0
compressor_converted = 0
input_scale_fixes = 0
_shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None
from tqdm import tqdm
for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (upcast)NVFP4→BF16 attn projs", unit="layer"):
for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (fix)NVFP4 attn input_scale", unit="layer"):
attn = layer.attn
# BF16 dequantization: attention projections (except wo_a)
for proj_name in bf16_proj_names:
if not hasattr(attn, proj_name):
if layer_idx == 0:
print(f"[CLAWMINE] Layer 0: {proj_name} NOT FOUND on attn (type={type(attn).__name__})")
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight"):
if layer_idx == 0:
print(f"[CLAWMINE] Layer 0: {proj_name} has no weight attr")
continue
if layer_idx == 0:
print(f"[CLAWMINE] Layer 0: {proj_name} weight dtype={mod.weight.dtype}")
if mod.weight.dtype in (torch.uint8, torch.int8):
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
if layer_idx == 0:
print(f"[CLAWMINE] Layer 0: {proj_name} AFTER dequant: dtype={mod.weight.dtype} amax={mod.weight.data.amax().item():.4f} NaN={torch.isnan(mod.weight.data).any().item()} quant_method={type(mod.quant_method).__name__}")
bf16_converted += 1
# FP8 conversion: wo_a (used by fp8_einsum, no input_scale)
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
for proj_name in fp8_proj_names:
@@ -1741,6 +1718,64 @@ class DeepseekV4Model(nn.Module):
self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX)
fp8_converted += 1
# Fix input_global_scale_inv for NVFP4 attention projections
# The checkpoint input_scale is wrong. We compute the correct scale
# by dequantizing to BF16 temporarily and running a warmup.
for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "input_global_scale_inv"):
continue
if mod.weight.dtype not in (torch.uint8, torch.int8):
continue
# Temporarily dequantize weight to BF16 for warmup
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
w_uint8 = mod.weight.data
w_bf16_unpacked = self._unpack_nvfp4_to_bf16(w_uint8, E2M1_LUT, w_uint8.device)
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
block_scale = self._block_scale_to_float32(mod.weight_scale.data)
if block_scale.dim() == 2 and w_bf16_unpacked.dim() == 2:
block_size = w_bf16_unpacked.shape[1] // block_scale.shape[1]
block_scale_expanded = block_scale.unsqueeze(-1).expand(-1, -1, block_size).reshape(w_bf16_unpacked.shape)
else:
block_scale_expanded = block_scale
global_scale = mod.weight_scale_2.data.max().item()
w_bf16_dequant = (w_bf16_unpacked.float() * block_scale_expanded * global_scale).to(torch.bfloat16)
else:
w_bf16_dequant = w_bf16_unpacked
# Warmup: compute actual activation amax using BF16 reference
with torch.no_grad():
in_features = w_bf16_dequant.shape[-1]
dummy_input = torch.randn(256, in_features, dtype=torch.bfloat16, device=mod.weight.device) * 2.0
ref_output = torch.nn.functional.linear(dummy_input, w_bf16_dequant)
act_amax = ref_output.amax().item()
# Clean up temp tensors
del w_bf16_unpacked, w_bf16_dequant, ref_output
# Set correct input_global_scale_inv: 1/(amax * headroom)
# scaled_fp4_quant divides by input_global_scale_inv
# so input_global_scale_inv should be ~ amax (to map amax → 1.0 in FP4)
headroom = 1.2 # slight headroom to avoid clipping
new_inv = act_amax * headroom if act_amax > 0 else 1.0
new_scale = 1.0 / new_inv
if layer_idx == 0:
old_inv = mod.input_global_scale_inv.item() if hasattr(mod.input_global_scale_inv, 'item') else float(mod.input_global_scale_inv)
old_scale = mod.input_global_scale.item() if hasattr(mod.input_global_scale, 'item') else float(mod.input_global_scale)
print(f"[CLAWMINE] Layer 0: {proj_name} scale_inv: {old_inv:.8f}{new_inv:.8f} scale: {old_scale:.8f}{new_scale:.8f} (act_amax={act_amax:.4f})")
mod.input_global_scale = torch.nn.Parameter(torch.tensor(new_scale, dtype=torch.float32), requires_grad=False)
mod.input_global_scale_inv = torch.nn.Parameter(torch.tensor(new_inv, dtype=torch.float32), requires_grad=False)
# Update alpha: input_scale * weight_scale (both are the "1/x" form now)
if hasattr(mod, "weight_global_scale") and hasattr(mod, "alpha"):
wgs = mod.weight_global_scale.item() if hasattr(mod.weight_global_scale, 'item') else float(mod.weight_global_scale)
mod.alpha = torch.nn.Parameter(torch.tensor(new_scale * wgs, dtype=torch.float32), requires_grad=False)
input_scale_fixes += 1
# Compressor: still needs BF16 reconstruction
mla_attn = getattr(attn, "mla_attn", None)
if mla_attn is not None: