From 5ba77e355f5f94baf6be56e7d206b8af46392edd Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 08:06:27 +0000 Subject: [PATCH] test: warmup gs computation with safety margin sweep --- tests/test_warmup_gs.py | 200 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 tests/test_warmup_gs.py diff --git a/tests/test_warmup_gs.py b/tests/test_warmup_gs.py new file mode 100644 index 00000000..ef8114c5 --- /dev/null +++ b/tests/test_warmup_gs.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +""" +Test C: Warmup-based gs computation — verify that exact warmup gs values +produce good cosine when used with quantize_activation_nvfp4. + +The warmup runs quantize_to_nvfp4 (dynamic gs) on representative input, +captures the exact gs for both L1 and L2, then feeds those values to +quantize_activation_nvfp4 (fixed gs, cudagraph-safe). + +Usage (on B200): + source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate + python3 tests/test_warmup_gs.py +""" +import torch, sys, os, json +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from cutedsl.bridge import ( + quantize_to_nvfp4, quantize_activation_nvfp4, + make_b_k_major, assemble_scales_2d_side, assemble_scales_3d_side, + run_nvfp4_grouped_gemm, compute_expert_offsets, +) +from cutedsl.moe_pipeline import run_nvfp4_moe +from vllm.nvfp4_cutedsl import CuTeDSLMoERunner +from safetensors import safe_open + +MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" +DEVICE = "cuda" +E2M1_LUT = torch.tensor([0.,0.5,1.,1.5,2.,3.,4.,6.,-0.,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32) + + +def dequant(w, sf, gs): + dev = w.device + lo = E2M1_LUT.to(dev)[(w & 0xF).long()] + up = E2M1_LUT.to(dev)[((w >> 4) & 0xF).long()] + o = torch.empty(w.shape[0], w.shape[1]*2, dtype=torch.float32, device=dev) + o[:, 0::2] = lo; o[:, 1::2] = up + return (o * sf.float().repeat_interleave(16, dim=1)[:, :o.shape[1]] * gs).to(torch.bfloat16) + + +def load_tensor(key): + with open(os.path.join(MODEL_DIR, "model.safetensors.index.json")) as f: + wm = json.load(f)["weight_map"] + shard = os.path.join(MODEL_DIR, wm.get(key, "")) + if not os.path.exists(shard): return None + with safe_open(shard, framework="pt") as f: + if key in f.keys(): return f.get_tensor(key).to(DEVICE) + return None + + +def load_layer0_experts(expert_indices): + l1_fp4, l1_sf, l1_gs = [], [], [] + l2_fp4, l2_sf, l2_gs = [], [], [] + for e in expert_indices: + gw = load_tensor(f"model.layers.0.mlp.experts.{e}.gate_proj.weight") + uw = load_tensor(f"model.layers.0.mlp.experts.{e}.up_proj.weight") + gsf = load_tensor(f"model.layers.0.mlp.experts.{e}.gate_proj.weight_scale") + usf = load_tensor(f"model.layers.0.mlp.experts.{e}.up_proj.weight_scale") + ggs = load_tensor(f"model.layers.0.mlp.experts.{e}.gate_proj.weight_scale_2").item() + ugs = load_tensor(f"model.layers.0.mlp.experts.{e}.up_proj.weight_scale_2").item() + fw = torch.cat([gw, uw], dim=0) + fw4 = fw.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + fs = torch.cat([gsf, usf], dim=0).permute(1, 0).contiguous() + mgs = max(ggs, ugs) + if ggs != ugs: + f32 = fs.float() + f32[:, :3072] *= (ggs / mgs) + f32[:, 3072:] *= (ugs / mgs) + fs = f32.to(torch.float8_e4m3fn) + l1_fp4.append(fw4); l1_sf.append(fs); l1_gs.append(mgs) + dw = load_tensor(f"model.layers.0.mlp.experts.{e}.down_proj.weight") + dsf = load_tensor(f"model.layers.0.mlp.experts.{e}.down_proj.weight_scale") + dgs = load_tensor(f"model.layers.0.mlp.experts.{e}.down_proj.weight_scale_2").item() + l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()) + l2_sf.append(dsf.permute(1, 0).contiguous()); l2_gs.append(dgs) + return l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs + + +def warmup_compute_gs(runner, hidden_states, topk_weights, topk_ids): + """Run a full forward pass with quantize_to_nvfp4 (dynamic gs) + to capture the exact gs values for L1 and L2.""" + device = hidden_states.device + num_tokens = hidden_states.shape[0] + top_k = topk_ids.shape[1] + + # Build slot mapping (same as runner.run()) + flat_ids = topk_ids.reshape(-1) + num_slots = num_tokens * top_k + token_indices = runner._token_indices[:num_slots] + sort_idx = flat_ids.argsort(stable=True) + sorted_ids = flat_ids[sort_idx] + sorted_token_ids = token_indices[sort_idx] + slot_hidden = hidden_states[sorted_token_ids] + + # L1: dynamic gs + _, _, l1_gs = quantize_to_nvfp4(slot_hidden) + + # Run L1 GEMM with dynamic gs to get L1 output + x_fp4, x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs) + + expert_id_range = runner._expert_id_range + tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() + expert_offsets = runner._expert_offsets_buf + expert_offsets.zero_() + expert_offsets[1:runner.num_experts + 1] = tokens_per_expert.cumsum(0) + + l1_scale_a = runner._assemble_scales_cudagraph_safe( + x_sf, expert_offsets[:runner.num_experts + 1], + runner._padded_x_sf_buf_l1, runner._per_expert_scale_bufs_l1 + ) + l1_gsa = torch.full((runner.num_experts,), l1_gs, dtype=torch.float32, device=device) + + l1_out = run_nvfp4_grouped_gemm( + mat_a=x_fp4, mat_b=runner._l1_mat_b, + scale_a=l1_scale_a, scale_b=runner._l1_scale_b, + expert_offsets=expert_offsets[1:runner.num_experts + 1], + global_scale_a=l1_gsa, global_scale_b=runner._l1_gsb, + ) + + # L2: compute gs from actual L1 output + gate = l1_out[:, :runner.intermediate_size] + up = l1_out[:, runner.intermediate_size:] + activated = torch.nn.functional.silu(gate) * up + _, _, l2_gs = quantize_to_nvfp4(activated) + + return l1_gs, l2_gs + + +def main(): + expert_indices = [0, 1, 2] + num_experts = len(expert_indices) + hidden_size = 7168 + intermediate_size = 3072 + + print("Loading weights...") + l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs = load_layer0_experts(expert_indices) + + torch.manual_seed(42) + hidden_states = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 + topk_ids = torch.tensor([[0, 1]] * 4, dtype=torch.int32, device=DEVICE) + topk_weights = torch.tensor([[0.6, 0.4]] * 4, dtype=torch.float32, device=DEVICE) + + # Pipeline reference + weights = {'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs, + 'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs} + ref = run_nvfp4_moe(hidden_states.clone(), topk_ids.clone(), topk_weights.clone(), weights, expert_indices) + print(f"Pipeline: amax={ref.abs().max():.4f}, mean={ref.float().mean():.6f}") + + # ── Test 1: Runner with warmup gs (no safety margin) ── + print("\n--- Test 1: Warmup gs, no safety margin ---") + runner = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=DEVICE) + runner.prepare_weights_direct( + [w.clone() for w in l1_fp4], [w.clone() for w in l1_sf], list(l1_gs), + [w.clone() for w in l2_fp4], [w.clone() for w in l2_sf], list(l2_gs), + ) + + l1_gs_val, l2_gs_val = warmup_compute_gs(runner, hidden_states, topk_weights, topk_ids) + print(f" Warmup L1 gs: {l1_gs_val:.10f}") + print(f" Warmup L2 gs: {l2_gs_val:.10f}") + + runner._l1_activation_global_scale = l1_gs_val + runner._l2_activation_global_scale = l2_gs_val + result = runner.run(hidden_states.clone(), topk_weights, topk_ids) + + cos = torch.nn.functional.cosine_similarity( + result.flatten().unsqueeze(0).float(), ref.flatten().unsqueeze(0).float() + ).item() + print(f" Cosine: {cos:.6f}, amax={result.abs().max():.4f}") + + # ── Test 2: Runner with warmup gs + safety margins ── + for safety in [1.0, 1.1, 1.2, 1.5, 2.0]: + runner2 = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=DEVICE) + runner2.prepare_weights_direct( + [w.clone() for w in l1_fp4], [w.clone() for w in l1_sf], list(l1_gs), + [w.clone() for w in l2_fp4], [w.clone() for w in l2_sf], list(l2_gs), + ) + runner2._l1_activation_global_scale = l1_gs_val * safety + runner2._l2_activation_global_scale = l2_gs_val * safety + result2 = runner2.run(hidden_states.clone(), topk_weights, topk_ids) + cos2 = torch.nn.functional.cosine_similarity( + result2.flatten().unsqueeze(0).float(), ref.flatten().unsqueeze(0).float() + ).item() + print(f" Safety {safety:.1f}x: cosine={cos2:.6f}, amax={result2.abs().max():.4f}") + + # ── Test 3: Different input (verify warmup gs generalizes) ── + print("\n--- Test 3: Different input with same warmup gs ---") + torch.manual_seed(99) + hidden_states2 = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 + topk_ids2 = torch.tensor([[0, 1]] * 4, dtype=torch.int32, device=DEVICE) + topk_weights2 = torch.tensor([[0.6, 0.4]] * 4, dtype=torch.float32, device=DEVICE) + + ref2 = run_nvfp4_moe(hidden_states2.clone(), topk_ids2.clone(), topk_weights2.clone(), weights, expert_indices) + result3 = runner.run(hidden_states2.clone(), topk_weights2, topk_ids2) + cos3 = torch.nn.functional.cosine_similarity( + result3.flatten().unsqueeze(0).float(), ref2.flatten().unsqueeze(0).float() + ).item() + print(f" Different input: cosine={cos3:.6f}") + + +if __name__ == "__main__": + main()