From bfe612969b0d74dfe3b407ddb54b16cf58bfb24f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 14:01:58 +0000 Subject: [PATCH] fix: preserve MN-major layout when interleaving L1 SF tensors _interleave_l1_weights used empty_like+copy_ which destroyed the MN-major stride layout required by TMA. Added interleave_sf_mn_major that works in K-major, interleaves, then transposes back to MN-major. --- deep_gemm/mega/__init__.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 08b4269..5a93415 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -145,7 +145,21 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup up = t[:, half:].reshape(g, half // gran, gran, *rest) return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) - return interleave(l1_weights[0]), interleave(l1_weights[1]) + def interleave_sf_mn_major(t, gran: int = 8) -> torch.Tensor: + """Interleave SF while preserving MN-major layout (stride(-2)=1, stride(-1)=TMA-aligned).""" + # First do the interleave in K-major, then transpose back to MN-major + # t is MN-major: (num_groups, mn, packed_sf_k) with stride(-2)=1 + t_k_major = t.transpose(-2, -1).contiguous() # (num_groups, packed_sf_k, mn) C-contiguous + g, k, mn = t_k_major.shape + half = mn // 2 + gate = t_k_major[:, :, :half].reshape(g, k, half // gran, gran) + up = t_k_major[:, :, half:].reshape(g, k, half // gran, gran) + interleaved_k = torch.empty_like(t_k_major).copy_( + torch.stack([gate, up], dim=3).reshape(g, k, mn)) + # Transpose back to MN-major + return interleaved_k.transpose(-2, -1).contiguous().transpose(-2, -1) + + return interleave(l1_weights[0]), interleave_sf_mn_major(l1_weights[1]) def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: @@ -320,14 +334,6 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor, l1_w, l1_w_sf = l1_weights l2_w, l2_w_sf = l2_weights - # Force contiguous on SF tensors — non-contiguous SF breaks TMA descriptors - for name, t in [("l1_w_sf", l1_w_sf), ("l2_w_sf", l2_w_sf)]: - if not t.is_contiguous(): - print(f"[contig-fix] {name}: was NOT contiguous, forcing", flush=True) - # (assign back to correct variable) - l1_w_sf = l1_w_sf.contiguous() - l2_w_sf = l2_w_sf.contiguous() - for name, t in [("l1_w", l1_w), ("l1_w_sf", l1_w_sf), ("l2_w", l2_w), ("l2_w_sf", l2_w_sf)]: print(f"[debug] {name}: dtype={t.dtype} shape={tuple(t.shape)} contig={t.is_contiguous()}", flush=True)