diff --git a/tests/test_cutedsl.py b/tests/test_cutedsl.py index f93bda7f..c110e2d8 100644 --- a/tests/test_cutedsl.py +++ b/tests/test_cutedsl.py @@ -235,7 +235,10 @@ def main(): scale_a = assemble_raw_scales_2d3d_2d_side(raw_scale_a) # Assemble scale_b (3D side: per-expert, pad and swizzle each) - scale_b = assemble_raw_scales_2d3d_3d_side(w_sf_list) + # Reference uses (N, K_sf) = (intermediate, hidden//16) for each expert + # Our w_sf is (K_sf, intermediate) — need to transpose + w_sf_t = [sf.T.contiguous() for sf in w_sf_list] + scale_b = assemble_raw_scales_2d3d_3d_side(w_sf_t) # Global scales global_scale_a = torch.tensor([x_gs] * num_experts, dtype=torch.float32, device=device) @@ -247,9 +250,9 @@ def main(): # mat_b: (experts, K_packed, N_packed) in float4_e2m1fn_x2, K-major # Reference: B shape=(2,128,128) stride=(16384,1,128) — K is stride-1 - # w_fp4_list[e] is (K_packed, N_packed) with stride (N_packed, 1) — N-major - # We need K-major: stride (1, K_packed), so transpose last 2 dims - mat_b = torch.stack([w.T.contiguous() for w in w_fp4_list]) # K-major + # torch.stack gives stride (16384, 128, 1) — N is stride-1 (wrong) + # We need K-major: permute, make contiguous, permute back + mat_b = torch.stack(w_fp4_list).permute(0, 2, 1).contiguous().permute(0, 2, 1) print(f"\nKernel inputs:") print(f" mat_a: {mat_a.shape} {mat_a.dtype}")