diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 206c6cb..f2afc50 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -146,17 +146,21 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) def interleave_sf_mn_major(t, gran: int = 8) -> torch.Tensor: - """Interleave SF while preserving MN-major layout (stride(-2)=1, stride(-1)=TMA-aligned).""" - # First do the interleave in K-major, then transpose back to MN-major - # t is MN-major: (num_groups, mn, packed_sf_k) with stride(-2)=1 - t_k_major = t.transpose(-2, -1).contiguous() # (num_groups, packed_sf_k, mn) C-contiguous - g, k, mn = t_k_major.shape + """Interleave SF while preserving MN-major layout (stride(-2)=1, stride(-1)=TMA-aligned). + + Input/Output shape: (num_groups, mn, packed_sf_k) with MN-major strides. + Interleaves the mn dimension: [gate_0..7, up_0..7, gate_8..15, up_8..15, ...] + """ + # t: (groups, mn, packed_sf_k) MN-major, stride(-2)=1 + # Step 1: transpose to K-major so we can use C-contiguous ops + t_k = t.transpose(-2, -1).contiguous() # (groups, packed_sf_k, mn) C-contiguous + g, k, mn = t_k.shape half = mn // 2 - gate = t_k_major[:, :, :half].reshape(g, k, half // gran, gran) - up = t_k_major[:, :, half:].reshape(g, k, half // gran, gran) - interleaved_k = torch.empty_like(t_k_major).copy_( - torch.stack([gate, up], dim=3).reshape(g, k, mn)) - # Transpose back to MN-major + gate = t_k[:, :, :half].reshape(g, k, half // gran, gran) + up = t_k[:, :, half:].reshape(g, k, half // gran, gran) + interleaved_k = torch.empty(g, k, mn, dtype=t.dtype, device=t.device) + interleaved_k.copy_(torch.stack([gate, up], dim=3).reshape(g, k, mn)) + # Step 2: transpose back to MN-major return interleaved_k.transpose(-2, -1).contiguous().transpose(-2, -1) return interleave(l1_weights[0]), interleave_sf_mn_major(l1_weights[1]) @@ -288,6 +292,8 @@ def transform_nvfp4_weights_for_mega_moe( # DeepGEMM expects int8 (kPackedFP4 = torch.kInt8) l1_out = (l1_interleaved[0].view(torch.int8), l1_interleaved[1]) l2_out = (l2_weights[0].view(torch.int8), l2_sf_transformed) + print(f"[debug-transform] l1_out_sf: dtype={l1_out[1].dtype} shape={tuple(l1_out[1].shape)} strides={l1_out[1].stride()}", flush=True) + print(f"[debug-transform] l2_out_sf: dtype={l2_out[1].dtype} shape={tuple(l2_out[1].shape)} strides={l2_out[1].stride()}", flush=True) return l1_out, l2_out