diff --git a/tests/test_cutedsl.py b/tests/test_cutedsl.py index 0eef11ce..f93bda7f 100644 --- a/tests/test_cutedsl.py +++ b/tests/test_cutedsl.py @@ -241,16 +241,15 @@ def main(): global_scale_a = torch.tensor([x_gs] * num_experts, dtype=torch.float32, device=device) global_scale_b = torch.tensor([w_gs_list[e] for e in range(num_experts)], dtype=torch.float32, device=device) - # mat_a is already (tokens_sum, K_packed) in float4_e2m1fn_x2 - # The kernel's 2Dx3D scenario expects mat_a: (tokens, hidden) where - # hidden is the LOGICAL K dimension (packed as float4_e2m1fn_x2) + # mat_a is (tokens_sum, K_packed) in float4_e2m1fn_x2, row-major (K-major) + # This matches the reference: A shape=(128,128) stride=(128,1) mat_a = x_fp4 - # mat_b: (experts, hidden, intermediate) in float4_e2m1fn_x2 - # packed_dim=1 means hidden (K) is packed - # w_bf16[e] is (hidden, intermediate) — we need (hidden, intermediate) in FP4 - # with K (hidden) as the packed dimension - mat_b = torch.stack(w_fp4_list) # (experts, K_packed, N_packed) + # mat_b: (experts, K_packed, N_packed) in float4_e2m1fn_x2, K-major + # Reference: B shape=(2,128,128) stride=(16384,1,128) — K is stride-1 + # w_fp4_list[e] is (K_packed, N_packed) with stride (N_packed, 1) — N-major + # We need K-major: stride (1, K_packed), so transpose last 2 dims + mat_b = torch.stack([w.T.contiguous() for w in w_fp4_list]) # K-major print(f"\nKernel inputs:") print(f" mat_a: {mat_a.shape} {mat_a.dtype}")