#!/usr/bin/env python3 """Test: does the SE's L1 GEMM produce NaN on non-zero GPUs?""" import torch from dsv4.layers.shared_expert import Nvfp4SharedExpert torch.manual_seed(42) # Load a real checkpoint weight for layer 0's shared expert from safetensors.torch import load_file import json, os cdir = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" # We'll use L0's weights and try running on different GPUs with open(os.path.join(cdir, "model.safetensors.index.json")) as f: wmap = json.load(f)["weight_map"] # Load L0 SE weights shards_needed = set() for proj in ['gate_proj', 'up_proj', 'down_proj']: k = f"model.layers.0.mlp.shared_experts.{proj}.weight" if k in wmap: shards_needed.add(wmap[k]) all_w = {} for sn in shards_needed: all_w.update(load_file(os.path.join(cdir, sn))) def get_weight(proj): w = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight") ws = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight_scale") ws2 = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight_scale_2") isc = all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.input_scale") return w, ws, ws2, isc for gpu in [0, 1]: torch.cuda.set_device(gpu) dev = f"cuda:{gpu}" se = Nvfp4SharedExpert(hidden_size=7168, intermediate_size=3072, device=dev) gw, gws, gws2, gisc = get_weight('gate_proj') uw, uws, uws2, uisc = get_weight('up_proj') dw, dws, dws2, disc = get_weight('down_proj') se.l1_fp4 = [torch.cat([gw, uw], dim=0).to(dev)] se.l1_sf = [torch.cat([gws, uws], dim=0).to(dev)] se.l1_gs = [1.0] se.l1_ws2 = [gws2.to(dev) if gws2 is not None else None] se._saved_l1_gsa = gisc.float().item() se.l2_fp4 = [dw.to(dev)] se.l2_sf = [dws.to(dev)] se.l2_gs = [1.0] se.l2_ws2 = [dws2.to(dev) if dws2 is not None else None] se._saved_l2_gsa = disc.float().item() # Run x = torch.randn(1, 7168, dtype=torch.bfloat16, device=dev) # Must set gsa AFTER _ensure_initialized but BEFORE run # _ensure_initialized is called lazily in run(), so we need to call it first se._ensure_initialized() # Now fix the gsa se._l1_activation_global_scale = gisc.float().item() se._l2_activation_global_scale = disc.float().item() out = se.run(x) has_nan = torch.isnan(out).any().item() print(f"GPU {gpu}: |out|={out.abs().max().item() if not has_nan else 'NaN'} has_nan={has_nan} shape={out.shape}")