Diagnostic: check tBgK/tVgV layout strides for degenerate dims

This commit is contained in:
2026-05-22 21:20:46 +00:00
parent ae173d3963
commit 2a9f764f8b

116
tests/diag_tma_layout.py Normal file
View File

@@ -0,0 +1,116 @@
"""Diagnostic: Check tBgK layout strides to see if the GMEM tile dim is degenerate."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32
from cutlass.utils import LayoutEnum
import cutlass.torch as ct
import math
HEAD_DIM = 64
n = 256
q = torch.randn(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
qk_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.K, Float32, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.MN, Float32, tcgen05.CtaGroup.ONE, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
kv_stage = 2; q_stage = 1
k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, BFloat16, kv_stage)
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, BFloat16, kv_stage)
k_s = cute.slice_(k_smem_s,(None,None,None,0))
v_s = cute.slice_(v_smem_s,(None,None,None,0))
tma_k, mK_tma = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, qk_mma.thr_id),
mK, k_s, qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape
)
tma_v, mV_tma = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, pv_mma.thr_id),
mV, v_s, pv_mma_tiler, pv_mma, cluster_layout_vmnk.shape
)
gK = cute.local_tile(mK_tma, cute.slice_(qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV_tma, cute.slice_(pv_mma_tiler,(0,None,None)),(None,None,None))
print(f'gK shape: {cute.shape(gK)}')
print(f'gK layout: {gK.layout}')
print(f'gK stride: {gK.layout.stride}')
print(f'gK size per mode: {[cute.size(gK, mode=[i]) for i in range(len(cute.shape(gK)))]}')
print()
print(f'gV shape: {cute.shape(gV)}')
print(f'gV layout: {gV.layout}')
print(f'gV stride: {gV.layout.stride}')
print()
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV)
print(f'tCgK shape: {cute.shape(tCgK)}')
print(f'tCgK layout: {tCgK.layout}')
print(f'tCgK stride: {tCgK.layout.stride}')
print()
print(f'tCgV shape: {cute.shape(tCgV)}')
print(f'tCgV layout: {tCgV.layout}')
print(f'tCgV stride: {tCgV.layout.stride}')
print()
sK = cute.make_tensor(BFloat16, k_s)
sV = cute.make_tensor(BFloat16, v_s)
b_lay = cute.make_layout(cute.slice_(cluster_layout_vmnk,(0,None,0,0)).shape)
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
print(f'=== tBgK (K TMA partition) ===')
print(f'shape: {cute.shape(tBgK)}')
print(f'layout: {tBgK.layout}')
print(f'stride: {tBgK.layout.stride}')
print(f'size per mode: {[cute.size(tBgK, mode=[i]) for i in range(len(cute.shape(tBgK)))]}')
print()
print(f'=== tVgV (V TMA partition) ===')
print(f'shape: {cute.shape(tVgV)}')
print(f'layout: {tVgV.layout}')
print(f'stride: {tVgV.layout.stride}')
print(f'size per mode: {[cute.size(tVgV, mode=[i]) for i in range(len(cute.shape(tVgV)))]}')
print()
# Now check the slices
print(f'=== tBgK after (None,None,0,0) ===')
tBgK_nn = tBgK[(None,None,0,0)]
print(f'shape: {cute.shape(tBgK_nn)}')
print(f'layout: {tBgK_nn.layout}')
print(f'stride: {tBgK_nn.layout.stride}')
print()
print(f'=== tBgK after (None,0,None,0) ===')
tBgK_n0 = tBgK[(None,0,None,0)]
print(f'shape: {cute.shape(tBgK_n0)}')
print(f'layout: {tBgK_n0.layout}')
print(f'stride: {tBgK_n0.layout.stride}')
print()
print(f'=== tVgV after (None,0,None,0) ===')
tVgV_n0 = tVgV[(None,0,None,0)]
print(f'shape: {cute.shape(tVgV_n0)}')
print(f'layout: {tVgV_n0.layout}')
print(f'stride: {tVgV_n0.layout.stride}')
print(f'=== tVgV after (None,None,0,0) ===')
tVgV_nn = tVgV[(None,None,0,0)]
print(f'shape: {cute.shape(tVgV_nn)}')
print(f'layout: {tVgV_nn.layout}')
print(f'stride: {tVgV_nn.layout.stride}')