Files
nvfp4-megamoe-kernel/test_se_dequant.py

52 lines
2.2 KiB
Python

#!/usr/bin/env python3
"""Test: dequantize SE L1 weight and do BF16 matmul."""
import torch
from safetensors.torch import load_file
import json, os
cdir = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
with open(os.path.join(cdir, "model.safetensors.index.json")) as f:
wmap = json.load(f)["weight_map"]
# Load L0 SE weights
shards_needed = set()
for proj in ['gate_proj', 'up_proj', 'down_proj']:
k = f"model.layers.0.mlp.shared_experts.{proj}.weight"
if k in wmap:
shards_needed.add(wmap[k])
all_w = {}
for sn in shards_needed:
all_w.update(load_file(os.path.join(cdir, sn)))
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
O, I2 = weight.shape; I = I2 * 2
lo = (weight & 0x0F).to(torch.int8); hi = (weight >> 4).to(torch.int8)
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
lo_f = lut[(lo & 0x07).long()] * torch.where((lo >> 3).bool(), -1., 1.)
hi_f = lut[(hi & 0x07).long()] * torch.where((hi >> 3).bool(), -1., 1.)
w = torch.stack([lo_f, hi_f], -1).reshape(O, I)
s = weight_scale.float().repeat_interleave(16, 1)
if weight_scale_2 is not None: s = s * weight_scale_2.float()
return (w * s).bfloat16()
for gpu in [0, 1]:
dev = f"cuda:{gpu}"
# Dequantize weights
gw = all_w['model.layers.0.mlp.shared_experts.gate_proj.weight'].to(dev)
gws = all_w['model.layers.0.mlp.shared_experts.gate_proj.weight_scale'].to(dev)
gws2 = all_w.get('model.layers.0.mlp.shared_experts.gate_proj.weight_scale_2')
gws2 = gws2.to(dev) if gws2 is not None else None
gisc = all_w.get('model.layers.0.mlp.shared_experts.gate_proj.input_scale')
gate_dequant = dequant_nvfp4(gw, gws, gws2)
print(f"GPU {gpu} gate_dequant: shape={gate_dequant.shape} |max|={gate_dequant.abs().max().item():.4f} has_nan={torch.isnan(gate_dequant).any().item()}")
# BF16 matmul
x = torch.randn(1, 7168, dtype=torch.bfloat16, device=dev)
gate_out = torch.nn.functional.linear(x, gate_dequant)
print(f"GPU {gpu} gate_out: shape={gate_out.shape} |max|={gate_out.abs().max().item():.4f} has_nan={torch.isnan(gate_out).any().item()}")