auto: pre-test push for test_se_multi_gpu.py
This commit is contained in:
68
test_se_multi_gpu.py
Normal file
68
test_se_multi_gpu.py
Normal file
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test: does the SE's L1 GEMM produce NaN on non-zero GPUs?"""
|
||||
import torch
|
||||
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Load a real checkpoint weight for layer 0's shared expert
|
||||
from safetensors.torch import load_file
|
||||
import json, os
|
||||
cdir = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
|
||||
# We'll use L0's weights and try running on different GPUs
|
||||
with open(os.path.join(cdir, "model.safetensors.index.json")) as f:
|
||||
wmap = json.load(f)["weight_map"]
|
||||
|
||||
# Load L0 SE weights
|
||||
shards_needed = set()
|
||||
for proj in ['gate_proj', 'up_proj', 'down_proj']:
|
||||
k = f"model.layers.0.mlp.shared_experts.{proj}.weight"
|
||||
if k in wmap:
|
||||
shards_needed.add(wmap[k])
|
||||
|
||||
all_w = {}
|
||||
for sn in shards_needed:
|
||||
all_w.update(load_file(os.path.join(cdir, sn)))
|
||||
|
||||
def get_weight(proj):
|
||||
w = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight")
|
||||
ws = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight_scale")
|
||||
ws2 = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight_scale_2")
|
||||
isc = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.input_scale")
|
||||
return w, ws, ws2, isc
|
||||
|
||||
for gpu in [0, 1]:
|
||||
torch.cuda.set_device(gpu)
|
||||
dev = f"cuda:{gpu}"
|
||||
|
||||
se = Nvfp4SharedExpert(hidden_size=7168, intermediate_size=3072, device=dev)
|
||||
|
||||
gw, gws, gws2, gisc = get_weight('gate_proj')
|
||||
uw, uws, uws2, uisc = get_weight('up_proj')
|
||||
dw, dws, dws2, disc = get_weight('down_proj')
|
||||
|
||||
se.l1_fp4 = [torch.cat([gw, uw], dim=0).to(dev)]
|
||||
se.l1_sf = [torch.cat([gws, uws], dim=0).to(dev)]
|
||||
se.l1_gs = [1.0]
|
||||
se.l1_ws2 = [gws2.to(dev) if gws2 is not None else None]
|
||||
se._saved_l1_gsa = gisc.float().item()
|
||||
|
||||
se.l2_fp4 = [dw.to(dev)]
|
||||
se.l2_sf = [dws.to(dev)]
|
||||
se.l2_gs = [1.0]
|
||||
se.l2_ws2 = [dws2.to(dev) if dws2 is not None else None]
|
||||
se._saved_l2_gsa = disc.float().item()
|
||||
|
||||
# Run
|
||||
x = torch.randn(1, 7168, dtype=torch.bfloat16, device=dev)
|
||||
out = se.run(x)
|
||||
|
||||
# Fix gsa
|
||||
if hasattr(se, '_saved_l1_gsa'):
|
||||
se._l1_activation_global_scale = se._saved_l1_gsa
|
||||
if hasattr(se, '_saved_l2_gsa'):
|
||||
se._l2_activation_global_scale = se._saved_l2_gsa
|
||||
|
||||
has_nan = torch.isnan(out).any().item()
|
||||
print(f"GPU {gpu}: |out|={out.abs().max().item() if not has_nan else 'NaN'} has_nan={has_nan} shape={out.shape}")
|
||||
Reference in New Issue
Block a user