diag: add NaN/Inf + FP8-dtype checks after NVFP4 dequant

This commit is contained in:
2026-05-11 19:26:12 +00:00
parent 8ae2214bad
commit cd24182e36

View File

@@ -1970,6 +1970,15 @@ class DeepseekV4Model(nn.Module):
w_dequant = w_bf16
# Replace weight with bf16 version
# Diagnostics: check for NaN/Inf from bad dequant
if w_dequant.isnan().any() or w_dequant.isinf().any():
nan_count = w_dequant.isnan().sum().item()
inf_count = w_dequant.isinf().sum().item()
print(f"[NVFP4-DEQUANT-WARN] {getattr(mod, 'prefix', 'unknown')}: "
f"shape={w_dequant.shape}, dtype={w_dequant.dtype}, "
f"NaN={nan_count}, Inf={inf_count}, "
f"block_scale range=[{block_scale.min().item():.6f}, {block_scale.max().item():.6f}], "
f"global_scale={global_scale:.6f}")
mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
@@ -2066,6 +2075,14 @@ class DeepseekV4Model(nn.Module):
delattr(mod, attr)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
# Safety check: UnquantizedLinearMethod with FP8 weight will crash CUBLAS
if mod.weight.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
print(f"[NVFP4-FP8-CRASH] {getattr(mod, 'prefix', 'unknown')}: "
f"weight is {mod.weight.dtype} but quant_method=UnquantizedLinearMethod! "
f"This will crash CUBLAS.")
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
if hasattr(mod, attr):
delattr(mod, attr)
def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path=""):
"""Reconstruct compressor fused_wkv_wgate from checkpoint.