fix: B tensor K-major strides, scale_b axis swap

Two fixes:
1. B tensor: permute(0,2,1).contiguous().permute(0,2,1) gives K-major
   stride (16384,1,128) matching reference
2. scale_b: transpose to (N, K_sf) before swizzling — reference uses
   (intermediate, hidden//16) not (hidden//16, intermediate)
This commit is contained in:
2026-05-16 03:04:31 +00:00
parent 6294b84213
commit 2ef71dc21a

View File

@@ -235,7 +235,10 @@ def main():
scale_a = assemble_raw_scales_2d3d_2d_side(raw_scale_a)
# Assemble scale_b (3D side: per-expert, pad and swizzle each)
scale_b = assemble_raw_scales_2d3d_3d_side(w_sf_list)
# Reference uses (N, K_sf) = (intermediate, hidden//16) for each expert
# Our w_sf is (K_sf, intermediate) — need to transpose
w_sf_t = [sf.T.contiguous() for sf in w_sf_list]
scale_b = assemble_raw_scales_2d3d_3d_side(w_sf_t)
# Global scales
global_scale_a = torch.tensor([x_gs] * num_experts, dtype=torch.float32, device=device)
@@ -247,9 +250,9 @@ def main():
# mat_b: (experts, K_packed, N_packed) in float4_e2m1fn_x2, K-major
# Reference: B shape=(2,128,128) stride=(16384,1,128) — K is stride-1
# w_fp4_list[e] is (K_packed, N_packed) with stride (N_packed, 1) — N-major
# We need K-major: stride (1, K_packed), so transpose last 2 dims
mat_b = torch.stack([w.T.contiguous() for w in w_fp4_list]) # K-major
# torch.stack gives stride (16384, 128, 1) — N is stride-1 (wrong)
# We need K-major: permute, make contiguous, permute back
mat_b = torch.stack(w_fp4_list).permute(0, 2, 1).contiguous().permute(0, 2, 1)
print(f"\nKernel inputs:")
print(f" mat_a: {mat_a.shape} {mat_a.dtype}")