#!/usr/bin/env python3 """Test: dequantize SE L1 weight and do BF16 matmul.""" import torch from safetensors.torch import load_file import json, os cdir = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" with open(os.path.join(cdir, "model.safetensors.index.json")) as f: wmap = json.load(f)["weight_map"] # Load L0 SE weights shards_needed = set() for proj in ['gate_proj', 'up_proj', 'down_proj']: k = f"model.layers.0.mlp.shared_experts.{proj}.weight" if k in wmap: shards_needed.add(wmap[k]) all_w = {} for sn in shards_needed: all_w.update(load_file(os.path.join(cdir, sn))) FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None): O, I2 = weight.shape; I = I2 * 2 lo = (weight & 0x0F).to(torch.int8); hi = (weight >> 4).to(torch.int8) lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) lo_f = lut[(lo & 0x07).long()] * torch.where((lo >> 3).bool(), -1., 1.) hi_f = lut[(hi & 0x07).long()] * torch.where((hi >> 3).bool(), -1., 1.) w = torch.stack([lo_f, hi_f], -1).reshape(O, I) s = weight_scale.float().repeat_interleave(16, 1) if weight_scale_2 is not None: s = s * weight_scale_2.float() return (w * s).bfloat16() for gpu in [0, 1]: dev = f"cuda:{gpu}" # Dequantize weights gw = all_w['model.layers.0.mlp.shared_experts.gate_proj.weight'].to(dev) gws = all_w['model.layers.0.mlp.shared_experts.gate_proj.weight_scale'].to(dev) gws2 = all_w.get('model.layers.0.mlp.shared_experts.gate_proj.weight_scale_2') gws2 = gws2.to(dev) if gws2 is not None else None gisc = all_w.get('model.layers.0.mlp.shared_experts.gate_proj.input_scale') gate_dequant = dequant_nvfp4(gw, gws, gws2) print(f"GPU {gpu} gate_dequant: shape={gate_dequant.shape} |max|={gate_dequant.abs().max().item():.4f} has_nan={torch.isnan(gate_dequant).any().item()}") # BF16 matmul x = torch.randn(1, 7168, dtype=torch.bfloat16, device=dev) gate_out = torch.nn.functional.linear(x, gate_dequant) print(f"GPU {gpu} gate_out: shape={gate_out.shape} |max|={gate_out.abs().max().item():.4f} has_nan={torch.isnan(gate_out).any().item()}")