fix: B tensor K-major strides, scale_b axis swap
Two fixes: 1. B tensor: permute(0,2,1).contiguous().permute(0,2,1) gives K-major stride (16384,1,128) matching reference 2. scale_b: transpose to (N, K_sf) before swizzling — reference uses (intermediate, hidden//16) not (hidden//16, intermediate)
This commit is contained in:
@@ -235,7 +235,10 @@ def main():
|
||||
scale_a = assemble_raw_scales_2d3d_2d_side(raw_scale_a)
|
||||
|
||||
# Assemble scale_b (3D side: per-expert, pad and swizzle each)
|
||||
scale_b = assemble_raw_scales_2d3d_3d_side(w_sf_list)
|
||||
# Reference uses (N, K_sf) = (intermediate, hidden//16) for each expert
|
||||
# Our w_sf is (K_sf, intermediate) — need to transpose
|
||||
w_sf_t = [sf.T.contiguous() for sf in w_sf_list]
|
||||
scale_b = assemble_raw_scales_2d3d_3d_side(w_sf_t)
|
||||
|
||||
# Global scales
|
||||
global_scale_a = torch.tensor([x_gs] * num_experts, dtype=torch.float32, device=device)
|
||||
@@ -247,9 +250,9 @@ def main():
|
||||
|
||||
# mat_b: (experts, K_packed, N_packed) in float4_e2m1fn_x2, K-major
|
||||
# Reference: B shape=(2,128,128) stride=(16384,1,128) — K is stride-1
|
||||
# w_fp4_list[e] is (K_packed, N_packed) with stride (N_packed, 1) — N-major
|
||||
# We need K-major: stride (1, K_packed), so transpose last 2 dims
|
||||
mat_b = torch.stack([w.T.contiguous() for w in w_fp4_list]) # K-major
|
||||
# torch.stack gives stride (16384, 128, 1) — N is stride-1 (wrong)
|
||||
# We need K-major: permute, make contiguous, permute back
|
||||
mat_b = torch.stack(w_fp4_list).permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
||||
|
||||
print(f"\nKernel inputs:")
|
||||
print(f" mat_a: {mat_a.shape} {mat_a.dtype}")
|
||||
|
||||
Reference in New Issue
Block a user