fix: single transpose back to MN-major, don't double-transpose
The .contiguous().transpose() dance was swapping dims back. A single transpose from (g,k,mn) gives (g,mn,k) with stride(-2)=1, which is exactly the MN-major layout TMA expects.
This commit is contained in:
@@ -152,7 +152,7 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup
|
||||
Interleaves the mn dimension: [gate_0..7, up_0..7, gate_8..15, up_8..15, ...]
|
||||
"""
|
||||
# t: (groups, mn, packed_sf_k) MN-major, stride(-2)=1
|
||||
# Step 1: transpose to K-major so we can use C-contiguous ops
|
||||
# Transpose to K-major C-contiguous for safe interleave ops
|
||||
t_k = t.transpose(-2, -1).contiguous() # (groups, packed_sf_k, mn) C-contiguous
|
||||
g, k, mn = t_k.shape
|
||||
half = mn // 2
|
||||
@@ -160,8 +160,8 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup
|
||||
up = t_k[:, :, half:].reshape(g, k, half // gran, gran)
|
||||
interleaved_k = torch.empty(g, k, mn, dtype=t.dtype, device=t.device)
|
||||
interleaved_k.copy_(torch.stack([gate, up], dim=3).reshape(g, k, mn))
|
||||
# Step 2: transpose back to MN-major
|
||||
return interleaved_k.transpose(-2, -1).contiguous().transpose(-2, -1)
|
||||
# Single transpose back to MN-major: (g, mn, k) with stride(-2)=1
|
||||
return interleaved_k.transpose(-2, -1)
|
||||
|
||||
return interleave(l1_weights[0]), interleave_sf_mn_major(l1_weights[1])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user