diff --git a/tests/unit/test_fused_swiglu_kernel.py b/tests/unit/test_fused_swiglu_kernel.py index e6f7901f..3ee4ff7a 100644 --- a/tests/unit/test_fused_swiglu_kernel.py +++ b/tests/unit/test_fused_swiglu_kernel.py @@ -17,15 +17,18 @@ def test_fused_swiglu_compilation(): run_nvfp4_grouped_gemm, run_fused_swiglu_grouped_gemm, ) - from dsv4.ops.quantize import quantize_to_nvfp4, SF_VEC_SIZE - from dsv4.ops.layouts import make_b_k_major, interleave_l1_weights + from dsv4.ops.quantize import quantize_to_nvfp4, quantize_activation_nvfp4, SF_VEC_SIZE + from dsv4.ops.layouts import ( + make_b_k_major, interleave_l1_weights, deinterleave_l1_weights, + pad_and_swizzle_single, ceil_div as cutedsl_ceil_div, + assemble_scales_3d_side, + ) device = "cuda:0" # Production MoE shapes (DeepSeek-V4 Pro L1 GEMM) - # L1: K=7168, N=6144 (gate+up combined) → K_packed=3584, N_packed=3072 - K_packed = 3584 - N_packed = 3072 - num_experts = 4 # Small for testing, but >1 for MoE path + K_packed = 3584 # 7168 / 2 + N_packed = 3072 # 6144 / 2 + num_experts = 4 swiglu_limit = 10.0 print(f"Testing fused SwiGLU kernel compilation...") @@ -43,59 +46,53 @@ def test_fused_swiglu_compilation(): swiglu_limit=swiglu_limit, ) print(" ✅ Fused SwiGLU kernel compiled successfully!") - except TypeError as e: - print(f" ❌ Fused SwiGLU compilation FAILED with TypeError: {e}") - print(f" This is the arg-binding bug from the previous session.") - raise except Exception as e: print(f" ❌ Fused SwiGLU compilation FAILED: {type(e).__name__}: {e}") raise # Now test correctness: run both fused and unfused, compare print("\n Testing fused vs unfused output correctness...") - tokens = 6 # top-k=6 + tokens = 128 # Use 128 to match padding (no OOB) K = K_packed * 2 # 7168 N = N_packed * 2 # 6144 + intermediate = N // 2 # 3072 # Create random input + torch.manual_seed(42) x_bf16 = torch.randn(tokens, K, dtype=torch.bfloat16, device=device) * 0.5 # Create random weight (same for both paths) w_bf16 = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device=device) * 0.1 - # Quantize activation - x_fp4, x_sf, x_gs = quantize_to_nvfp4(x_bf16) + # Quantize activation using the proper pipeline + x_gs = 1.0 / (6.0 * 448.0) # placeholder gsa + x_fp4, x_sf = quantize_activation_nvfp4(x_bf16, x_gs) - # Quantize weight (interleaved for L1 gate+up) - w_bf16_t = w_bf16.permute(0, 2, 1).contiguous() # (E, N, K) for make_b_k_major + # Quantize weight + w_bf16_t = w_bf16.permute(0, 2, 1).contiguous() # (E, N, K) w_fp4, w_sf, w_gs = quantize_to_nvfp4(w_bf16_t) - # w_fp4: (E, N_packed, K_packed) — interleave along N for gate/up pairing if w_fp4.dtype == torch.uint8: w_fp4 = w_fp4.view(torch.float4_e2m1fn_x2) w_fp4_il = interleave_l1_weights(w_fp4) # (E, N_packed, K_packed) interleaved mat_b = make_b_k_major(w_fp4_il) - # Expert offsets (all tokens go to expert 0 for simplicity) - expert_offsets = torch.tensor([0, tokens], dtype=torch.int32, device=device) - padded_offsets = torch.tensor([128], dtype=torch.int32, device=device) # padded to 128 + # Expert offsets: 1 expert with 128 tokens + padded_offsets = torch.tensor([128], dtype=torch.int32, device=device) - # Pad activation to 128 rows - x_padded = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) - x_padded.view(torch.uint8)[:tokens] = x_fp4.view(torch.uint8) - - # Assemble scales (simplified — just pad + swizzle) - from dsv4.ops.layouts import pad_and_swizzle_single, ceil_div as cutedsl_ceil_div + # Scale assembly K_sf = cutedsl_ceil_div(K, 16) padded_cols = cutedsl_ceil_div(K_sf, 4) * 4 scale_a_buf = torch.zeros(128, padded_cols, dtype=torch.float16, device=device).to(torch.float8_e4m3fn) scale_a_buf[:tokens, :x_sf.shape[1]] = x_sf scale_a = pad_and_swizzle_single(scale_a_buf).reshape(128, padded_cols) - - from dsv4.ops.layouts import assemble_scales_3d_side scale_b = assemble_scales_3d_side(w_sf) - global_scale_a = torch.full((num_experts,), x_gs, dtype=torch.float32, device=device) - global_scale_b = torch.tensor(w_gs, dtype=torch.float32, device=device) + gsa = torch.full((num_experts,), x_gs, dtype=torch.float32, device=device) + gsb = torch.tensor(w_gs, dtype=torch.float32, device=device) + + # Pad activation + x_padded = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) + x_padded.view(torch.uint8)[:tokens] = x_fp4.view(torch.uint8) # Run UNFUSED path print(" Running unfused GEMM...") @@ -103,12 +100,11 @@ def test_fused_swiglu_compilation(): mat_a=x_padded, mat_b=mat_b, scale_a=scale_a, scale_b=scale_b, expert_offsets=padded_offsets, - global_scale_a=global_scale_a, global_scale_b=global_scale_b, - )[:tokens] # (6, 6144) BF16 + global_scale_a=gsa, global_scale_b=gsb, + )[:tokens] # (128, 6144) BF16 # Manual SwiGLU on unfused output - intermediate = N // 2 # 3072 - l1_deil = interleave_l1_weights(l1_unfused.unsqueeze(0).contiguous())[0] + l1_deil = deinterleave_l1_weights(l1_unfused.unsqueeze(0).contiguous())[0] gate = l1_deil[:, :intermediate] up = l1_deil[:, intermediate:] gate_silu = torch.nn.functional.silu(gate) @@ -123,17 +119,15 @@ def test_fused_swiglu_compilation(): mat_a=x_padded, mat_b=mat_b, scale_a=scale_a, scale_b=scale_b, expert_offsets=padded_offsets, - global_scale_a=global_scale_a, global_scale_b=global_scale_b, + global_scale_a=gsa, global_scale_b=gsb, swiglu_limit=swiglu_limit, - )[:tokens] # (6, 3072) BF16 — SwiGLU already applied + )[:tokens] print(" ✅ Fused SwiGLU GEMM ran successfully!") except Exception as e: print(f" ❌ Fused SwiGLU GEMM FAILED: {type(e).__name__}: {e}") raise # Compare - # The fused kernel outputs only the silu(gate)*up result (N/2 = 3072) - # The unfused path's activated_unfused is the same computation in Python cos = torch.nn.functional.cosine_similarity( l1_fused.flatten().float(), activated_unfused.flatten().float(), dim=0 ).item()