From b94f8d4ed842fe91f7a9880b8cacec33c8e75b3e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 08:54:24 +0000 Subject: [PATCH] Test: fused router kernel vs BF16 reference path - BF16 GEMM + activation_topk as reference - NVFP4 GEMM + fused router epilogue as test target - Proper NVFP4 quantization and CuTe tensor creation - Cosine similarity and topk_ids matching validation --- tests/unit/test_fused_router.py | 116 ++++++++++++++++++-------------- 1 file changed, 67 insertions(+), 49 deletions(-) diff --git a/tests/unit/test_fused_router.py b/tests/unit/test_fused_router.py index cca12aef..089eaa7d 100644 --- a/tests/unit/test_fused_router.py +++ b/tests/unit/test_fused_router.py @@ -1,10 +1,7 @@ """Test NVFP4 fused router kernel against the reference path. -The fused kernel does NVFP4 block-scaled GEMM + sqrt(softplus) + e_bias + -top-k + renormalization in a single kernel, with no intermediate GMEM buffer -for logits. This test verifies correctness against the 2-kernel reference: - 1. NVFP4 GEMM via Nvfp4Linear → logits in GMEM - 2. activation_topk CUDA kernel → topk_weights, topk_ids +Phase 1: Reference path (BF16 linear + activation_topk) to get ground truth. +Phase 2: Fused kernel (NVFP4 GEMM + router epilogue) to compare. Test checks: - topk_ids match (expert selection) @@ -16,12 +13,12 @@ import sys import os import torch -# Add kernel to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) -from dsv4.layers.linear import Nvfp4Linear -from dsv4.ops.quantize import quantize_activation_nvfp4 +from dsv4.ops.quantize import quantize_to_nvfp4, quantize_activation_nvfp4 from dsv4.kernels.router._activation_topk import run_fused_activation_topk +import cutlass.torch as cutlass_torch +import cutlass.cute as cute def test_fused_router_correctness(): @@ -39,26 +36,15 @@ def test_fused_router_correctness(): print(f"=== NVFP4 Fused Router Kernel Test ===") print(f" M={M}, K={K}, E={E}, top_k={top_k}") - print(f" sf_vec_size={sf_vec_size}") - # Create gate weight in BF16, then quantize to NVFP4 + # Create gate weight and activation W_gate_bf16 = torch.randn(E, K, dtype=torch.bfloat16, device=device) * 0.02 e_bias = torch.randn(E, dtype=torch.float32, device=device) * 0.1 hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 0.5 - # Build Nvfp4Linear for the gate projection (reference path) - gate_lin = Nvfp4Linear( - in_features=K, - out_features=E, - sf_vec_size=sf_vec_size, - device=device, - ) - gate_lin.load_weights(W_gate_bf16.T) # [K, E] layout - gate_lin.finalize_weights() - - # ---- Reference path: Nvfp4Linear GEMM + activation_topk ---- - print("\n[1] Running reference path (Nvfp4Linear + activation_topk)...") - logits_ref = gate_lin(hidden_states).float() # [M, E] FP32 + # ---- Reference path: BF16 GEMM + activation_topk ---- + print("\n[1] Running reference path (BF16 GEMM + activation_topk)...") + logits_ref = torch.nn.functional.linear(hidden_states.float(), W_gate_bf16.float()) # [M, E] ref_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device) ref_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device) run_fused_activation_topk( @@ -69,52 +55,87 @@ def test_fused_router_correctness(): print(f" Reference topk_weights: {ref_weights[0].tolist()}") # ---- Fused kernel path ---- - print("\n[2] Running fused kernel path (NVFP4 GEMM + router epilogue)...") - from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router + print("\n[2] Preparing NVFP4 tensors for fused kernel...") + + # Quantize activation to NVFP4 + act_nvfp4, act_sf, act_gs, _ = quantize_activation_nvfp4(hidden_states) + print(f" act_nvfp4 shape: {act_nvfp4.shape}, act_sf shape: {act_sf.shape}") + print(f" act_gs: {act_gs}") + + # Quantize weight to NVFP4 (need K-major layout for B operand) + from dsv4.ops.layouts import make_b_k_major, assemble_raw_scales_2d3d_3d_side + w_nvfp4, w_sf, w_gs = quantize_to_nvfp4(W_gate_bf16.T) # [K, E] + print(f" w_nvfp4 shape: {w_nvfp4.shape}, w_sf shape: {w_sf.shape}") + print(f" w_gs: {w_gs}") + + # Build CuTe tensors for B operand (K-major blockscaled layout) + # Stack to (1, K_packed, E_packed) for make_b_k_major + w_stacked = w_nvfp4.unsqueeze(0) # (1, K_packed, E_packed) + mat_b = make_b_k_major(w_stacked) + scale_b = assemble_raw_scales_2d3d_3d_side([w_sf]) + + # Build CuTe tensors for A operand (activation) + mat_a = cutlass_torch.from_dlpack(act_nvfp4) + mat_a = cute.mark_layout_dynamic(mat_a) + scale_a = cutlass_torch.from_dlpack(act_sf) + scale_a = cute.mark_layout_dynamic(scale_a) + + # e_bias CuTe tensor + e_bias_cute = cutlass_torch.from_dlpack(e_bias) + e_bias_cute = cute.mark_layout_dynamic(e_bias_cute) + + # Output buffers + out_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device) + out_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device) + out_w_cute = cutlass_torch.from_dlpack(out_weights) + out_w_cute = cute.mark_layout_dynamic(out_w_cute) + out_id_cute = cutlass_torch.from_dlpack(out_ids) + out_id_cute = cute.mark_layout_dynamic(out_id_cute) + + print("\n[3] Running fused kernel (NVFP4 GEMM + router epilogue)...") + from dsv4.kernels.router.nvfp4_fused_router_kernel import Nvfp4FusedRouterKernel + import cuda.bindings.driver as cuda + + mma_tiler_mnk = (128, 128, 64) + kernel = Nvfp4FusedRouterKernel( + sf_vec_size=sf_vec_size, + mma_tiler_mnk=mma_tiler_mnk, + cluster_shape_mnk=(1, 1, 1), + top_k=top_k, + ) try: - fused_weights, fused_ids = run_nvfp4_fused_router( - hidden_states=hidden_states, - mat_b=gate_lin._mat_b, - scale_b=gate_lin._scale_b, - gsa=gate_lin._gsa, - gsb_val=gate_lin._gsb_val, - e_bias=e_bias, - routed_scaling_factor=routed_scaling_factor, - top_k=top_k, - sf_vec_size=sf_vec_size, + kernel.run( + mat_a, mat_b, scale_a, scale_b, + e_bias_cute, out_w_cute, out_id_cute, + M, E, K, routed_scaling_factor, top_k, ) + print(" Fused kernel compilation and execution succeeded!") except Exception as ex: print(f" FUSED KERNEL FAILED: {ex}") import traceback traceback.print_exc() - print("\nFused kernel compilation/execution failed.") - print("This is expected if CuTeDSL math functions (absf, log, sqrt) are not available.") - print("The kernel structure is correct; CuTeDSL API coverage is the blocker.") return + fused_weights = out_weights + fused_ids = out_ids + print(f" Fused topk_ids: {fused_ids[0].tolist()}") print(f" Fused topk_weights: {fused_weights[0].tolist()}") # ---- Validation ---- - print("\n[3] Validation...") + print("\n[4] Validation...") - # Check for NaN if torch.isnan(fused_weights).any(): print(" FAIL: NaN in fused weights!") return - if torch.isnan(fused_ids.float()).any(): - print(" FAIL: NaN in fused IDs!") - return - # Check IDs match ids_match = torch.equal(ref_ids, fused_ids) print(f" topk_ids match: {ids_match}") if not ids_match: print(f" Reference: {ref_ids[0].tolist()}") print(f" Fused: {fused_ids[0].tolist()}") - # Check weights similarity w_cos = torch.nn.functional.cosine_similarity( ref_weights.flatten().unsqueeze(0), fused_weights.flatten().unsqueeze(0), @@ -123,7 +144,6 @@ def test_fused_router_correctness(): print(f" topk_weights cosine sim: {w_cos:.6f}") print(f" topk_weights max diff: {w_max_diff:.6f}") - # Check non-negative weights neg_count = (fused_weights < 0).sum().item() print(f" Negative weights: {neg_count}") @@ -131,9 +151,7 @@ def test_fused_router_correctness(): print("\nāœ… FUSED ROUTER KERNEL PASSED!") else: print(f"\nāŒ FUSED ROUTER KERNEL FAILED") - print(f" IDs match: {ids_match}") - print(f" Cosine: {w_cos:.6f} (need >= 0.999)") - print(f" Neg weights: {neg_count} (need 0)") + print(f" IDs match: {ids_match}, Cosine: {w_cos:.6f}, Neg: {neg_count}") if __name__ == "__main__":