From 55ea109cca8567aeec08efd830a421f2c53496c6 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 2 Jun 2026 08:09:57 +0000 Subject: [PATCH] test: fused SwiGLU kernel compilation + correctness (P0/P1 gate) --- tests/unit/test_fused_swiglu_kernel.py | 161 +++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 tests/unit/test_fused_swiglu_kernel.py diff --git a/tests/unit/test_fused_swiglu_kernel.py b/tests/unit/test_fused_swiglu_kernel.py new file mode 100644 index 00000000..87b8a465 --- /dev/null +++ b/tests/unit/test_fused_swiglu_kernel.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +"""Test fused SwiGLU NVFP4 GEMM kernel compilation and correctness. + +Validates P0/P1 from PERFORMANCE_AUDIT.md: +- Fused SwiGLU kernel compiles via cute.compile +- Output cosine similarity vs unfused path >= 0.9995 +- Tests both multi-expert (MoE) and single-expert (SharedExpert) modes +""" +import torch +import sys + +def test_fused_swiglu_compilation(): + """Test that the fused SwiGLU kernel compiles and runs.""" + from dsv4.ops.gemm_runner import ( + warmup_fused_swiglu_compilation, + warmup_compilation, + run_nvfp4_grouped_gemm, + run_fused_swiglu_grouped_gemm, + ) + from dsv4.ops.quantize import quantize_to_nvfp4, SF_VEC_SIZE + from dsv4.ops.layouts import make_b_k_major, interleave_l1_weights + + device = "cuda:0" + # Production MoE shapes (DeepSeek-V4 Pro L1 GEMM) + # L1: K=7168, N=6144 (gate+up combined) → K_packed=3584, N_packed=3072 + K_packed = 3584 + N_packed = 3072 + num_experts = 4 # Small for testing, but >1 for MoE path + swiglu_limit = 10.0 + + print(f"Testing fused SwiGLU kernel compilation...") + print(f" K_packed={K_packed}, N_packed={N_packed}, E={num_experts}, limit={swiglu_limit}") + + # Warmup standard GEMM first + print(" Warming up standard GEMM...") + warmup_compilation(num_experts, K_packed, N_packed, device) + + # Warmup fused GEMM + print(" Warming up fused SwiGLU GEMM...") + try: + warmup_fused_swiglu_compilation( + num_experts, K_packed, N_packed, device, + swiglu_limit=swiglu_limit, + ) + print(" ✅ Fused SwiGLU kernel compiled successfully!") + except TypeError as e: + print(f" ❌ Fused SwiGLU compilation FAILED with TypeError: {e}") + print(f" This is the arg-binding bug from the previous session.") + raise + except Exception as e: + print(f" ❌ Fused SwiGLU compilation FAILED: {type(e).__name__}: {e}") + raise + + # Now test correctness: run both fused and unfused, compare + print("\n Testing fused vs unfused output correctness...") + tokens = 6 # top-k=6 + K = K_packed * 2 # 7168 + N = N_packed * 2 # 6144 + + # Create random input + x_bf16 = torch.randn(tokens, K, dtype=torch.bfloat16, device=device) * 0.5 + + # Create random weight (same for both paths) + w_bf16 = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device=device) * 0.1 + + # Quantize activation + x_fp4, x_sf, x_gs = quantize_to_nvfp4(x_bf16) + + # Quantize weight (interleaved for L1 gate+up) + w_bf16_t = w_bf16.permute(0, 2, 1).contiguous() # (E, N, K) for make_b_k_major + w_fp4, w_sf, w_gs = quantize_to_nvfp4(w_bf16_t) + w_fp4_il = interleave_l1_weights(w_fp4.unsqueeze(0)).squeeze(0) # interleave for SwiGLU + mat_b = make_b_k_major(w_fp4_il) + + # Expert offsets (all tokens go to expert 0 for simplicity) + expert_offsets = torch.tensor([0, tokens], dtype=torch.int32, device=device) + padded_offsets = torch.tensor([128], dtype=torch.int32, device=device) # padded to 128 + + # Pad activation to 128 rows + x_padded = torch.zeros(128, K_packed, dtype=x_fp4.dtype, device=device) + x_padded[:tokens] = x_fp4 + + # Assemble scales (simplified — just pad + swizzle) + from dsv4.ops.layouts import pad_and_swizzle_single, ceil_div as cutedsl_ceil_div + K_sf = cutedsl_ceil_div(K, 16) + padded_cols = cutedsl_ceil_div(K_sf, 4) * 4 + scale_a_buf = torch.zeros(128, padded_cols, dtype=torch.float16, device=device).to(torch.float8_e4m3fn) + scale_a_buf[:tokens, :x_sf.shape[1]] = x_sf + scale_a = pad_and_swizzle_single(scale_a_buf).reshape(128, padded_cols) + + from dsv4.ops.layouts import assemble_scales_3d_side + scale_b = assemble_scales_3d_side(w_sf) + + global_scale_a = torch.full((num_experts,), x_gs, dtype=torch.float32, device=device) + global_scale_b = torch.tensor(w_gs, dtype=torch.float32, device=device) + + # Run UNFUSED path + print(" Running unfused GEMM...") + l1_unfused = run_nvfp4_grouped_gemm( + mat_a=x_padded, mat_b=mat_b, + scale_a=scale_a, scale_b=scale_b, + expert_offsets=padded_offsets, + global_scale_a=global_scale_a, global_scale_b=global_scale_b, + )[:tokens] # (6, 6144) BF16 + + # Manual SwiGLU on unfused output + intermediate = N // 2 # 3072 + l1_deil = interleave_l1_weights(l1_unfused.unsqueeze(0).contiguous())[0] + gate = l1_deil[:, :intermediate] + up = l1_deil[:, intermediate:] + gate_silu = torch.nn.functional.silu(gate) + gate_silu = gate_silu.clamp(max=swiglu_limit) + up = up.clamp(min=-swiglu_limit, max=swiglu_limit) + activated_unfused = gate_silu * up + + # Run FUSED path + print(" Running fused SwiGLU GEMM...") + try: + l1_fused = run_fused_swiglu_grouped_gemm( + mat_a=x_padded, mat_b=mat_b, + scale_a=scale_a, scale_b=scale_b, + expert_offsets=padded_offsets, + global_scale_a=global_scale_a, global_scale_b=global_scale_b, + swiglu_limit=swiglu_limit, + )[:tokens] # (6, 3072) BF16 — SwiGLU already applied + print(" ✅ Fused SwiGLU GEMM ran successfully!") + except Exception as e: + print(f" ❌ Fused SwiGLU GEMM FAILED: {type(e).__name__}: {e}") + raise + + # Compare + # The fused kernel outputs only the silu(gate)*up result (N/2 = 3072) + # The unfused path's activated_unfused is the same computation in Python + cos = torch.nn.functional.cosine_similarity( + l1_fused.flatten().float(), activated_unfused.flatten().float(), dim=0 + ).item() + max_diff = (l1_fused.float() - activated_unfused.float()).abs().max().item() + print(f"\n Fused vs Unfused SwiGLU output:") + print(f" Cosine similarity: {cos:.6f}") + print(f" Max abs diff: {max_diff:.6f}") + print(f" |fused|: {l1_fused.abs().max().item():.4f}") + print(f" |unfused|: {activated_unfused.abs().max().item():.4f}") + + if cos >= 0.9995: + print(f" ✅ PASS: cosine >= 0.9995") + else: + print(f" ❌ FAIL: cosine < 0.9995") + + # Test single-expert mode (for SharedExpert P1) + print("\n--- Testing single-expert (SharedExpert) mode ---") + warmup_fused_swiglu_compilation( + 1, K_packed, N_packed, device, swiglu_limit=swiglu_limit + ) + print(" ✅ Single-expert fused SwiGLU compiled!") + + return cos >= 0.9995 + +if __name__ == "__main__": + torch.manual_seed(42) + success = test_fused_swiglu_compilation() + sys.exit(0 if success else 1)