fix: preserve MN-major layout when interleaving L1 SF tensors

_interleave_l1_weights used empty_like+copy_ which destroyed the
MN-major stride layout required by TMA. Added interleave_sf_mn_major
that works in K-major, interleaves, then transposes back to MN-major.
This commit is contained in:
2026-05-12 14:01:58 +00:00
parent 76220ac6ee
commit bfe612969b

View File

@@ -145,7 +145,21 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup
up = t[:, half:].reshape(g, half // gran, gran, *rest)
return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest))
return interleave(l1_weights[0]), interleave(l1_weights[1])
def interleave_sf_mn_major(t, gran: int = 8) -> torch.Tensor:
"""Interleave SF while preserving MN-major layout (stride(-2)=1, stride(-1)=TMA-aligned)."""
# First do the interleave in K-major, then transpose back to MN-major
# t is MN-major: (num_groups, mn, packed_sf_k) with stride(-2)=1
t_k_major = t.transpose(-2, -1).contiguous() # (num_groups, packed_sf_k, mn) C-contiguous
g, k, mn = t_k_major.shape
half = mn // 2
gate = t_k_major[:, :, :half].reshape(g, k, half // gran, gran)
up = t_k_major[:, :, half:].reshape(g, k, half // gran, gran)
interleaved_k = torch.empty_like(t_k_major).copy_(
torch.stack([gate, up], dim=3).reshape(g, k, mn))
# Transpose back to MN-major
return interleaved_k.transpose(-2, -1).contiguous().transpose(-2, -1)
return interleave(l1_weights[0]), interleave_sf_mn_major(l1_weights[1])
def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
@@ -320,14 +334,6 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor,
l1_w, l1_w_sf = l1_weights
l2_w, l2_w_sf = l2_weights
# Force contiguous on SF tensors — non-contiguous SF breaks TMA descriptors
for name, t in [("l1_w_sf", l1_w_sf), ("l2_w_sf", l2_w_sf)]:
if not t.is_contiguous():
print(f"[contig-fix] {name}: was NOT contiguous, forcing", flush=True)
# (assign back to correct variable)
l1_w_sf = l1_w_sf.contiguous()
l2_w_sf = l2_w_sf.contiguous()
for name, t in [("l1_w", l1_w), ("l1_w_sf", l1_w_sf),
("l2_w", l2_w), ("l2_w_sf", l2_w_sf)]:
print(f"[debug] {name}: dtype={t.dtype} shape={tuple(t.shape)} contig={t.is_contiguous()}", flush=True)