From e498a2c7291adddc48aaa821e9228251937786a5 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 14:23:02 +0000 Subject: [PATCH] fix: single transpose back to MN-major, don't double-transpose The .contiguous().transpose() dance was swapping dims back. A single transpose from (g,k,mn) gives (g,mn,k) with stride(-2)=1, which is exactly the MN-major layout TMA expects. --- deep_gemm/mega/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index f2afc50..e2ff491 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -152,7 +152,7 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup Interleaves the mn dimension: [gate_0..7, up_0..7, gate_8..15, up_8..15, ...] """ # t: (groups, mn, packed_sf_k) MN-major, stride(-2)=1 - # Step 1: transpose to K-major so we can use C-contiguous ops + # Transpose to K-major C-contiguous for safe interleave ops t_k = t.transpose(-2, -1).contiguous() # (groups, packed_sf_k, mn) C-contiguous g, k, mn = t_k.shape half = mn // 2 @@ -160,8 +160,8 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup up = t_k[:, :, half:].reshape(g, k, half // gran, gran) interleaved_k = torch.empty(g, k, mn, dtype=t.dtype, device=t.device) interleaved_k.copy_(torch.stack([gate, up], dim=3).reshape(g, k, mn)) - # Step 2: transpose back to MN-major - return interleaved_k.transpose(-2, -1).contiguous().transpose(-2, -1) + # Single transpose back to MN-major: (g, mn, k) with stride(-2)=1 + return interleaved_k.transpose(-2, -1) return interleave(l1_weights[0]), interleave_sf_mn_major(l1_weights[1])