"""Minimal TMA multi-tile test: does the TMA coordinate propagate at runtime? This is a tiny kernel that loads K from two different tiles and checks if the data is different. """ import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass import Float32, BFloat16, Int32, Boolean from cutlass.utils import LayoutEnum import cuda.bindings.driver as cuda import cutlass.torch as ct import math HEAD_DIM = 64 class TmaMultiTileTest: def __init__(self, s_k=256): self.s_k = s_k self.q_dtype = BFloat16 self.cta_group = tcgen05.CtaGroup.ONE self.kv_stage = 2 def _setup(self, qk_mma): qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) self.qk_mma_tiler = (128, 128, qk_ik * 4) self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) cta = cute.size(qk_mma.thr_id.shape) k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta @cute.jit def __call__(self, k, out_buf, stream): b_major = LayoutEnum.from_tensor(k).mma_major_mode() qk_mma = utils.sm100.make_trivial_tiled_mma( self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, b_major, Float32, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) self._setup(qk_mma) k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B( utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_layout_vmnk.shape, qk_mma.thr_id), k, k_s, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) self._kernel(qk_mma, tma_k, mK, out_buf, self.cluster_layout_vmnk, self.k_smem_s).launch( grid=(1,1,1), block=[32,1,1], stream=stream) @cute.kernel def _kernel(self, qk_mma, tma_k, mK, out_buf, cl_vmnk, k_smem_s): tidx,_,_ = cute.arch.thread_idx() cpasync.prefetch_descriptor(tma_k) sK = cute.make_tensor(self.q_dtype, cute.slice_(k_smem_s,(None,None,None,0)), byte_alignment=128) gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) n_kv_tiles = cute.size(gK, mode=[3]) qk_thr = qk_mma.get_slice(0) tCgK = qk_thr.partition_B(gK) b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) # Test different slice patterns tBgK_nn = tBgK[(None,None,0,0)] # Mode 1 = GMEM tiles (should work for multi-tile) tBgK_n0 = tBgK[(None,0,None,0)] # Mode 1 fixed to 0 (broken for multi-tile) print(f"tBgK shape: {cute.shape(tBgK)}", end="") print(f" tBgK_nn shape: {cute.shape(tBgK_nn)}", end="") print(f" tBgK_n0 shape: {cute.shape(tBgK_n0)}", end="") print(f" n_kv_tiles: {n_kv_tiles}") # Load K from each tile using the (None,None,0,0) slice # and write the first 4 elements to out_buf for verification for kt in range(n_kv_tiles): cute.copy(tma_k, tBgK_nn[(None, Int32(kt))], tBsK[(None, 0)], tma_bar_ptr=None) # Read first 4 BF16 values from SMEM and write to out_buf for i in range(4): out_buf[kt * 4 + i] = sK[0, i, 0, 0, 0] def test(): for n in [256]: torch.manual_seed(42) k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') out_buf = torch.zeros(n // 128 * 4, dtype=torch.bfloat16, device='cuda') mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mO = ct.from_dlpack(out_buf).mark_layout_dynamic(leading_dim=ct.get_leading_dim(out_buf)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) kernel = TmaMultiTileTest(s_k=n) print(f'n={n}: Compiling...', flush=True) compiled = cute.compile(kernel, mK, mO, stream) compiled(mK, mO, stream) torch.cuda.synchronize() print(f'out_buf: {out_buf.tolist()}') # Check if tile 0 and tile 1 data are different tile0 = k[:128, 0, 0][:4].tolist() tile1 = k[128:, 0, 0][:4].tolist() print(f'tile0 first 4: {tile0}') print(f'tile1 first 4: {tile1}') if __name__ == '__main__': test()