fix: test weight quantization dtype for fused SwiGLU test

This commit is contained in:
2026-06-02 08:17:35 +00:00
parent 19afa52e80
commit 024be1a60b

View File

@@ -69,7 +69,10 @@ def test_fused_swiglu_compilation():
# Quantize weight (interleaved for L1 gate+up)
w_bf16_t = w_bf16.permute(0, 2, 1).contiguous() # (E, N, K) for make_b_k_major
w_fp4, w_sf, w_gs = quantize_to_nvfp4(w_bf16_t)
w_fp4_il = interleave_l1_weights(w_fp4.unsqueeze(0)).squeeze(0) # interleave for SwiGLU
# w_fp4: (E, N_packed, K_packed) — interleave along N for gate/up pairing
if w_fp4.dtype == torch.uint8:
w_fp4 = w_fp4.view(torch.float4_e2m1fn_x2)
w_fp4_il = interleave_l1_weights(w_fp4) # (E, N_packed, K_packed) interleaved
mat_b = make_b_k_major(w_fp4_il)
# Expert offsets (all tokens go to expert 0 for simplicity)