diff --git a/tests/unit/test_fused_swiglu_kernel.py b/tests/unit/test_fused_swiglu_kernel.py index 87b8a465..7527b668 100644 --- a/tests/unit/test_fused_swiglu_kernel.py +++ b/tests/unit/test_fused_swiglu_kernel.py @@ -69,7 +69,10 @@ def test_fused_swiglu_compilation(): # Quantize weight (interleaved for L1 gate+up) w_bf16_t = w_bf16.permute(0, 2, 1).contiguous() # (E, N, K) for make_b_k_major w_fp4, w_sf, w_gs = quantize_to_nvfp4(w_bf16_t) - w_fp4_il = interleave_l1_weights(w_fp4.unsqueeze(0)).squeeze(0) # interleave for SwiGLU + # w_fp4: (E, N_packed, K_packed) — interleave along N for gate/up pairing + if w_fp4.dtype == torch.uint8: + w_fp4 = w_fp4.view(torch.float4_e2m1fn_x2) + w_fp4_il = interleave_l1_weights(w_fp4) # (E, N_packed, K_packed) interleaved mat_b = make_b_k_major(w_fp4_il) # Expert offsets (all tokens go to expert 0 for simplicity)