diff --git a/tests/layertest.py b/tests/layertest.py index 56beeb00..672c50ac 100644 --- a/tests/layertest.py +++ b/tests/layertest.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 """ -Layer 0 comparison test: original checkpoint vs NVFP4 checkpoint + our kernel. +Layer 0 kernel comparison test: NVFP4 kernel vs BF16 reference. -Loads layer 0 expert weights from both checkpoints, runs the same deterministic -MoE forward pass, and compares the results. No vLLM, no Docker, no tensor -parallelism — just raw weights + our GEMM kernel. +No vLLM, no Docker, no tensor parallelism. Just raw weights + our kernel. +If cosine < 0.99, the test exits with error. Usage: python3 layertest.py @@ -19,12 +18,12 @@ from safetensors import safe_open # ── Constants ────────────────────────────────────────────────────────── -ORIG_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro" NVFP4_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" LAYER_IDX = 0 DEVICE = "cuda" +COSINE_THRESHOLD = 0.99 -# E2M1 FP4 lookup table (shared by both formats) +# E2M1 FP4 lookup table E2M1_LUT = torch.tensor([ 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, @@ -43,7 +42,6 @@ def find_shards(model_dir): for key, shard in index["weight_map"].items(): key_to_shard[key] = os.path.join(model_dir, shard) else: - # Single shard for sf in glob.glob(os.path.join(model_dir, "*.safetensors")): with safe_open(sf, framework="pt") as f: for key in f.keys(): @@ -53,18 +51,12 @@ def find_shards(model_dir): def load_layer_tensors(model_dir, layer_idx): - """Load all tensors for a specific layer from the checkpoint. - - Returns dict of {key: tensor} for all keys matching the layer. - Keys are normalized to NOT have 'model.' prefix. - """ + """Load all tensors for a specific layer. Keys normalized (no 'model.' prefix).""" key_to_shard = find_shards(model_dir) layer_prefix = f"layers.{layer_idx}." - # Group by shard to minimize file opens shard_to_keys = {} for key, shard in key_to_shard.items(): - # Normalize: strip 'model.' prefix if present norm_key = key.removeprefix("model.") if not norm_key.startswith(layer_prefix): continue @@ -79,83 +71,21 @@ def load_layer_tensors(model_dir, layer_idx): return tensors -def print_layer_keys(tensors, label): - """Print sorted tensor keys with shapes and dtypes.""" - print(f"\n{'='*70}") - print(f" {label} — {len(tensors)} tensors") - print(f"{'='*70}") - for key in sorted(tensors.keys()): +def print_layer_keys(tensors, label, max_keys=20): + """Print sorted tensor keys with shapes and dtypes (first N).""" + print(f"\n {label} — {len(tensors)} tensors") + sorted_keys = sorted(tensors.keys()) + for key in sorted_keys[:max_keys]: t = tensors[key] - print(f" {key}: dtype={t.dtype} shape={tuple(t.shape)}") + print(f" {key}: dtype={t.dtype} shape={tuple(t.shape)}") + if len(sorted_keys) > max_keys: + print(f" ... ({len(sorted_keys) - max_keys} more)") -# ── Dequantization: Original checkpoint (MXFP4) ─────────────────────── - -def dequantize_mxfp4_weight(packed_uint8, scale_e8m0): - """Dequantize MXFP4 (E2M1 + E8M0, block_size=32) to BF16. - - Original checkpoint format: - packed_uint8: (out_features, in_features//2) uint8 - scale_e8m0: (out_features, in_features//32) float8_e8m0fnu - """ - device = packed_uint8.device - lut = E2M1_LUT.to(device) - - lower = lut[(packed_uint8 & 0x0F).long()] - upper = lut[((packed_uint8 >> 4) & 0x0F).long()] - - out_features = packed_uint8.shape[0] - in_features = packed_uint8.shape[1] * 2 - - unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device) - unpacked[:, 0::2] = lower - unpacked[:, 1::2] = upper - - # E8M0 → float32: exponent-only format, represents 2^(x - bias) - scale_f32 = scale_e8m0.float() - scale_expanded = scale_f32.repeat_interleave(32, dim=1)[:, :in_features] - - return (unpacked * scale_expanded).to(torch.bfloat16) - - -def dequantize_mxfp4_experts(orig_tensors, layer_idx, expert_indices): - """Dequantize expert weights from original MXFP4 checkpoint. - - Original checkpoint key format: layers.{L}.ffn.experts.{E}.{w1/w2/w3}.{weight/scale} - w1 = gate_proj, w3 = up_proj, w2 = down_proj - - Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16. - """ - experts = {} - for e in expert_indices: - expert = {} - for proj, shard in [("gate_proj", "w1"), ("up_proj", "w3"), ("down_proj", "w2")]: - weight_key = f"layers.{layer_idx}.ffn.experts.{e}.{shard}.weight" - scale_key = f"layers.{layer_idx}.ffn.experts.{e}.{shard}.scale" - - if weight_key not in orig_tensors: - if proj == "down_proj" and e == 211: - continue - raise KeyError(f"Missing {weight_key}") - - weight = orig_tensors[weight_key].to(DEVICE) - scale = orig_tensors[scale_key].to(DEVICE) - expert[proj] = dequantize_mxfp4_weight(weight, scale) - - experts[e] = expert - return experts - - -# ── Dequantization: NVFP4 checkpoint ────────────────────────────────── +# ── NVFP4 Dequantization ────────────────────────────────────────────── def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale): - """Dequantize NVFP4 (E2M1 + E4M3 block scale + float32 global) to BF16. - - NVFP4 checkpoint format: - packed_uint8: (out_features, in_features//2) uint8 - scale_e4m3: (out_features, in_features//16) float8_e4m3fn - global_scale: float32 scalar - """ + """Dequantize NVFP4 (E2M1 + E4M3 + global) to BF16.""" device = packed_uint8.device lut = E2M1_LUT.to(device) @@ -169,20 +99,14 @@ def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale): unpacked[:, 0::2] = lower unpacked[:, 1::2] = upper - block_scale = scale_e4m3.float() # float8_e4m3fn → float32 + block_scale = scale_e4m3.float() block_expanded = block_scale.repeat_interleave(16, dim=1)[:, :in_features] - # Weight dequant = e2m1 * block_scale * global_scale return (unpacked * block_expanded * global_scale).to(torch.bfloat16) def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices): - """Dequantize expert weights from NVFP4 checkpoint. - - NVFP4 checkpoint key format: layers.{L}.mlp.experts.{E}.{gate_proj/up_proj/down_proj}.{weight/weight_scale/weight_scale_2} - - Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16. - """ + """Dequantize expert weights from NVFP4 checkpoint → BF16.""" experts = {} for e in expert_indices: expert = {} @@ -205,20 +129,10 @@ def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices): return experts -# ── MoE Forward Pass (BF16 reference) ───────────────────────────────── +# ── BF16 MoE Forward ─────────────────────────────────────────────────── def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights): - """Run MoE forward pass in pure BF16. - - Args: - hidden_states: (num_tokens, hidden_size) BF16 - experts: dict {expert_id: {gate_proj, up_proj, down_proj}} BF16 - expert_ids: (num_tokens, top_k) int — which expert per token per slot - expert_weights: (num_tokens, top_k) float32 — routing weights - - Returns: - output: (num_tokens, hidden_size) BF16 - """ + """Run MoE forward pass in pure BF16 (torch.matmul).""" num_tokens, hidden_size = hidden_states.shape top_k = expert_ids.shape[1] output = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) @@ -231,92 +145,66 @@ def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights): if e not in experts: continue - x = hidden_states[t] # (hidden_size,) - gate = x @ experts[e]["gate_proj"].T # (intermediate//2,) - up = x @ experts[e]["up_proj"].T # (intermediate//2,) - activated = torch.nn.functional.silu(gate) * up # (intermediate//2,) + x = hidden_states[t] + gate = x @ experts[e]["gate_proj"].T + up = x @ experts[e]["up_proj"].T + activated = torch.nn.functional.silu(gate) * up if "down_proj" in experts[e]: - y = activated @ experts[e]["down_proj"].T # (hidden_size,) + y = activated @ experts[e]["down_proj"].T else: - y = activated[:hidden_size] # shared expert, no down_proj + y = activated[:hidden_size] output[t] += w * y return output -# ── MoE Forward Pass (NVFP4 kernel) ─────────────────────────────────── +# ── NVFP4 Kernel MoE Forward ────────────────────────────────────────── def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, expert_weights): - """Run MoE forward pass using our NVFP4 kernel. - - Loads weights directly from NVFP4 checkpoint (no vLLM), transforms them - for CUTLASS, and runs the grouped GEMM. - """ + """Run MoE forward pass using our NVFP4 kernel.""" from nvfp4_megamoe_kernel import ( stage_activation, nvfp4_mega_moe_full, transform_nvfp4_weights_for_mega_moe, - SymmBuffer, get_symm_buffer_for_nvfp4_mega_moe, ) num_tokens, hidden_size = hidden_states.shape top_k = expert_ids.shape[1] - # Collect the experts we need unique_experts = sorted(set(expert_ids.flatten().tolist())) num_experts = len(unique_experts) expert_map = {e: i for i, e in enumerate(unique_experts)} - # Load NVFP4 weights for these experts - # Shapes: gate_proj.weight = (3072, 3584) uint8, weight_scale = (3072, 448) float8_e4m3fn - intermediate_half = 3072 # intermediate_size // 2 + intermediate_half = 3072 hidden_half = hidden_size // 2 - l1_weights = [] # gate + up fused - l1_scales = [] - l1_global_scales = [] - l2_weights = [] # down - l2_scales = [] - l2_global_scales = [] + l1_weights, l1_scales, l1_global_scales = [], [], [] + l2_weights, l2_scales, l2_global_scales = [], [], [] for e in unique_experts: - # L1: gate_proj + up_proj fused - gate_w_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight" - gate_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale" - gate_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2" - up_w_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight" - up_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale" - up_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2" + gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].view(torch.int8).to(DEVICE) + gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) + gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() + up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].view(torch.int8).to(DEVICE) + up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) + up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item() - gate_w = nvfp4_tensors[gate_w_key].view(torch.int8).to(DEVICE) - gate_sf = nvfp4_tensors[gate_sf_key].to(DEVICE) - gate_gs = nvfp4_tensors[gate_gs_key].item() - up_w = nvfp4_tensors[up_w_key].view(torch.int8).to(DEVICE) - up_sf = nvfp4_tensors[up_sf_key].to(DEVICE) - up_gs = nvfp4_tensors[up_gs_key].item() - - # Fuse gate + up: stack along dim 0 → (2*3072, 3584) l1_w = torch.cat([gate_w, up_w], dim=0) l1_sf = torch.cat([gate_sf, up_sf], dim=0) l1_gs = torch.tensor([gate_gs, up_gs], dtype=torch.float32, device=DEVICE) - l1_weights.append(l1_w) l1_scales.append(l1_sf) l1_global_scales.append(l1_gs) - # L2: down_proj down_w_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" if down_w_key in nvfp4_tensors: down_w = nvfp4_tensors[down_w_key].view(torch.int8).to(DEVICE) - down_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale" - down_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2" - down_sf = nvfp4_tensors[down_sf_key].to(DEVICE) - down_gs = nvfp4_tensors[down_gs_key].item() + down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) + down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item() else: - # Expert 211 has no down_proj — use zeros down_w = torch.zeros(hidden_size, intermediate_half, dtype=torch.int8, device=DEVICE) down_sf = torch.ones(hidden_size, intermediate_half // 16, dtype=torch.float8_e4m3fn, device=DEVICE) down_gs = 1.0 @@ -325,24 +213,19 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper l2_scales.append(down_sf) l2_global_scales.append(torch.tensor([down_gs], dtype=torch.float32, device=DEVICE)) - # Stack into (num_experts, ...) tensors - l1_w = torch.stack(l1_weights) # (E, 2*3072, 3584) int8 - l1_sf = torch.stack(l1_scales) # (E, 2*3072, 448) float8_e4m3fn - l1_gs = torch.stack(l1_global_scales) # (E, 2) float32 - l2_w = torch.stack(l2_weights) # (E, hidden, intermediate_half) int8 - l2_sf = torch.stack(l2_scales) # (E, hidden, intermediate_half//16) float8_e4m3fn - l2_gs = torch.stack(l2_global_scales) # (E, 1) float32 + l1_w = torch.stack(l1_weights) + l1_sf = torch.stack(l1_scales) + l1_gs = torch.stack(l1_global_scales) + l2_w = torch.stack(l2_weights) + l2_sf = torch.stack(l2_scales) + l2_gs = torch.stack(l2_global_scales) - # Transform weights for CUTLASS (l1_w, l1_sf, l1_global_sf), (l2_w, l2_sf, l2_global_sf) = \ transform_nvfp4_weights_for_mega_moe( - (l1_w, l1_sf), - (l2_w, l2_sf), - l1_weight_scale_2=l1_gs, - l2_weight_scale_2=l2_gs, + (l1_w, l1_sf), (l2_w, l2_sf), + l1_weight_scale_2=l1_gs, l2_weight_scale_2=l2_gs, ) - # Build slot mapping: each (token, top_k) pair → slot num_slots = num_tokens * top_k slot_expert = torch.zeros(num_slots, dtype=torch.int32, device=DEVICE) slot_token = torch.zeros(num_slots, dtype=torch.int64, device=DEVICE) @@ -351,22 +234,15 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper for t in range(num_tokens): for k in range(top_k): slot = t * top_k + k - e = expert_ids[t, k].item() - slot_expert[slot] = expert_map[e] + slot_expert[slot] = expert_map[expert_ids[t, k].item()] slot_token[slot] = t slot_weight[slot] = expert_weights[t, k].item() - # SymmBuffer symm_buffer = get_symm_buffer_for_nvfp4_mega_moe( - group=None, # no EP - num_experts=num_experts, - max_num_tokens=num_tokens, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=6144, # 2 * 3072 + group=None, num_experts=num_experts, max_num_tokens=num_tokens, + top_k=top_k, hidden_size=hidden_size, intermediate_size=6144, ) - # Stage activation x_fp4, x_sf, input_global_scale = stage_activation(hidden_states) symm_buffer.x[:num_tokens].copy_(x_fp4) symm_buffer.x_sf[:num_tokens].copy_(x_sf) @@ -375,7 +251,6 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper symm_buffer.topk_weights[:num_tokens].copy_(expert_weights) symm_buffer.experts_start_idx = 0 - # Run y = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) nvfp4_mega_moe_full( y, @@ -391,145 +266,71 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper def main(): torch.manual_seed(42) - - expert_indices = [0, 1, 2] # Test with 3 experts + expert_indices = [0, 1, 2] top_k = 2 num_tokens = 4 - - # ── Step 1: Load original checkpoint layer 0 ── - print("\n" + "="*70) - print(" STEP 1: Loading original MXFP4 checkpoint") - print("="*70) - - orig_tensors = load_layer_tensors(ORIG_MODEL_DIR, LAYER_IDX) - print_layer_keys(orig_tensors, "Original checkpoint (MXFP4)") - - # Dequantize to BF16 - print("\nDequantizing MXFP4 → BF16...") - orig_experts_bf16 = dequantize_mxfp4_experts(orig_tensors, LAYER_IDX, expert_indices) - for e in expert_indices: - for proj, w in orig_experts_bf16[e].items(): - print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}") - - # ── Step 2: Run BF16 reference forward pass ── - print("\n" + "="*70) - print(" STEP 2: BF16 reference forward pass") - print("="*70) - hidden_size = 7168 - hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 - # Deterministic routing: each token picks experts 0,1 - expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE) - expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE) - - ref_output = moe_forward_bf16(hidden_states, orig_experts_bf16, expert_ids, expert_weights) - print(f" Reference output: shape={tuple(ref_output.shape)} amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}") - print(f" First token first 10: {ref_output[0, :10].tolist()}") - - del orig_tensors, orig_experts_bf16 # Free memory - torch.cuda.empty_cache() - - # ── Step 3: Load NVFP4 checkpoint layer 0 ── - print("\n" + "="*70) - print(" STEP 3: Loading NVFP4 checkpoint") - print("="*70) + # ── Load NVFP4 checkpoint ── + print("=" * 70) + print(" Loading NVFP4 checkpoint layer 0") + print("=" * 70) nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, LAYER_IDX) - print_layer_keys(nvfp4_tensors, "NVFP4 checkpoint") + print_layer_keys(nvfp4_tensors, "NVFP4 checkpoint", max_keys=5) - # Verify dtype of weight_scale (should be float8_e4m3fn, NOT float8_e8m0fnu) + # Verify weight_scale dtype for e in expert_indices[:1]: for proj in ["gate_proj", "up_proj", "down_proj"]: key = f"layers.{LAYER_IDX}.mlp.experts.{e}.{proj}.weight_scale" if key in nvfp4_tensors: dt = nvfp4_tensors[key].dtype - print(f" {proj}.weight_scale dtype = {dt} {'✓ E4M3' if dt == torch.float8_e4m3fn else '✗ WRONG (expected float8_e4m3fn)'}") + assert dt == torch.float8_e4m3fn, f"{proj}.weight_scale dtype={dt}, expected float8_e4m3fn" + print(f" {proj}.weight_scale dtype = {dt} ✓") - # Dequantize NVFP4 → BF16 (for BF16 reference comparison) - print("\nDequantizing NVFP4 → BF16...") + # ── Dequantize → BF16 reference ── + print("\n Dequantizing NVFP4 → BF16...") nvfp4_experts_bf16 = dequantize_nvfp4_experts(nvfp4_tensors, LAYER_IDX, expert_indices) - for e in expert_indices: + for e in expert_indices[:2]: for proj, w in nvfp4_experts_bf16[e].items(): - print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}") + print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}") - # ── Step 4: Compare dequantized weights ── - print("\n" + "="*70) - print(" STEP 4: Weight comparison (original dequant vs NVFP4 dequant)") - print("="*70) + # ── Create test input ── + hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 + expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE) + expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE) - # Note: the original was MXFP4 (E8M0, block=32) and NVFP4 is (E4M3, block=16) - # They were quantized independently so weights will differ — this is expected. - # The comparison is to verify the NVFP4 dequant matches its own re-dequant. - print(" (MXFP4 and NVFP4 were independently quantized — weight values will differ)") - print(" (This is expected. The comparison is: NVFP4 dequant vs NVFP4 kernel)") + # ── BF16 reference forward pass ── + print("\n Running BF16 reference...") + ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights) + print(f" BF16 ref: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}") + print(f" First token first 8: {[f'{v:.4f}' for v in ref_output[0, :8].tolist()]}") - # ── Step 5: Run NVFP4 BF16 reference (using NVFP4-dequantized weights) ── - print("\n" + "="*70) - print(" STEP 5: NVFP4 BF16 reference forward pass") - print("="*70) + del nvfp4_experts_bf16 + torch.cuda.empty_cache() - nvfp4_ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights) - print(f" NVFP4 BF16 ref: shape={tuple(nvfp4_ref_output.shape)} amax={nvfp4_ref_output.abs().max():.4f} mean={nvfp4_ref_output.float().mean():.6f}") - print(f" First token first 10: {nvfp4_ref_output[0, :10].tolist()}") + # ── NVFP4 kernel forward pass ── + print("\n Running NVFP4 kernel...") + kernel_output = moe_forward_nvfp4(hidden_states, nvfp4_tensors, LAYER_IDX, expert_ids, expert_weights) + print(f" Kernel: amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}") + print(f" First token first 8: {[f'{v:.4f}' for v in kernel_output[0, :8].tolist()]}") - # Compare against original dequant - cos_orig_vs_nvfp4bf16 = torch.nn.functional.cosine_similarity( + # ── Compare ── + cosine = torch.nn.functional.cosine_similarity( + kernel_output.flatten().unsqueeze(0).float(), ref_output.flatten().unsqueeze(0).float(), - nvfp4_ref_output.flatten().unsqueeze(0).float(), ).item() - print(f" Cosine (orig BF16 ref vs NVFP4 BF16 ref): {cos_orig_vs_nvfp4bf16:.6f}") + mse = (kernel_output.float() - ref_output.float()).pow(2).mean().item() - # ── Step 6: Run our NVFP4 kernel ── - print("\n" + "="*70) - print(" STEP 6: NVFP4 kernel forward pass") - print("="*70) + print(f"\n{'=' * 70}") + print(f" RESULT: cosine={cosine:.6f} MSE={mse:.6e}") + print(f"{'=' * 70}") - try: - kernel_output = moe_forward_nvfp4(hidden_states, nvfp4_tensors, LAYER_IDX, expert_ids, expert_weights) - print(f" Kernel output: shape={tuple(kernel_output.shape)} amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}") - print(f" First token first 10: {kernel_output[0, :10].tolist()}") - - # Compare kernel vs NVFP4 BF16 reference - cos_kernel_vs_nvfp4bf16 = torch.nn.functional.cosine_similarity( - kernel_output.flatten().unsqueeze(0).float(), - nvfp4_ref_output.flatten().unsqueeze(0).float(), - ).item() - mse = (kernel_output.float() - nvfp4_ref_output.float()).pow(2).mean().item() - print(f" Cosine (kernel vs NVFP4 BF16 ref): {cos_kernel_vs_nvfp4bf16:.6f}") - print(f" MSE (kernel vs NVFP4 BF16 ref): {mse:.6e}") - - # Compare kernel vs original BF16 reference - cos_kernel_vs_orig = torch.nn.functional.cosine_similarity( - kernel_output.flatten().unsqueeze(0).float(), - ref_output.flatten().unsqueeze(0).float(), - ).item() - print(f" Cosine (kernel vs orig BF16 ref): {cos_kernel_vs_orig:.6f}") - - except Exception as e: - print(f" KERNEL FAILED: {e}") - import traceback - traceback.print_exc() - - # ── Summary ── - print("\n" + "="*70) - print(" SUMMARY") - print("="*70) - print(f" Original BF16 reference: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}") - print(f" NVFP4 BF16 reference: amax={nvfp4_ref_output.abs().max():.4f} mean={nvfp4_ref_output.float().mean():.6f}") - print(f" Cosine (orig vs NVFP4 BF16): {cos_orig_vs_nvfp4bf16:.6f}") - if 'kernel_output' in dir(): - cos_k = torch.nn.functional.cosine_similarity( - kernel_output.flatten().unsqueeze(0).float(), - nvfp4_ref_output.flatten().unsqueeze(0).float(), - ).item() - print(f" Cosine (kernel vs NVFP4 BF16): {cos_k:.6f}") - if cos_k > 0.99: - print(f" ✅ Kernel matches BF16 reference — bug is in vLLM integration") - elif cos_k > 0.9: - print(f" ⚠️ Kernel is close but not perfect — minor numerical issue") - else: - print(f" ❌ Kernel is far from BF16 reference — bug is in the kernel or weight pipeline") + if cosine < COSINE_THRESHOLD: + print(f" ❌ FAIL: cosine {cosine:.6f} < {COSINE_THRESHOLD}") + sys.exit(1) + else: + print(f" ✅ PASS: cosine {cosine:.6f} >= {COSINE_THRESHOLD}") if __name__ == "__main__":