diff --git a/test_se_multi_gpu.py b/test_se_multi_gpu.py new file mode 100644 index 00000000..794c8f55 --- /dev/null +++ b/test_se_multi_gpu.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +"""Test: does the SE's L1 GEMM produce NaN on non-zero GPUs?""" +import torch +from dsv4.layers.shared_expert import Nvfp4SharedExpert + +torch.manual_seed(42) + +# Load a real checkpoint weight for layer 0's shared expert +from safetensors.torch import load_file +import json, os +cdir = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" + +# We'll use L0's weights and try running on different GPUs +with open(os.path.join(cdir, "model.safetensors.index.json")) as f: + wmap = json.load(f)["weight_map"] + +# Load L0 SE weights +shards_needed = set() +for proj in ['gate_proj', 'up_proj', 'down_proj']: + k = f"model.layers.0.mlp.shared_experts.{proj}.weight" + if k in wmap: + shards_needed.add(wmap[k]) + +all_w = {} +for sn in shards_needed: + all_w.update(load_file(os.path.join(cdir, sn))) + +def get_weight(proj): + w = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight") + ws = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight_scale") + ws2 = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight_scale_2") + isc = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.input_scale") + return w, ws, ws2, isc + +for gpu in [0, 1]: + torch.cuda.set_device(gpu) + dev = f"cuda:{gpu}" + + se = Nvfp4SharedExpert(hidden_size=7168, intermediate_size=3072, device=dev) + + gw, gws, gws2, gisc = get_weight('gate_proj') + uw, uws, uws2, uisc = get_weight('up_proj') + dw, dws, dws2, disc = get_weight('down_proj') + + se.l1_fp4 = [torch.cat([gw, uw], dim=0).to(dev)] + se.l1_sf = [torch.cat([gws, uws], dim=0).to(dev)] + se.l1_gs = [1.0] + se.l1_ws2 = [gws2.to(dev) if gws2 is not None else None] + se._saved_l1_gsa = gisc.float().item() + + se.l2_fp4 = [dw.to(dev)] + se.l2_sf = [dws.to(dev)] + se.l2_gs = [1.0] + se.l2_ws2 = [dws2.to(dev) if dws2 is not None else None] + se._saved_l2_gsa = disc.float().item() + + # Run + x = torch.randn(1, 7168, dtype=torch.bfloat16, device=dev) + out = se.run(x) + + # Fix gsa + if hasattr(se, '_saved_l1_gsa'): + se._l1_activation_global_scale = se._saved_l1_gsa + if hasattr(se, '_saved_l2_gsa'): + se._l2_activation_global_scale = se._saved_l2_gsa + + has_nan = torch.isnan(out).any().item() + print(f"GPU {gpu}: |out|={out.abs().max().item() if not has_nan else 'NaN'} has_nan={has_nan} shape={out.shape}")