debug: add transform output shape/stride prints

This commit is contained in:
2026-05-12 14:22:05 +00:00
parent 1f13b24354
commit 916f03d528

View File

@@ -146,17 +146,21 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup
return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest))
def interleave_sf_mn_major(t, gran: int = 8) -> torch.Tensor:
"""Interleave SF while preserving MN-major layout (stride(-2)=1, stride(-1)=TMA-aligned)."""
# First do the interleave in K-major, then transpose back to MN-major
# t is MN-major: (num_groups, mn, packed_sf_k) with stride(-2)=1
t_k_major = t.transpose(-2, -1).contiguous() # (num_groups, packed_sf_k, mn) C-contiguous
g, k, mn = t_k_major.shape
"""Interleave SF while preserving MN-major layout (stride(-2)=1, stride(-1)=TMA-aligned).
Input/Output shape: (num_groups, mn, packed_sf_k) with MN-major strides.
Interleaves the mn dimension: [gate_0..7, up_0..7, gate_8..15, up_8..15, ...]
"""
# t: (groups, mn, packed_sf_k) MN-major, stride(-2)=1
# Step 1: transpose to K-major so we can use C-contiguous ops
t_k = t.transpose(-2, -1).contiguous() # (groups, packed_sf_k, mn) C-contiguous
g, k, mn = t_k.shape
half = mn // 2
gate = t_k_major[:, :, :half].reshape(g, k, half // gran, gran)
up = t_k_major[:, :, half:].reshape(g, k, half // gran, gran)
interleaved_k = torch.empty_like(t_k_major).copy_(
torch.stack([gate, up], dim=3).reshape(g, k, mn))
# Transpose back to MN-major
gate = t_k[:, :, :half].reshape(g, k, half // gran, gran)
up = t_k[:, :, half:].reshape(g, k, half // gran, gran)
interleaved_k = torch.empty(g, k, mn, dtype=t.dtype, device=t.device)
interleaved_k.copy_(torch.stack([gate, up], dim=3).reshape(g, k, mn))
# Step 2: transpose back to MN-major
return interleaved_k.transpose(-2, -1).contiguous().transpose(-2, -1)
return interleave(l1_weights[0]), interleave_sf_mn_major(l1_weights[1])
@@ -288,6 +292,8 @@ def transform_nvfp4_weights_for_mega_moe(
# DeepGEMM expects int8 (kPackedFP4 = torch.kInt8)
l1_out = (l1_interleaved[0].view(torch.int8), l1_interleaved[1])
l2_out = (l2_weights[0].view(torch.int8), l2_sf_transformed)
print(f"[debug-transform] l1_out_sf: dtype={l1_out[1].dtype} shape={tuple(l1_out[1].shape)} strides={l1_out[1].stride()}", flush=True)
print(f"[debug-transform] l2_out_sf: dtype={l2_out[1].dtype} shape={tuple(l2_out[1].shape)} strides={l2_out[1].stride()}", flush=True)
return l1_out, l2_out