fix: preserve MN-major layout when interleaving L1 SF tensors
_interleave_l1_weights used empty_like+copy_ which destroyed the MN-major stride layout required by TMA. Added interleave_sf_mn_major that works in K-major, interleaves, then transposes back to MN-major.
This commit is contained in:
@@ -145,7 +145,21 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup
|
||||
up = t[:, half:].reshape(g, half // gran, gran, *rest)
|
||||
return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest))
|
||||
|
||||
return interleave(l1_weights[0]), interleave(l1_weights[1])
|
||||
def interleave_sf_mn_major(t, gran: int = 8) -> torch.Tensor:
|
||||
"""Interleave SF while preserving MN-major layout (stride(-2)=1, stride(-1)=TMA-aligned)."""
|
||||
# First do the interleave in K-major, then transpose back to MN-major
|
||||
# t is MN-major: (num_groups, mn, packed_sf_k) with stride(-2)=1
|
||||
t_k_major = t.transpose(-2, -1).contiguous() # (num_groups, packed_sf_k, mn) C-contiguous
|
||||
g, k, mn = t_k_major.shape
|
||||
half = mn // 2
|
||||
gate = t_k_major[:, :, :half].reshape(g, k, half // gran, gran)
|
||||
up = t_k_major[:, :, half:].reshape(g, k, half // gran, gran)
|
||||
interleaved_k = torch.empty_like(t_k_major).copy_(
|
||||
torch.stack([gate, up], dim=3).reshape(g, k, mn))
|
||||
# Transpose back to MN-major
|
||||
return interleaved_k.transpose(-2, -1).contiguous().transpose(-2, -1)
|
||||
|
||||
return interleave(l1_weights[0]), interleave_sf_mn_major(l1_weights[1])
|
||||
|
||||
|
||||
def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
|
||||
@@ -320,14 +334,6 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor,
|
||||
l1_w, l1_w_sf = l1_weights
|
||||
l2_w, l2_w_sf = l2_weights
|
||||
|
||||
# Force contiguous on SF tensors — non-contiguous SF breaks TMA descriptors
|
||||
for name, t in [("l1_w_sf", l1_w_sf), ("l2_w_sf", l2_w_sf)]:
|
||||
if not t.is_contiguous():
|
||||
print(f"[contig-fix] {name}: was NOT contiguous, forcing", flush=True)
|
||||
# (assign back to correct variable)
|
||||
l1_w_sf = l1_w_sf.contiguous()
|
||||
l2_w_sf = l2_w_sf.contiguous()
|
||||
|
||||
for name, t in [("l1_w", l1_w), ("l1_w_sf", l1_w_sf),
|
||||
("l2_w", l2_w), ("l2_w_sf", l2_w_sf)]:
|
||||
print(f"[debug] {name}: dtype={t.dtype} shape={tuple(t.shape)} contig={t.is_contiguous()}", flush=True)
|
||||
|
||||
Reference in New Issue
Block a user