Revert to BF16 dequant for attention NVFP4 (input_scale fix was too early)
process_weights_after_loading sets input_global_scale_inv AFTER _convert_nvfp4_post_load runs, so the fix couldn't find the attrs. Going back to BF16 dequant approach. The zeros in the dummy run are expected (attention_impl returns early with out.zero_()). Need to test with a real request under cudagraph_mode=NONE.
This commit is contained in:
@@ -1685,26 +1685,37 @@ class DeepseekV4Model(nn.Module):
|
||||
def _convert_nvfp4_post_load(self):
|
||||
"""Post-load conversion of NVFP4 weights for vLLM compatibility.
|
||||
|
||||
Fixes the attention input_global_scale_inv (activation global scale)
|
||||
by running a warmup forward and computing the correct scale from
|
||||
actual activation magnitudes. The checkpoint input_scale values are
|
||||
calibrated incorrectly and cause NaN during activation quantization.
|
||||
All attention NVFP4 projections (except wo_a) are dequantized to BF16.
|
||||
The checkpoint input_scale values cause NaN during activation quantization
|
||||
in FlashInferCutlassNvFp4LinearKernel. BF16 bypasses this entirely.
|
||||
|
||||
wo_a is converted to FP8 for fp8_einsum (no input_scale needed).
|
||||
Compressor weights are reconstructed from checkpoint sub-weights.
|
||||
"""
|
||||
# wo_a → FP8 (fp8_einsum path, no input_scale)
|
||||
bf16_proj_names = {"fused_wqa_wkv", "wq_b", "wo_b"}
|
||||
fp8_proj_names = {"wo_a"}
|
||||
bf16_converted = 0
|
||||
fp8_converted = 0
|
||||
compressor_converted = 0
|
||||
input_scale_fixes = 0
|
||||
|
||||
_shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None
|
||||
|
||||
from tqdm import tqdm
|
||||
for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (fix)NVFP4 attn input_scale", unit="layer"):
|
||||
for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (upcast)NVFP4→BF16 attn projs", unit="layer"):
|
||||
attn = layer.attn
|
||||
|
||||
# BF16 dequantization: attention projections (except wo_a)
|
||||
for proj_name in bf16_proj_names:
|
||||
if not hasattr(attn, proj_name):
|
||||
continue
|
||||
mod = getattr(attn, proj_name)
|
||||
if not hasattr(mod, "weight"):
|
||||
continue
|
||||
if mod.weight.dtype in (torch.uint8, torch.int8):
|
||||
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
|
||||
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
|
||||
bf16_converted += 1
|
||||
|
||||
# FP8 conversion: wo_a (used by fp8_einsum, no input_scale)
|
||||
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
for proj_name in fp8_proj_names:
|
||||
@@ -1718,68 +1729,6 @@ class DeepseekV4Model(nn.Module):
|
||||
self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX)
|
||||
fp8_converted += 1
|
||||
|
||||
# Fix input_global_scale_inv for NVFP4 attention projections
|
||||
# The checkpoint input_scale is wrong. We compute the correct scale
|
||||
# by dequantizing to BF16 temporarily and running a warmup.
|
||||
for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]:
|
||||
if not hasattr(attn, proj_name):
|
||||
if layer_idx == 0:
|
||||
print(f"[CLAWMINE] Layer 0: {proj_name} NOT on attn")
|
||||
continue
|
||||
mod = getattr(attn, proj_name)
|
||||
if layer_idx == 0:
|
||||
print(f"[CLAWMINE] Layer 0: {proj_name} dtype={mod.weight.dtype} has_input_global_scale_inv={hasattr(mod, 'input_global_scale_inv')} has_input_scale={hasattr(mod, 'input_scale')}")
|
||||
if not hasattr(mod, "input_global_scale_inv"):
|
||||
continue
|
||||
if mod.weight.dtype not in (torch.uint8, torch.int8):
|
||||
continue
|
||||
|
||||
# Temporarily dequantize weight to BF16 for warmup
|
||||
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
|
||||
w_uint8 = mod.weight.data
|
||||
w_bf16_unpacked = self._unpack_nvfp4_to_bf16(w_uint8, E2M1_LUT, w_uint8.device)
|
||||
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
|
||||
block_scale = self._block_scale_to_float32(mod.weight_scale.data)
|
||||
if block_scale.dim() == 2 and w_bf16_unpacked.dim() == 2:
|
||||
block_size = w_bf16_unpacked.shape[1] // block_scale.shape[1]
|
||||
block_scale_expanded = block_scale.unsqueeze(-1).expand(-1, -1, block_size).reshape(w_bf16_unpacked.shape)
|
||||
else:
|
||||
block_scale_expanded = block_scale
|
||||
global_scale = mod.weight_scale_2.data.max().item()
|
||||
w_bf16_dequant = (w_bf16_unpacked.float() * block_scale_expanded * global_scale).to(torch.bfloat16)
|
||||
else:
|
||||
w_bf16_dequant = w_bf16_unpacked
|
||||
|
||||
# Warmup: compute actual activation amax using BF16 reference
|
||||
with torch.no_grad():
|
||||
in_features = w_bf16_dequant.shape[-1]
|
||||
dummy_input = torch.randn(256, in_features, dtype=torch.bfloat16, device=mod.weight.device) * 2.0
|
||||
ref_output = torch.nn.functional.linear(dummy_input, w_bf16_dequant)
|
||||
act_amax = ref_output.amax().item()
|
||||
# Clean up temp tensors
|
||||
del w_bf16_unpacked, w_bf16_dequant, ref_output
|
||||
|
||||
# Set correct input_global_scale_inv: 1/(amax * headroom)
|
||||
# scaled_fp4_quant divides by input_global_scale_inv
|
||||
# so input_global_scale_inv should be ~ amax (to map amax → 1.0 in FP4)
|
||||
headroom = 1.2 # slight headroom to avoid clipping
|
||||
new_inv = act_amax * headroom if act_amax > 0 else 1.0
|
||||
new_scale = 1.0 / new_inv
|
||||
|
||||
if layer_idx == 0:
|
||||
old_inv = mod.input_global_scale_inv.item() if hasattr(mod.input_global_scale_inv, 'item') else float(mod.input_global_scale_inv)
|
||||
old_scale = mod.input_global_scale.item() if hasattr(mod.input_global_scale, 'item') else float(mod.input_global_scale)
|
||||
print(f"[CLAWMINE] Layer 0: {proj_name} scale_inv: {old_inv:.8f}→{new_inv:.8f} scale: {old_scale:.8f}→{new_scale:.8f} (act_amax={act_amax:.4f})")
|
||||
|
||||
mod.input_global_scale = torch.nn.Parameter(torch.tensor(new_scale, dtype=torch.float32), requires_grad=False)
|
||||
mod.input_global_scale_inv = torch.nn.Parameter(torch.tensor(new_inv, dtype=torch.float32), requires_grad=False)
|
||||
# Update alpha: input_scale * weight_scale (both are the "1/x" form now)
|
||||
if hasattr(mod, "weight_global_scale") and hasattr(mod, "alpha"):
|
||||
wgs = mod.weight_global_scale.item() if hasattr(mod.weight_global_scale, 'item') else float(mod.weight_global_scale)
|
||||
mod.alpha = torch.nn.Parameter(torch.tensor(new_scale * wgs, dtype=torch.float32), requires_grad=False)
|
||||
|
||||
input_scale_fixes += 1
|
||||
|
||||
# Compressor: still needs BF16 reconstruction
|
||||
mla_attn = getattr(attn, "mla_attn", None)
|
||||
if mla_attn is not None:
|
||||
|
||||
Reference in New Issue
Block a user