Diag: simplified TMA shape analysis

This commit is contained in:
2026-05-22 18:43:30 +00:00
parent 73109dde85
commit d2a3b83aa2

View File

@@ -1,99 +1,97 @@
"""Diagnostic: print TMA partition tensor shapes for multi-tile K/V."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
"""Diagnostic: print TMA partition tensor shapes for multi-tile K/V.
Simplified to avoid JIT-only constructs."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass import Float32, BFloat16, Int32
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
HEAD_DIM = 64
n = 256 # 2 KV tiles
n = 256
q = torch.randn(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
c = torch.zeros(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
# V layout for FMHA
v_fmha = cute.make_tensor(
mV.iterator,
cute.make_layout(
(HEAD_DIM, n, 1),
stride=(1, HEAD_DIM, HEAD_DIM * n),
),
)
qk_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, LayoutEnum.from_tensor(mQ).mma_major_mode(), LayoutEnum.from_tensor(mK).mma_major_mode(), Float32, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM)
v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, v_major, Float32, tcgen05.CtaGroup.ONE, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(mV).mma_major_mode(), Float32, tcgen05.CtaGroup.ONE, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
kv_stage = 2
print(f'qk_mma_tiler: {qk_mma_tiler}')
print(f'pv_mma_tiler: {pv_mma_tiler}')
print(f'cluster_layout_vmnk: {cute.shape(cluster_layout_vmnk)}')
kv_stage = 2; q_stage = 1
k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, BFloat16, kv_stage)
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, BFloat16, kv_stage)
tma_k, mK_tma = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, qk_mma.thr_id),
mK, cute.slice_(k_smem_s,(None,None,None,0)), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape
)
tma_v, mV_tma = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, pv_mma.thr_id),
v_fmha, cute.slice_(v_smem_s,(None,None,None,0)), pv_mma_tiler, pv_mma, cluster_layout_vmnk.shape
)
gK = cute.local_tile(mK_tma, cute.slice_(qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV_tma, cute.slice_(pv_mma_tiler,(0,None,None)),(None,None,None))
print(f'gK shape: {cute.shape(gK)}')
print(f'gV shape: {cute.shape(gV)}')
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV)
print(f'tCgK shape: {cute.shape(tCgK)}')
print(f'tCgV shape: {cute.shape(tCgV)}')
q_s = cute.slice_(k_smem_s,(None,None,None,0)) # just to get the per-stage shape
k_s = cute.slice_(k_smem_s,(None,None,None,0))
v_s = cute.slice_(v_smem_s,(None,None,None,0))
sK = cute.make_tensor(BFloat16, k_s.outer)
sV = cute.make_tensor(BFloat16, v_s.outer)
print(f'k_smem_s outer shape: {cute.shape(k_smem_s.outer)}')
print(f'k_smem_s inner shape: {cute.shape(k_smem_s.inner)}')
print(f'v_smem_s outer shape: {cute.shape(v_smem_s.outer)}')
print(f'v_smem_s inner shape: {cute.shape(v_smem_s.inner)}')
print(f'k_s (per-stage) shape: {cute.shape(k_s)}')
print(f'v_s (per-stage) shape: {cute.shape(v_s)}')
tma_k, mK_tma = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, qk_mma.thr_id),
mK, k_s, qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape
)
# For V, we use the raw mV (not v_fmha) just for shape diag
# The FMHA layout matters for the actual kernel but for shape analysis the mV is sufficient
v_major = LayoutEnum.from_tensor(mV).mma_major_mode()
print(f'v_major (raw mV): {v_major}')
gK = cute.local_tile(mK_tma, cute.slice_(qk_mma_tiler,(0,None,None)),(None,None,None))
print(f'mK_tma shape: {cute.shape(mK_tma)}')
print(f'gK shape: {cute.shape(gK)}')
n_kv_tiles = cute.size(gK, mode=[3])
print(f'n_kv_tiles (mode 3): {n_kv_tiles}')
qk_thr = qk_mma.get_slice(0)
tCgK = qk_thr.partition_B(gK)
print(f'tCgK shape: {cute.shape(tCgK)}')
sK = cute.make_tensor(BFloat16, k_s)
b_lay = cute.make_layout(cute.slice_(cluster_layout_vmnk,(0,None,0,0)).shape)
print(f'b_lay: {cute.shape(b_lay)}')
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
print(f'tBsK shape: {cute.shape(tBsK)}')
print(f'tBgK shape: {cute.shape(tBgK)}')
print(f'tVsV shape: {cute.shape(tVsV)}')
print(f'tVgV shape: {cute.shape(tVgV)}')
print(f'tBsK layout: {tBsK.layout}')
print(f'tBgK layout: {tBgK.layout}')
# Now apply the slice
# Apply the problematic slice
tBgK_sliced = tBgK[(None,0,None,0)]
tVgV_sliced = tVgV[(None,0,None,0)]
print(f'tBgK after (None,0,None,0) shape: {cute.shape(tBgK_sliced)}')
print(f'tVgV after (None,0,None,0) shape: {cute.shape(tVgV_sliced)}')
# What about the CUTLASS-style pre-slice?
# Try (None,None,0,0) — keeps first 2 modes, fixes last 2
# tBgK_refstyle = tBgK[(None,None,0,0)]
# print(f'tBgK after (None,None,0,0) shape: {cute.shape(tBgK_refstyle)}')
n_kv_tiles = cute.size(gK, mode=[3])
print(f'n_kv_tiles = {n_kv_tiles}')
# Try different slices to understand the mode meanings
for desc, sl in [
("(None,0,0,0)", (None,0,0,0)),
("(0,None,0,0)", (0,None,0,0)),
("(None,None,0,0)", (None,None,0,0)),
("(0,0,None,0)", (0,0,None,0)),
("(None,0,None,None)", (None,0,None,None)),
]:
try:
result = tBgK[sl]
print(f'tBgK after {desc} shape: {cute.shape(result)}')
except Exception as e:
print(f'tBgK after {desc}: ERROR {e}')