diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 213054a..08b4269 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -319,6 +319,15 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor, """ l1_w, l1_w_sf = l1_weights l2_w, l2_w_sf = l2_weights + + # Force contiguous on SF tensors — non-contiguous SF breaks TMA descriptors + for name, t in [("l1_w_sf", l1_w_sf), ("l2_w_sf", l2_w_sf)]: + if not t.is_contiguous(): + print(f"[contig-fix] {name}: was NOT contiguous, forcing", flush=True) + # (assign back to correct variable) + l1_w_sf = l1_w_sf.contiguous() + l2_w_sf = l2_w_sf.contiguous() + for name, t in [("l1_w", l1_w), ("l1_w_sf", l1_w_sf), ("l2_w", l2_w), ("l2_w_sf", l2_w_sf)]: print(f"[debug] {name}: dtype={t.dtype} shape={tuple(t.shape)} contig={t.is_contiguous()}", flush=True) @@ -330,7 +339,7 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor, _C.fp8_nvfp4_mega_moe( y, - l1_weights, l2_weights, + (l1_w, l1_w_sf), (l2_w, l2_w_sf), cumulative_local_expert_recv_stats, sym_buffer.buffer, sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),