fix: single transpose back to MN-major, don't double-transpose

The .contiguous().transpose() dance was swapping dims back.
A single transpose from (g,k,mn) gives (g,mn,k) with stride(-2)=1,
which is exactly the MN-major layout TMA expects.
This commit is contained in:
2026-05-12 14:23:02 +00:00
parent 916f03d528
commit e498a2c729

View File

@@ -152,7 +152,7 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup
Interleaves the mn dimension: [gate_0..7, up_0..7, gate_8..15, up_8..15, ...]
"""
# t: (groups, mn, packed_sf_k) MN-major, stride(-2)=1
# Step 1: transpose to K-major so we can use C-contiguous ops
# Transpose to K-major C-contiguous for safe interleave ops
t_k = t.transpose(-2, -1).contiguous() # (groups, packed_sf_k, mn) C-contiguous
g, k, mn = t_k.shape
half = mn // 2
@@ -160,8 +160,8 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup
up = t_k[:, :, half:].reshape(g, k, half // gran, gran)
interleaved_k = torch.empty(g, k, mn, dtype=t.dtype, device=t.device)
interleaved_k.copy_(torch.stack([gate, up], dim=3).reshape(g, k, mn))
# Step 2: transpose back to MN-major
return interleaved_k.transpose(-2, -1).contiguous().transpose(-2, -1)
# Single transpose back to MN-major: (g, mn, k) with stride(-2)=1
return interleaved_k.transpose(-2, -1)
return interleave(l1_weights[0]), interleave_sf_mn_major(l1_weights[1])