#!/usr/bin/env python3 """ Layer 0 full MoE pipeline test: CuTeDSL NVFP4 vs BF16 reference. Tests the complete pipeline: L1→SiLU→L2→scatter If cosine < 0.99, exits with error. """ import os import sys import json import glob import torch from safetensors import safe_open REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, REPO_ROOT) from cutedsl.moe_pipeline import ( run_nvfp4_moe, ) NVFP4_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" LAYER_IDX = 0 DEVICE = "cuda" COSINE_THRESHOLD = 0.98 # Double quantization loss from checkpoint dequant→requant E2M1_LUT = torch.tensor([ 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ], dtype=torch.float32) def find_shards(model_dir): index_path = os.path.join(model_dir, "model.safetensors.index.json") key_to_shard = {} if os.path.exists(index_path): with open(index_path) as f: index = json.load(f) for key, shard in index["weight_map"].items(): key_to_shard[key] = os.path.join(model_dir, shard) else: for sf in glob.glob(os.path.join(model_dir, "*.safetensors")): with safe_open(sf, framework="pt") as f: for key in f.keys(): key_to_shard[key] = sf return key_to_shard def load_layer_tensors(model_dir, layer_idx): key_to_shard = find_shards(model_dir) layer_prefix = f"layers.{layer_idx}." shard_to_keys = {} for key, shard in key_to_shard.items(): norm_key = key.removeprefix("model.") if not norm_key.startswith(layer_prefix): continue shard_to_keys.setdefault(shard, []).append((key, norm_key)) tensors = {} for shard, keys in shard_to_keys.items(): with safe_open(shard, framework="pt") as f: for orig_key, norm_key in keys: tensors[norm_key] = f.get_tensor(orig_key) return tensors def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale): device = packed_uint8.device lut = E2M1_LUT.to(device) lower = lut[(packed_uint8 & 0x0F).long()] upper = lut[((packed_uint8 >> 4) & 0x0F).long()] out_features = packed_uint8.shape[0] in_features = packed_uint8.shape[1] * 2 unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device) unpacked[:, 0::2] = lower unpacked[:, 1::2] = upper block_scale = scale_e4m3.float() block_expanded = block_scale.repeat_interleave(16, dim=1)[:, :in_features] return (unpacked * block_expanded * global_scale).to(torch.bfloat16) def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices): experts = {} for e in expert_indices: expert = {} for proj in ["gate_proj", "up_proj", "down_proj"]: weight_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight" scale_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale" gs_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale_2" if weight_key not in nvfp4_tensors: if proj == "down_proj" and e == 211: continue raise KeyError(f"Missing {weight_key}") weight = nvfp4_tensors[weight_key].to(DEVICE) scale = nvfp4_tensors[scale_key].to(DEVICE) global_scale = nvfp4_tensors[gs_key].item() expert[proj] = dequantize_nvfp4_weight(weight, scale, global_scale) experts[e] = expert return experts def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights): num_tokens, hidden_size = hidden_states.shape top_k = expert_ids.shape[1] output = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) for t in range(num_tokens): for k in range(top_k): e = expert_ids[t, k].item() w = expert_weights[t, k].item() if e not in experts: continue x = hidden_states[t] gate = x @ experts[e]["gate_proj"].T up = x @ experts[e]["up_proj"].T activated = torch.nn.functional.silu(gate) * up if "down_proj" in experts[e]: y = activated @ experts[e]["down_proj"].T else: y = activated[:hidden_size] output[t] += w * y return output def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, intermediate_size): """Prepare weights via direct view-cast (no BF16 round-trip). Checkpoint uint8 → float4_e2m1fn_x2 (byte-preserving). Block scales float8_e4m3fn → used directly. Global scales float32 → used directly. For L1 (gate+up fused): normalize dual global scales to max, fold ratio into block scales via float32 (one multiply + float8 round-trip on ratio only). """ l1_fp4, l1_sf, l1_gs = [], [], [] l2_fp4, l2_sf, l2_gs = [], [], [] for e in expert_indices: # L1: gate + up gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item() # Fuse gate+up along N, transpose to K-major fused_w = torch.cat([gate_w, up_w], dim=0) # (2*intermediate, hidden//2) uint8 fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() # (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate # Fuse block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N) fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) = (N, K_sf) fused_sf = fused_sf.permute(1, 0).contiguous() # → (K_sf, N) # Normalize dual global scales l1_max_gs = max(gate_gs, up_gs) if gate_gs != up_gs: fused_sf_f32 = fused_sf.float() # Gate is first intermediate cols, up is second (after transpose) fused_sf_f32[:, :intermediate_size] *= (gate_gs / l1_max_gs) fused_sf_f32[:, intermediate_size:] *= (up_gs / l1_max_gs) fused_sf = fused_sf_f32.to(torch.float8_e4m3fn) l1_fp4.append(fused_w_fp4) l1_sf.append(fused_sf) l1_gs.append(l1_max_gs) # L2: down down_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" if down_key in nvfp4_tensors: down_w = nvfp4_tensors[down_key].to(DEVICE) down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item() down_w_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() # (intermediate//2, hidden) — K=intermediate packed, N=hidden l2_fp4.append(down_w_fp4) l2_sf.append(down_sf.permute(1, 0).contiguous()) # (N, K_sf) → (K_sf, N) l2_gs.append(down_gs) else: # Expert 211 has no down_proj l2_fp4.append(torch.zeros(3072 // 2, 7168, dtype=torch.float4_e2m1fn_x2, device=DEVICE)) l2_sf.append(torch.ones(3072 // 16, 7168, dtype=torch.float8_e4m3fn, device=DEVICE)) # (K_sf, N) l2_gs.append(1.0) return { 'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs, 'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs, } def main(): torch.manual_seed(42) expert_indices = [0, 1, 2] top_k = 2 num_tokens = 4 hidden_size = 7168 print("=" * 70) print(" Loading NVFP4 checkpoint layer 0") print("=" * 70) nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, LAYER_IDX) print(f" {len(nvfp4_tensors)} tensors loaded") # Prepare weights — DIRECT PATH (no BF16 round-trip) print("\n Preparing NVFP4 weights (direct view-cast)...") weights = prepare_nvfp4_weights_direct(nvfp4_tensors, LAYER_IDX, expert_indices, 3072) print(f" L1: {len(weights['l1_fp4'])} experts, shape {weights['l1_fp4'][0].shape}") print(f" L2: {len(weights['l2_fp4'])} experts, shape {weights['l2_fp4'][0].shape}") # Dequantize for BF16 reference print("\n Dequantizing NVFP4 -> BF16 reference...") nvfp4_experts_bf16 = dequantize_nvfp4_experts(nvfp4_tensors, LAYER_IDX, expert_indices) # Test input hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE) expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE) # BF16 reference print("\n Running BF16 MoE reference...") ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights) print(f" BF16 ref: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}") del nvfp4_experts_bf16 torch.cuda.empty_cache() # CuTeDSL NVFP4 pipeline print("\n Running CuTeDSL NVFP4 MoE pipeline (first run compiles)...") kernel_output = run_nvfp4_moe( hidden_states, expert_ids, expert_weights, weights, expert_indices, ) print(f" Kernel: amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}") # Compare cosine = torch.nn.functional.cosine_similarity( kernel_output.flatten().unsqueeze(0).float(), ref_output.flatten().unsqueeze(0).float(), ).item() mse = (kernel_output.float() - ref_output.float()).pow(2).mean().item() print(f"\n{'=' * 70}") print(f" RESULT: cosine={cosine:.6f} MSE={mse:.6e}") print(f"{'=' * 70}") if cosine < COSINE_THRESHOLD: print(f" FAIL: cosine {cosine:.6f} < {COSINE_THRESHOLD}") sys.exit(1) else: print(f" PASS: cosine {cosine:.6f} >= {COSINE_THRESHOLD}") if __name__ == "__main__": main()