Files
nvfp4-megamoe-kernel/tests/integration/test_se_multi_gpu.py
biondizzle 8de47e26ce Cleanup Step 1: Move root-level files to proper directories
- Move test_*.py → tests/integration/
- Move probe_*.py, dump_*.py → helpers/
- Move PERFORMANCE_AUDIT.md → docs/
- Move single_shot_PYTORCH_REFERENCE.py → dsv4/reference/
- Fix 3 import references in test_layer_comparison, test_mhc_comparison, test_compressor_position_bias
- Add helpers/import_closure.py (dead-code detection tool)
2026-06-02 19:24:39 +00:00

71 lines
2.4 KiB
Python

#!/usr/bin/env python3
"""Test: does the SE's L1 GEMM produce NaN on non-zero GPUs?"""
import torch
from dsv4.layers.shared_expert import Nvfp4SharedExpert
torch.manual_seed(42)
# Load a real checkpoint weight for layer 0's shared expert
from safetensors.torch import load_file
import json, os
cdir = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
# We'll use L0's weights and try running on different GPUs
with open(os.path.join(cdir, "model.safetensors.index.json")) as f:
wmap = json.load(f)["weight_map"]
# Load L0 SE weights
shards_needed = set()
for proj in ['gate_proj', 'up_proj', 'down_proj']:
k = f"model.layers.0.mlp.shared_experts.{proj}.weight"
if k in wmap:
shards_needed.add(wmap[k])
all_w = {}
for sn in shards_needed:
all_w.update(load_file(os.path.join(cdir, sn)))
def get_weight(proj):
w = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight")
ws = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight_scale")
ws2 = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight_scale_2")
isc = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.input_scale")
return w, ws, ws2, isc
for gpu in [0, 1]:
torch.cuda.set_device(gpu)
dev = f"cuda:{gpu}"
se = Nvfp4SharedExpert(hidden_size=7168, intermediate_size=3072, device=dev)
gw, gws, gws2, gisc = get_weight('gate_proj')
uw, uws, uws2, uisc = get_weight('up_proj')
dw, dws, dws2, disc = get_weight('down_proj')
se.l1_fp4 = [torch.cat([gw, uw], dim=0).to(dev)]
se.l1_sf = [torch.cat([gws, uws], dim=0).to(dev)]
se.l1_gs = [1.0]
se.l1_ws2 = [gws2.to(dev) if gws2 is not None else None]
se._saved_l1_gsa = gisc.float().item()
se.l2_fp4 = [dw.to(dev)]
se.l2_sf = [dws.to(dev)]
se.l2_gs = [1.0]
se.l2_ws2 = [dws2.to(dev) if dws2 is not None else None]
se._saved_l2_gsa = disc.float().item()
# Run
x = torch.randn(1, 7168, dtype=torch.bfloat16, device=dev)
# Must set gsa AFTER _ensure_initialized but BEFORE run
# _ensure_initialized is called lazily in run(), so we need to call it first
se._ensure_initialized()
# Now fix the gsa
se._l1_activation_global_scale = gisc.float().item()
se._l2_activation_global_scale = disc.float().item()
out = se.run(x)
has_nan = torch.isnan(out).any().item()
print(f"GPU {gpu}: |out|={out.abs().max().item() if not has_nan else 'NaN'} has_nan={has_nan} shape={out.shape}")