debug: add transform output shape/stride prints
This commit is contained in:
@@ -146,17 +146,21 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup
|
||||
return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest))
|
||||
|
||||
def interleave_sf_mn_major(t, gran: int = 8) -> torch.Tensor:
|
||||
"""Interleave SF while preserving MN-major layout (stride(-2)=1, stride(-1)=TMA-aligned)."""
|
||||
# First do the interleave in K-major, then transpose back to MN-major
|
||||
# t is MN-major: (num_groups, mn, packed_sf_k) with stride(-2)=1
|
||||
t_k_major = t.transpose(-2, -1).contiguous() # (num_groups, packed_sf_k, mn) C-contiguous
|
||||
g, k, mn = t_k_major.shape
|
||||
"""Interleave SF while preserving MN-major layout (stride(-2)=1, stride(-1)=TMA-aligned).
|
||||
|
||||
Input/Output shape: (num_groups, mn, packed_sf_k) with MN-major strides.
|
||||
Interleaves the mn dimension: [gate_0..7, up_0..7, gate_8..15, up_8..15, ...]
|
||||
"""
|
||||
# t: (groups, mn, packed_sf_k) MN-major, stride(-2)=1
|
||||
# Step 1: transpose to K-major so we can use C-contiguous ops
|
||||
t_k = t.transpose(-2, -1).contiguous() # (groups, packed_sf_k, mn) C-contiguous
|
||||
g, k, mn = t_k.shape
|
||||
half = mn // 2
|
||||
gate = t_k_major[:, :, :half].reshape(g, k, half // gran, gran)
|
||||
up = t_k_major[:, :, half:].reshape(g, k, half // gran, gran)
|
||||
interleaved_k = torch.empty_like(t_k_major).copy_(
|
||||
torch.stack([gate, up], dim=3).reshape(g, k, mn))
|
||||
# Transpose back to MN-major
|
||||
gate = t_k[:, :, :half].reshape(g, k, half // gran, gran)
|
||||
up = t_k[:, :, half:].reshape(g, k, half // gran, gran)
|
||||
interleaved_k = torch.empty(g, k, mn, dtype=t.dtype, device=t.device)
|
||||
interleaved_k.copy_(torch.stack([gate, up], dim=3).reshape(g, k, mn))
|
||||
# Step 2: transpose back to MN-major
|
||||
return interleaved_k.transpose(-2, -1).contiguous().transpose(-2, -1)
|
||||
|
||||
return interleave(l1_weights[0]), interleave_sf_mn_major(l1_weights[1])
|
||||
@@ -288,6 +292,8 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
# DeepGEMM expects int8 (kPackedFP4 = torch.kInt8)
|
||||
l1_out = (l1_interleaved[0].view(torch.int8), l1_interleaved[1])
|
||||
l2_out = (l2_weights[0].view(torch.int8), l2_sf_transformed)
|
||||
print(f"[debug-transform] l1_out_sf: dtype={l1_out[1].dtype} shape={tuple(l1_out[1].shape)} strides={l1_out[1].stride()}", flush=True)
|
||||
print(f"[debug-transform] l2_out_sf: dtype={l2_out[1].dtype} shape={tuple(l2_out[1].shape)} strides={l2_out[1].stride()}", flush=True)
|
||||
return l1_out, l2_out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user