diff --git a/patches/deepseek_v4.py b/patches/deepseek_v4.py index 7fb63c1..f204a9b 100644 --- a/patches/deepseek_v4.py +++ b/patches/deepseek_v4.py @@ -1970,6 +1970,15 @@ class DeepseekV4Model(nn.Module): w_dequant = w_bf16 # Replace weight with bf16 version + # Diagnostics: check for NaN/Inf from bad dequant + if w_dequant.isnan().any() or w_dequant.isinf().any(): + nan_count = w_dequant.isnan().sum().item() + inf_count = w_dequant.isinf().sum().item() + print(f"[NVFP4-DEQUANT-WARN] {getattr(mod, 'prefix', 'unknown')}: " + f"shape={w_dequant.shape}, dtype={w_dequant.dtype}, " + f"NaN={nan_count}, Inf={inf_count}, " + f"block_scale range=[{block_scale.min().item():.6f}, {block_scale.max().item():.6f}], " + f"global_scale={global_scale:.6f}") mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False) from vllm.model_executor.layers.linear import UnquantizedLinearMethod mod.quant_method = UnquantizedLinearMethod() @@ -2066,6 +2075,14 @@ class DeepseekV4Model(nn.Module): delattr(mod, attr) from vllm.model_executor.layers.linear import UnquantizedLinearMethod mod.quant_method = UnquantizedLinearMethod() + # Safety check: UnquantizedLinearMethod with FP8 weight will crash CUBLAS + if mod.weight.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + print(f"[NVFP4-FP8-CRASH] {getattr(mod, 'prefix', 'unknown')}: " + f"weight is {mod.weight.dtype} but quant_method=UnquantizedLinearMethod! " + f"This will crash CUBLAS.") + for attr in ("weight_scale", "weight_scale_2", "input_scale"): + if hasattr(mod, attr): + delattr(mod, attr) def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path=""): """Reconstruct compressor fused_wkv_wgate from checkpoint.