test: direct SE L1 test on different GPUs
This commit is contained in:
64
test_se_l1_direct.py
Normal file
64
test_se_l1_direct.py
Normal file
@@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test: shared expert L1 on different GPUs with correct quantization."""
|
||||
import torch
|
||||
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
||||
from safetensors.torch import load_file
|
||||
import json, os
|
||||
|
||||
cdir = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
with open(os.path.join(cdir, "model.safetensors.index.json")) as f:
|
||||
wmap = json.load(f)["weight_map"]
|
||||
|
||||
shards_needed = set()
|
||||
for proj in ['gate_proj', 'up_proj', 'down_proj']:
|
||||
k = f"model.layers.0.mlp.shared_experts.{proj}.weight"
|
||||
if k in wmap:
|
||||
shards_needed.add(wmap[k])
|
||||
|
||||
all_w = {}
|
||||
for sn in shards_needed:
|
||||
all_w.update(load_file(os.path.join(cdir, sn)))
|
||||
|
||||
def get_weight(proj):
|
||||
return (
|
||||
all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight"),
|
||||
all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight_scale"),
|
||||
all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.weight_scale_2"),
|
||||
all_w.get(f"model.layers.0.mlp.shared_experts.{proj}.input_scale"),
|
||||
)
|
||||
|
||||
for gpu in [0, 1]:
|
||||
torch.cuda.set_device(gpu)
|
||||
dev = f"cuda:{gpu}"
|
||||
|
||||
se = Nvfp4SharedExpert(hidden_size=7168, intermediate_size=3072, device=dev, swiglu_limit=10.0)
|
||||
|
||||
gw, gws, gws2, gisc = get_weight('gate_proj')
|
||||
uw, uws, uws2, uisc = get_weight('up_proj')
|
||||
dw, dws, dws2, disc = get_weight('down_proj')
|
||||
|
||||
se.l1_fp4 = [torch.cat([gw, uw], dim=0).to(dev)]
|
||||
se.l1_sf = [torch.cat([gws, uws], dim=0).to(dev)]
|
||||
se.l1_gs = [1.0]
|
||||
se.l1_ws2 = [gws2.to(dev) if gws2 is not None else None]
|
||||
|
||||
se.l2_fp4 = [dw.to(dev)]
|
||||
se.l2_sf = [dws.to(dev)]
|
||||
se.l2_gs = [1.0]
|
||||
se.l2_ws2 = [dws2.to(dev) if dws2 is not None else None]
|
||||
|
||||
# Initialize and set correct gsa
|
||||
se._ensure_initialized()
|
||||
se._l1_activation_global_scale = gisc.float().item()
|
||||
se._l2_activation_global_scale = disc.float().item()
|
||||
|
||||
# Test L1 only
|
||||
x = torch.randn(1, 7168, dtype=torch.bfloat16, device=dev) * 0.5
|
||||
l1_out = se._run_l1(x)
|
||||
has_nan = torch.isnan(l1_out).any().item()
|
||||
print(f"GPU {gpu} SE L1: |out|={l1_out.abs().max().item() if not has_nan else 'NaN'} has_nan={has_nan} shape={l1_out.shape}")
|
||||
|
||||
# Full run
|
||||
out = se.run(x)
|
||||
has_nan = torch.isnan(out).any().item()
|
||||
print(f"GPU {gpu} SE full: |out|={out.abs().max().item() if not has_nan else 'NaN'} has_nan={has_nan} shape={out.shape}")
|
||||
Reference in New Issue
Block a user