#!/usr/bin/env python3 """ Test C: Warmup-based gs computation — verify that exact warmup gs values produce good cosine when used with quantize_activation_nvfp4. The warmup runs quantize_to_nvfp4 (dynamic gs) on representative input, captures the exact gs for both L1 and L2, then feeds those values to quantize_activation_nvfp4 (fixed gs, cudagraph-safe). Usage (on B200): source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate python3 tests/test_warmup_gs.py """ import torch, sys, os, json sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from cutedsl.bridge import ( quantize_to_nvfp4, quantize_activation_nvfp4, make_b_k_major, assemble_scales_2d_side, assemble_scales_3d_side, run_nvfp4_grouped_gemm, compute_expert_offsets, ) from cutedsl.moe_pipeline import run_nvfp4_moe from vllm.nvfp4_cutedsl import CuTeDSLMoERunner from safetensors import safe_open MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" DEVICE = "cuda" E2M1_LUT = torch.tensor([0.,0.5,1.,1.5,2.,3.,4.,6.,-0.,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32) def dequant(w, sf, gs): dev = w.device lo = E2M1_LUT.to(dev)[(w & 0xF).long()] up = E2M1_LUT.to(dev)[((w >> 4) & 0xF).long()] o = torch.empty(w.shape[0], w.shape[1]*2, dtype=torch.float32, device=dev) o[:, 0::2] = lo; o[:, 1::2] = up return (o * sf.float().repeat_interleave(16, dim=1)[:, :o.shape[1]] * gs).to(torch.bfloat16) def load_tensor(key): with open(os.path.join(MODEL_DIR, "model.safetensors.index.json")) as f: wm = json.load(f)["weight_map"] shard = os.path.join(MODEL_DIR, wm.get(key, "")) if not os.path.exists(shard): return None with safe_open(shard, framework="pt") as f: if key in f.keys(): return f.get_tensor(key).to(DEVICE) return None def load_layer0_experts(expert_indices): l1_fp4, l1_sf, l1_gs = [], [], [] l2_fp4, l2_sf, l2_gs = [], [], [] for e in expert_indices: gw = load_tensor(f"model.layers.0.mlp.experts.{e}.gate_proj.weight") uw = load_tensor(f"model.layers.0.mlp.experts.{e}.up_proj.weight") gsf = load_tensor(f"model.layers.0.mlp.experts.{e}.gate_proj.weight_scale") usf = load_tensor(f"model.layers.0.mlp.experts.{e}.up_proj.weight_scale") ggs = load_tensor(f"model.layers.0.mlp.experts.{e}.gate_proj.weight_scale_2").item() ugs = load_tensor(f"model.layers.0.mlp.experts.{e}.up_proj.weight_scale_2").item() fw = torch.cat([gw, uw], dim=0) fw4 = fw.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() fs = torch.cat([gsf, usf], dim=0).permute(1, 0).contiguous() mgs = max(ggs, ugs) if ggs != ugs: f32 = fs.float() f32[:, :3072] *= (ggs / mgs) f32[:, 3072:] *= (ugs / mgs) fs = f32.to(torch.float8_e4m3fn) l1_fp4.append(fw4); l1_sf.append(fs); l1_gs.append(mgs) dw = load_tensor(f"model.layers.0.mlp.experts.{e}.down_proj.weight") dsf = load_tensor(f"model.layers.0.mlp.experts.{e}.down_proj.weight_scale") dgs = load_tensor(f"model.layers.0.mlp.experts.{e}.down_proj.weight_scale_2").item() l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()) l2_sf.append(dsf.permute(1, 0).contiguous()); l2_gs.append(dgs) return l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs def warmup_compute_gs(runner, hidden_states, topk_weights, topk_ids): """Run a full forward pass with quantize_to_nvfp4 (dynamic gs) to capture the exact gs values for L1 and L2.""" device = hidden_states.device num_tokens = hidden_states.shape[0] top_k = topk_ids.shape[1] # Build slot mapping (same as runner.run()) runner._ensure_stacked() flat_ids = topk_ids.reshape(-1) num_slots = num_tokens * top_k token_indices = runner._token_indices[:num_slots] sort_idx = flat_ids.argsort(stable=True) sorted_ids = flat_ids[sort_idx] sorted_token_ids = token_indices[sort_idx] slot_hidden = hidden_states[sorted_token_ids] # L1: dynamic gs _, _, l1_gs = quantize_to_nvfp4(slot_hidden) # Run L1 GEMM with dynamic gs to get L1 output x_fp4, x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs) expert_id_range = runner._expert_id_range tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() expert_offsets = runner._expert_offsets_buf expert_offsets.zero_() expert_offsets[1:runner.num_experts + 1] = tokens_per_expert.cumsum(0) l1_scale_a = runner._assemble_scales_cudagraph_safe( x_sf, expert_offsets[:runner.num_experts + 1], runner._padded_x_sf_buf_l1, runner._per_expert_scale_bufs_l1 ) l1_gsa = torch.full((runner.num_experts,), l1_gs, dtype=torch.float32, device=device) l1_out = run_nvfp4_grouped_gemm( mat_a=x_fp4, mat_b=runner._l1_mat_b, scale_a=l1_scale_a, scale_b=runner._l1_scale_b, expert_offsets=expert_offsets[1:runner.num_experts + 1], global_scale_a=l1_gsa, global_scale_b=runner._l1_gsb, ) # L2: compute gs from actual L1 output gate = l1_out[:, :runner.intermediate_size] up = l1_out[:, runner.intermediate_size:] activated = torch.nn.functional.silu(gate) * up _, _, l2_gs = quantize_to_nvfp4(activated) return l1_gs, l2_gs def main(): expert_indices = [0, 1, 2] num_experts = len(expert_indices) hidden_size = 7168 intermediate_size = 3072 print("Loading weights...") l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs = load_layer0_experts(expert_indices) torch.manual_seed(42) hidden_states = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 topk_ids = torch.tensor([[0, 1]] * 4, dtype=torch.int32, device=DEVICE) topk_weights = torch.tensor([[0.6, 0.4]] * 4, dtype=torch.float32, device=DEVICE) # Pipeline reference weights = {'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs, 'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs} ref = run_nvfp4_moe(hidden_states.clone(), topk_ids.clone(), topk_weights.clone(), weights, expert_indices) print(f"Pipeline: amax={ref.abs().max():.4f}, mean={ref.float().mean():.6f}") # ── Test 1: Runner with warmup gs (no safety margin) ── print("\n--- Test 1: Warmup gs, no safety margin ---") runner = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=DEVICE) runner.prepare_weights_direct( [w.clone() for w in l1_fp4], [w.clone() for w in l1_sf], list(l1_gs), [w.clone() for w in l2_fp4], [w.clone() for w in l2_sf], list(l2_gs), ) # Use the runner's built-in warmup method runner.compute_activation_global_scales(hidden_states.clone(), topk_weights, topk_ids) result = runner.run(hidden_states.clone(), topk_weights, topk_ids) cos = torch.nn.functional.cosine_similarity( result.flatten().unsqueeze(0).float(), ref.flatten().unsqueeze(0).float() ).item() print(f" Cosine: {cos:.6f}, amax={result.abs().max():.4f}") # ── Test 2: Runner with warmup gs + safety margins ── for safety in [1.0, 1.1, 1.2, 1.5, 2.0]: runner2 = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=DEVICE) runner2.prepare_weights_direct( [w.clone() for w in l1_fp4], [w.clone() for w in l1_sf], list(l1_gs), [w.clone() for w in l2_fp4], [w.clone() for w in l2_sf], list(l2_gs), ) runner2._l1_activation_global_scale = l1_gs_val * safety runner2._l2_activation_global_scale = l2_gs_val * safety result2 = runner2.run(hidden_states.clone(), topk_weights, topk_ids) cos2 = torch.nn.functional.cosine_similarity( result2.flatten().unsqueeze(0).float(), ref.flatten().unsqueeze(0).float() ).item() print(f" Safety {safety:.1f}x: cosine={cos2:.6f}, amax={result2.abs().max():.4f}") # ── Test 3: Different input (verify warmup gs generalizes) ── print("\n--- Test 3: Different input with same warmup gs ---") torch.manual_seed(99) hidden_states2 = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 topk_ids2 = torch.tensor([[0, 1]] * 4, dtype=torch.int32, device=DEVICE) topk_weights2 = torch.tensor([[0.6, 0.4]] * 4, dtype=torch.float32, device=DEVICE) ref2 = run_nvfp4_moe(hidden_states2.clone(), topk_ids2.clone(), topk_weights2.clone(), weights, expert_indices) result3 = runner.run(hidden_states2.clone(), topk_weights2, topk_ids2) cos3 = torch.nn.functional.cosine_similarity( result3.flatten().unsqueeze(0).float(), ref2.flatten().unsqueeze(0).float() ).item() print(f" Different input: cosine={cos3:.6f}") if __name__ == "__main__": main()