"""Standalone test: Shared expert using CuTeDSL dedicated runner. Tests the CuTeDSLSharedExpertRunner for the shared expert path. Compares against BF16 dequantized reference. Usage: python3 test_shared_expert.py """ import torch import torch.nn.functional as F import sys, os, json from safetensors import safe_open MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" DEVICE = "cuda:0" LAYER_IDX = 0 HIDDEN_SIZE = 7168 # shared expert input dim (from checkpoint weight shapes) INTERMEDIATE_SIZE = 3072 SWIGLU_LIMIT = 10.0 NUM_TOKENS = 4 E2M1_LUT = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5, -2., -3., -4., -6.], dtype=torch.float32) _cache = {} def load_tensor(key, wm, model_dir): if key in _cache: return _cache[key] shard_path = os.path.join(model_dir, wm[key]) with safe_open(shard_path, framework="pt") as f: t = f.get_tensor(key) _cache[key] = t return t def dequant_nvfp4(packed_uint8, scale_e4m3, global_scale): """Dequantize NVFP4 weight to BF16 for reference.""" device = packed_uint8.device lut = E2M1_LUT.to(device) lower = lut[(packed_uint8 & 0x0F).long()] upper = lut[((packed_uint8 >> 4) & 0x0F).long()] out_features = packed_uint8.shape[0] in_features = packed_uint8.shape[1] * 2 unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device) unpacked[:, 0::2] = lower unpacked[:, 1::2] = upper block_scale = scale_e4m3.float() block_expanded = block_scale.repeat_interleave(16, dim=1)[:out_features, :in_features] return (unpacked * block_expanded * global_scale).to(torch.bfloat16) def main(): torch.cuda.set_device(0) torch.manual_seed(42) sys.path.insert(0, "/root/nvfp4-megamoe-kernel") from cutedsl.shared_expert_pipeline import CuTeDSLSharedExpertRunner with open(os.path.join(MODEL_PATH, "model.safetensors.index.json")) as f: wm = json.load(f)["weight_map"] P = lambda key: load_tensor(key, wm, MODEL_PATH).to(DEVICE) print("=== Shared Expert Test (CuTeDSL SharedExpertRunner) ===\n") # Load shared expert weights prefix = f"model.layers.{LAYER_IDX}.mlp.shared_experts" gate_w = P(f"{prefix}.gate_proj.weight") gate_sf = P(f"{prefix}.gate_proj.weight_scale") gate_gs = P(f"{prefix}.gate_proj.weight_scale_2").item() up_w = P(f"{prefix}.up_proj.weight") up_sf = P(f"{prefix}.up_proj.weight_scale") up_gs = P(f"{prefix}.up_proj.weight_scale_2").item() down_w = P(f"{prefix}.down_proj.weight") down_sf = P(f"{prefix}.down_proj.weight_scale") down_gs = P(f"{prefix}.down_proj.weight_scale_2").item() print(f"gate_proj: shape={gate_w.shape} gs={gate_gs:.8f} sf_shape={gate_sf.shape}") print(f"up_proj: shape={up_w.shape} gs={up_gs:.8f} sf_shape={up_sf.shape}") print(f"down_proj: shape={down_w.shape} gs={down_gs:.8f} sf_shape={down_sf.shape}") # Stack gate + up into gate_up_proj (same format as MoE L1) # gate/up weights are (intermediate, hidden) uint8 packed gate_up_w = torch.cat([gate_w, up_w], dim=0) gate_up_sf = torch.cat([gate_sf, up_sf], dim=0) mgs = max(gate_gs, up_gs) if gate_gs != up_gs: sf32 = gate_up_sf.float() sf32[:INTERMEDIATE_SIZE] *= (gate_gs / mgs) sf32[INTERMEDIATE_SIZE:] *= (up_gs / mgs) gate_up_sf = sf32.to(torch.float8_e4m3fn) # Convert to CuTeDSL format: # Checkpoint weights are (out_features, in_features) uint8 packed # We need float4_e2m1fn_x2 with (out_features, in_features // 2) after view # Then permute to (in_features // 2, out_features) for K-major (K=in_features) l1_fp4 = [gate_up_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()] l1_sf = [gate_up_sf.permute(1, 0).contiguous()] l2_fp4 = [down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()] l2_sf = [down_sf.permute(1, 0).contiguous()] # Create runner runner = CuTeDSLSharedExpertRunner( hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=8192, device=DEVICE, swiglu_limit=SWIGLU_LIMIT, ) runner.l1_fp4 = l1_fp4 runner.l1_sf = l1_sf runner.l1_gs = [mgs] runner.l2_fp4 = l2_fp4 runner.l2_sf = l2_sf runner.l2_gs = [down_gs] runner.finalize_weights() # Warmup to compute activation global scales dummy = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0 runner._ensure_initialized() runner.compute_activation_global_scales(dummy) print(f"Warmup gs: L1={runner._l1_activation_global_scale:.6f} " f"L2={runner._l2_activation_global_scale:.6f}") # Run CuTeDSL print("\n--- CuTeDSL Forward ---") hidden = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0 with torch.no_grad(): output = runner.run(hidden) print(f"CuTeDSL output: shape={output.shape} amax={output.amax():.4f} " f"NaN={torch.isnan(output).any()}") # BF16 reference print("\n--- BF16 Reference ---") gate_bf16 = dequant_nvfp4(gate_w, gate_sf, gate_gs) up_bf16 = dequant_nvfp4(up_w, up_sf, up_gs) down_bf16 = dequant_nvfp4(down_w, down_sf, down_gs) with torch.no_grad(): gate = hidden @ gate_bf16.T up = hidden @ up_bf16.T gate_silu = F.silu(gate).clamp(max=SWIGLU_LIMIT) up = up.clamp(min=-SWIGLU_LIMIT, max=SWIGLU_LIMIT) intermediate = gate_silu * up ref_output = intermediate @ down_bf16.T print(f"BF16 ref: shape={ref_output.shape} amax={ref_output.amax():.4f}") # Compare cos = F.cosine_similarity(ref_output.flatten().unsqueeze(0), output.flatten().unsqueeze(0)).item() mse = (ref_output - output).pow(2).mean().item() print(f"\n=== RESULT: cosine={cos:.6f} MSE={mse:.6e} ===") if cos >= 0.98: print("✅ PASS") else: print("❌ FAIL") if __name__ == "__main__": main()