From 2f670e33d14b65b2860c04c00bc2c3c92ee95d15 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 21:54:22 +0000 Subject: [PATCH] TMA shape diagnostic: exact setup from example9 + shape prints --- tests/unit/test_tma_shapes.py | 163 ++++++++++++++++++++++++---------- 1 file changed, 115 insertions(+), 48 deletions(-) diff --git a/tests/unit/test_tma_shapes.py b/tests/unit/test_tma_shapes.py index b66870bb..89d49ae8 100644 --- a/tests/unit/test_tma_shapes.py +++ b/tests/unit/test_tma_shapes.py @@ -1,32 +1,77 @@ """ -Diagnostic: print tBgK and tVgV shapes after tma_partition. -Just need to see how many modes and which is the KV tile dim. +Diagnostic: print tBgK and tVgV shapes BEFORE pre-slicing. +This runs at JIT trace time, so Python print gives us static shape info. """ import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass import Float32, BFloat16, Int32, const_expr from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset import cuda.bindings.driver as cuda import cutlass.torch as ct import math HEAD_DIM = 64 + class TmaShapeDiag: def __init__(self, s_k=256): self.s_k = s_k self.n_kv_tiles = s_k // 128 - self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE - self.kv_stage = 2; self.q_stage = 1 - self.threads_per_cta = 192 - self.qk_mma_tiler = (128, 128, 4) - self.pv_mma_tiler = (128, HEAD_DIM, 4) + self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192; self.num_c_stage = 2 + self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2 + self.scale_softmax = 1.0 / math.sqrt(HEAD_DIM) + self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) + + def _setup(self, qk_mma, pv_mma): + qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (128, 128, qk_ik * 4) + pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik)) + self.mma_tiler = self.qk_mma_tiler + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2]) + self.c_layout = LayoutEnum.ROW_MAJOR + self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) + self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_as) + self.tmem_s0_offset = 0; self.tmem_p0_offset = 32 + p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width + p_end = self.tmem_p0_offset + p_cols_fp32 + s_cols = self.qk_mma_tiler[1] + o_after = max(s_cols, p_end) + self.tmem_o0_offset = ((o_after + 31) // 32) * 32 + o_cols = find_tmem_tensor_col_offset(tOtO) + total = self.tmem_o0_offset + o_cols + self.num_tmem_alloc_cols = 1 + while self.num_tmem_alloc_cols < total: + self.num_tmem_alloc_cols *= 2 + cta = cute.size(qk_mma.thr_id.shape) + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) + k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) + v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) + self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta + self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) + + cute.size_in_bytes(self.q_dtype, v_s)) * cta @cute.jit - def __call__(self, q, k, v, stream): - a_major = LayoutEnum.from_tensor(q).mma_major_mode() - b_major = LayoutEnum.from_tensor(k).mma_major_mode() + def __call__(self, q, k, v, c, stream): + self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() v_fmha = cute.make_tensor( v.iterator, cute.make_layout( @@ -34,21 +79,28 @@ class TmaShapeDiag: stride=(1, HEAD_DIM, HEAD_DIM * self.s_k), ), ) - v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() - qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, a_major, b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) - pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM) + self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) + tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) + tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) + tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) + epi_s = cute.select(self.c_smem_s,mode=[0,1]) + tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) + # Stop here — just check shapes + self._diag_kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,self.cluster_layout_vmnk).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) - q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) - k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) - v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) - - q_s = cute.slice_(q_smem_s,(None,None,None,0)) - k_s = cute.slice_(k_smem_s,(None,None,None,0)) - v_s = cute.slice_(v_smem_s,(None,None,None,0)) - - tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,(1,1,1)) - tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,(1,1,1)) - tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,(1,1,1)) + @cute.kernel + def _diag_kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, cl_vmnk): + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) + k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) + v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) + sQ = cute.make_tensor(BFloat16, q_s.outer) + sK = cute.make_tensor(BFloat16, k_s.outer) + sV = cute.make_tensor(BFloat16, v_s.outer) gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) @@ -58,29 +110,39 @@ class TmaShapeDiag: tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK) tCgV = pv_thr.partition_B(gV) - a_lay = cute.make_layout(1) - tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(q_s,0,3),cute.group_modes(tCgQ,0,3)) - b_lay = cute.make_layout(1) - tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(k_s,0,3),cute.group_modes(tCgK,0,3)) - tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(v_s,0,3),cute.group_modes(tCgV,0,3)) + a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape) + tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) + tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) + tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - print(f"tAgQ shape: {cute.shape(tAgQ)} rank: {tAgQ.rank}") - print(f"tBgK shape: {cute.shape(tBgK)} rank: {tBgK.rank}") - print(f"tVgV shape: {cute.shape(tVgV)} rank: {tVgV.rank}") - print(f"tAsQ shape: {cute.shape(tAsQ)} rank: {tAsQ.rank}") - print(f"tBsK shape: {cute.shape(tBsK)} rank: {tBsK.rank}") - print(f"tVsV shape: {cute.shape(tVsV)} rank: {tVsV.rank}") - # Print size of each mode - for i in range(tBgK.rank): - try: - print(f" tBgK mode {i} size: {cute.size(tBgK, mode=[i])}") - except: - print(f" tBgK mode {i}: error getting size") - for i in range(tVgV.rank): - try: - print(f" tVgV mode {i} size: {cute.size(tVgV, mode=[i])}") - except: - print(f" tVgV mode {i}: error getting size") + # Print shapes BEFORE any pre-slice + print(f"=== TMA partition shapes (n_kv_tiles={self.n_kv_tiles}) ===") + print(f"tAgQ: shape={cute.shape(tAgQ)} rank={tAgQ.rank}") + print(f"tBgK: shape={cute.shape(tBgK)} rank={tBgK.rank}") + print(f"tVgV: shape={cute.shape(tVgV)} rank={tVgV.rank}") + print(f"tAsQ: shape={cute.shape(tAsQ)} rank={tAsQ.rank}") + print(f"tBsK: shape={cute.shape(tBsK)} rank={tBsK.rank}") + print(f"tVsV: shape={cute.shape(tVsV)} rank={tVsV.rank}") + + # Print per-mode sizes + for name, t in [("tAgQ", tAgQ), ("tBgK", tBgK), ("tVgV", tVgV)]: + for i in range(t.rank): + sz = cute.size(t, mode=[i]) + print(f" {name} mode {i} size={sz}") + + # Also print after the original pre-slice + tAgQ2 = tAgQ[(None,0,None,0)] + tBgK2 = tBgK[(None,None,0,0)] + tVgV2 = tVgV[(None,0,None,0)] + print(f"\nAfter pre-slice:") + print(f"tAgQ[(None,0,None,0)]: shape={cute.shape(tAgQ2)} rank={tAgQ2.rank}") + print(f"tBgK[(None,None,0,0)]: shape={cute.shape(tBgK2)} rank={tBgK2.rank}") + print(f"tVgV[(None,0,None,0)]: shape={cute.shape(tVgV2)} rank={tVgV2.rank}") + for name, t in [("tAgQ2", tAgQ2), ("tBgK2", tBgK2), ("tVgV2", tVgV2)]: + for i in range(t.rank): + sz = cute.size(t, mode=[i]) + print(f" {name} mode {i} size={sz}") def test(): @@ -90,16 +152,21 @@ def test(): k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda') v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda') v_kernel = v.unsqueeze(-1) + c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) diag = TmaShapeDiag(s_k=n) - compiled = cute.compile(diag, mQ, mK, mV, stream) - compiled(mQ, mK, mV, stream) + print('Compiling...', flush=True) + compiled = cute.compile(diag, mQ, mK, mV, mC, stream) + print('Running...', flush=True) + compiled(mQ, mK, mV, mC, stream) torch.cuda.synchronize() + print('Done.') if __name__ == '__main__':