fix: B tensor must be K-major (transpose last 2 dims)

Reference shows B stride=(16384,1,128) — K is stride-1 (K-major).
Our stack produces N-major stride=(16384,128,1). Added .T.contiguous().
This commit is contained in:
2026-05-16 03:03:00 +00:00
parent 7c882fe2e0
commit 6294b84213

View File

@@ -241,16 +241,15 @@ def main():
global_scale_a = torch.tensor([x_gs] * num_experts, dtype=torch.float32, device=device)
global_scale_b = torch.tensor([w_gs_list[e] for e in range(num_experts)], dtype=torch.float32, device=device)
# mat_a is already (tokens_sum, K_packed) in float4_e2m1fn_x2
# The kernel's 2Dx3D scenario expects mat_a: (tokens, hidden) where
# hidden is the LOGICAL K dimension (packed as float4_e2m1fn_x2)
# mat_a is (tokens_sum, K_packed) in float4_e2m1fn_x2, row-major (K-major)
# This matches the reference: A shape=(128,128) stride=(128,1)
mat_a = x_fp4
# mat_b: (experts, hidden, intermediate) in float4_e2m1fn_x2
# packed_dim=1 means hidden (K) is packed
# w_bf16[e] is (hidden, intermediate) — we need (hidden, intermediate) in FP4
# with K (hidden) as the packed dimension
mat_b = torch.stack(w_fp4_list) # (experts, K_packed, N_packed)
# mat_b: (experts, K_packed, N_packed) in float4_e2m1fn_x2, K-major
# Reference: B shape=(2,128,128) stride=(16384,1,128) — K is stride-1
# w_fp4_list[e] is (K_packed, N_packed) with stride (N_packed, 1) — N-major
# We need K-major: stride (1, K_packed), so transpose last 2 dims
mat_b = torch.stack([w.T.contiguous() for w in w_fp4_list]) # K-major
print(f"\nKernel inputs:")
print(f" mat_a: {mat_a.shape} {mat_a.dtype}")