71 lines
2.4 KiB
Python
71 lines
2.4 KiB
Python
#!/usr/bin/env python3
|
|
"""Test: does the SE's L1 GEMM produce NaN on non-zero GPUs?"""
|
|
import torch
|
|
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
|
|
|
torch.manual_seed(42)
|
|
|
|
# Load a real checkpoint weight for layer 0's shared expert
|
|
from safetensors.torch import load_file
|
|
import json, os
|
|
cdir = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
|
|
# We'll use L0's weights and try running on different GPUs
|
|
with open(os.path.join(cdir, "model.safetensors.index.json")) as f:
|
|
wmap = json.load(f)["weight_map"]
|
|
|
|
# Load L0 SE weights
|
|
shards_needed = set()
|
|
for proj in ['gate_proj', 'up_proj', 'down_proj']:
|
|
k = f"model.layers.0.mlp.shared_experts.{proj}.weight"
|
|
if k in wmap:
|
|
shards_needed.add(wmap[k])
|
|
|
|
all_w = {}
|
|
for sn in shards_needed:
|
|
all_w.update(load_file(os.path.join(cdir, sn)))
|
|
|
|
def get_weight(proj):
|
|
w = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight")
|
|
ws = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight_scale")
|
|
ws2 = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight_scale_2")
|
|
isc = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.input_scale")
|
|
return w, ws, ws2, isc
|
|
|
|
for gpu in [0, 1]:
|
|
torch.cuda.set_device(gpu)
|
|
dev = f"cuda:{gpu}"
|
|
|
|
se = Nvfp4SharedExpert(hidden_size=7168, intermediate_size=3072, device=dev)
|
|
|
|
gw, gws, gws2, gisc = get_weight('gate_proj')
|
|
uw, uws, uws2, uisc = get_weight('up_proj')
|
|
dw, dws, dws2, disc = get_weight('down_proj')
|
|
|
|
se.l1_fp4 = [torch.cat([gw, uw], dim=0).to(dev)]
|
|
se.l1_sf = [torch.cat([gws, uws], dim=0).to(dev)]
|
|
se.l1_gs = [1.0]
|
|
se.l1_ws2 = [gws2.to(dev) if gws2 is not None else None]
|
|
se._saved_l1_gsa = gisc.float().item()
|
|
|
|
se.l2_fp4 = [dw.to(dev)]
|
|
se.l2_sf = [dws.to(dev)]
|
|
se.l2_gs = [1.0]
|
|
se.l2_ws2 = [dws2.to(dev) if dws2 is not None else None]
|
|
se._saved_l2_gsa = disc.float().item()
|
|
|
|
# Run
|
|
x = torch.randn(1, 7168, dtype=torch.bfloat16, device=dev)
|
|
|
|
# Must set gsa AFTER _ensure_initialized but BEFORE run
|
|
# _ensure_initialized is called lazily in run(), so we need to call it first
|
|
se._ensure_initialized()
|
|
# Now fix the gsa
|
|
se._l1_activation_global_scale = gisc.float().item()
|
|
se._l2_activation_global_scale = disc.float().item()
|
|
|
|
out = se.run(x)
|
|
|
|
has_nan = torch.isnan(out).any().item()
|
|
print(f"GPU {gpu}: |out|={out.abs().max().item() if not has_nan else 'NaN'} has_nan={has_nan} shape={out.shape}")
|