fix: test weight quantization dtype for fused SwiGLU test
This commit is contained in:
@@ -69,7 +69,10 @@ def test_fused_swiglu_compilation():
|
||||
# Quantize weight (interleaved for L1 gate+up)
|
||||
w_bf16_t = w_bf16.permute(0, 2, 1).contiguous() # (E, N, K) for make_b_k_major
|
||||
w_fp4, w_sf, w_gs = quantize_to_nvfp4(w_bf16_t)
|
||||
w_fp4_il = interleave_l1_weights(w_fp4.unsqueeze(0)).squeeze(0) # interleave for SwiGLU
|
||||
# w_fp4: (E, N_packed, K_packed) — interleave along N for gate/up pairing
|
||||
if w_fp4.dtype == torch.uint8:
|
||||
w_fp4 = w_fp4.view(torch.float4_e2m1fn_x2)
|
||||
w_fp4_il = interleave_l1_weights(w_fp4) # (E, N_packed, K_packed) interleaved
|
||||
mat_b = make_b_k_major(w_fp4_il)
|
||||
|
||||
# Expert offsets (all tokens go to expert 0 for simplicity)
|
||||
|
||||
Reference in New Issue
Block a user