52 lines
2.2 KiB
Python
52 lines
2.2 KiB
Python
#!/usr/bin/env python3
|
|
"""Test: dequantize SE L1 weight and do BF16 matmul."""
|
|
import torch
|
|
from safetensors.torch import load_file
|
|
import json, os
|
|
|
|
cdir = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
with open(os.path.join(cdir, "model.safetensors.index.json")) as f:
|
|
wmap = json.load(f)["weight_map"]
|
|
|
|
# Load L0 SE weights
|
|
shards_needed = set()
|
|
for proj in ['gate_proj', 'up_proj', 'down_proj']:
|
|
k = f"model.layers.0.mlp.shared_experts.{proj}.weight"
|
|
if k in wmap:
|
|
shards_needed.add(wmap[k])
|
|
|
|
all_w = {}
|
|
for sn in shards_needed:
|
|
all_w.update(load_file(os.path.join(cdir, sn)))
|
|
|
|
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
|
|
|
|
def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
|
|
O, I2 = weight.shape; I = I2 * 2
|
|
lo = (weight & 0x0F).to(torch.int8); hi = (weight >> 4).to(torch.int8)
|
|
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
|
lo_f = lut[(lo & 0x07).long()] * torch.where((lo >> 3).bool(), -1., 1.)
|
|
hi_f = lut[(hi & 0x07).long()] * torch.where((hi >> 3).bool(), -1., 1.)
|
|
w = torch.stack([lo_f, hi_f], -1).reshape(O, I)
|
|
s = weight_scale.float().repeat_interleave(16, 1)
|
|
if weight_scale_2 is not None: s = s * weight_scale_2.float()
|
|
return (w * s).bfloat16()
|
|
|
|
for gpu in [0, 1]:
|
|
dev = f"cuda:{gpu}"
|
|
|
|
# Dequantize weights
|
|
gw = all_w['model.layers.0.mlp.shared_experts.gate_proj.weight'].to(dev)
|
|
gws = all_w['model.layers.0.mlp.shared_experts.gate_proj.weight_scale'].to(dev)
|
|
gws2 = all_w.get('model.layers.0.mlp.shared_experts.gate_proj.weight_scale_2')
|
|
gws2 = gws2.to(dev) if gws2 is not None else None
|
|
gisc = all_w.get('model.layers.0.mlp.shared_experts.gate_proj.input_scale')
|
|
|
|
gate_dequant = dequant_nvfp4(gw, gws, gws2)
|
|
print(f"GPU {gpu} gate_dequant: shape={gate_dequant.shape} |max|={gate_dequant.abs().max().item():.4f} has_nan={torch.isnan(gate_dequant).any().item()}")
|
|
|
|
# BF16 matmul
|
|
x = torch.randn(1, 7168, dtype=torch.bfloat16, device=dev)
|
|
gate_out = torch.nn.functional.linear(x, gate_dequant)
|
|
print(f"GPU {gpu} gate_out: shape={gate_out.shape} |max|={gate_out.abs().max().item():.4f} has_nan={torch.isnan(gate_out).any().item()}")
|