fix: B tensor must be K-major (transpose last 2 dims)
Reference shows B stride=(16384,1,128) — K is stride-1 (K-major). Our stack produces N-major stride=(16384,128,1). Added .T.contiguous().
This commit is contained in:
@@ -241,16 +241,15 @@ def main():
|
||||
global_scale_a = torch.tensor([x_gs] * num_experts, dtype=torch.float32, device=device)
|
||||
global_scale_b = torch.tensor([w_gs_list[e] for e in range(num_experts)], dtype=torch.float32, device=device)
|
||||
|
||||
# mat_a is already (tokens_sum, K_packed) in float4_e2m1fn_x2
|
||||
# The kernel's 2Dx3D scenario expects mat_a: (tokens, hidden) where
|
||||
# hidden is the LOGICAL K dimension (packed as float4_e2m1fn_x2)
|
||||
# mat_a is (tokens_sum, K_packed) in float4_e2m1fn_x2, row-major (K-major)
|
||||
# This matches the reference: A shape=(128,128) stride=(128,1)
|
||||
mat_a = x_fp4
|
||||
|
||||
# mat_b: (experts, hidden, intermediate) in float4_e2m1fn_x2
|
||||
# packed_dim=1 means hidden (K) is packed
|
||||
# w_bf16[e] is (hidden, intermediate) — we need (hidden, intermediate) in FP4
|
||||
# with K (hidden) as the packed dimension
|
||||
mat_b = torch.stack(w_fp4_list) # (experts, K_packed, N_packed)
|
||||
# mat_b: (experts, K_packed, N_packed) in float4_e2m1fn_x2, K-major
|
||||
# Reference: B shape=(2,128,128) stride=(16384,1,128) — K is stride-1
|
||||
# w_fp4_list[e] is (K_packed, N_packed) with stride (N_packed, 1) — N-major
|
||||
# We need K-major: stride (1, K_packed), so transpose last 2 dims
|
||||
mat_b = torch.stack([w.T.contiguous() for w in w_fp4_list]) # K-major
|
||||
|
||||
print(f"\nKernel inputs:")
|
||||
print(f" mat_a: {mat_a.shape} {mat_a.dtype}")
|
||||
|
||||
Reference in New Issue
Block a user