diff --git a/tests/test_diag_multitile.py b/tests/test_diag_multitile.py new file mode 100644 index 00000000..d0d974bc --- /dev/null +++ b/tests/test_diag_multitile.py @@ -0,0 +1,30 @@ +"""Test the identity diag for multi-tile n=256,384""" +import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline, cutlass.torch as ct, cuda.bindings.driver as cuda +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean +from cutlass.utils import LayoutEnum +from test_fmha_v3_diag import FmhaV3Diag + +HEAD_DIM = 64 +for n in [128, 256, 384]: + torch.manual_seed(42) + q = torch.randn(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') + v = torch.ones(n, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + v_kernel = v.unsqueeze(-1) + c = torch.zeros(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = FmhaV3Diag(s_k=n) + print(f'n={n}: Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + qf = q[:,:,0].float(); kf = k[:,:,0].float() + ref = (qf @ kf.T * (1.0/math.sqrt(HEAD_DIM))) @ v.float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + print(f'Identity diag n={n}: cos {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}') diff --git a/tests/test_tma_coord.py b/tests/test_tma_coord.py new file mode 100644 index 00000000..e07cc854 --- /dev/null +++ b/tests/test_tma_coord.py @@ -0,0 +1,102 @@ +"""Minimal TMA multi-tile test: does the TMA coordinate propagate at runtime? +This is a tiny kernel that loads K from two different tiles and checks if the data is different. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean +from cutlass.utils import LayoutEnum +import cuda.bindings.driver as cuda +import cutlass.torch as ct +import math + +HEAD_DIM = 64 + +class TmaMultiTileTest: + def __init__(self, s_k=256): + self.s_k = s_k + self.q_dtype = BFloat16 + self.cta_group = tcgen05.CtaGroup.ONE + self.kv_stage = 2 + + def _setup(self, qk_mma): + qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (128, 128, qk_ik * 4) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) + cta = cute.size(qk_mma.thr_id.shape) + k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) + self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta + + @cute.jit + def __call__(self, k, out_buf, stream): + b_major = LayoutEnum.from_tensor(k).mma_major_mode() + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + cute.nvgpu.OperandMajorMode.K, b_major, + Float32, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) + self._setup(qk_mma) + k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) + tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_layout_vmnk.shape, qk_mma.thr_id), + k, k_s, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + self._kernel(qk_mma, tma_k, mK, out_buf, self.cluster_layout_vmnk, self.k_smem_s).launch( + grid=(1,1,1), block=[32,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, tma_k, mK, out_buf, cl_vmnk, k_smem_s): + tidx,_,_ = cute.arch.thread_idx() + cpasync.prefetch_descriptor(tma_k) + + sK = cute.make_tensor(self.q_dtype, cute.slice_(k_smem_s,(None,None,None,0)), byte_alignment=128) + + gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) + n_kv_tiles = cute.size(gK, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgK = qk_thr.partition_B(gK) + b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) + tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) + + # Test different slice patterns + tBgK_nn = tBgK[(None,None,0,0)] # Mode 1 = GMEM tiles (should work for multi-tile) + tBgK_n0 = tBgK[(None,0,None,0)] # Mode 1 fixed to 0 (broken for multi-tile) + + print(f"tBgK shape: {cute.shape(tBgK)}", end="") + print(f" tBgK_nn shape: {cute.shape(tBgK_nn)}", end="") + print(f" tBgK_n0 shape: {cute.shape(tBgK_n0)}", end="") + print(f" n_kv_tiles: {n_kv_tiles}") + + # Load K from each tile using the (None,None,0,0) slice + # and write the first 4 elements to out_buf for verification + for kt in range(n_kv_tiles): + cute.copy(tma_k, tBgK_nn[(None, Int32(kt))], tBsK[(None, 0)], tma_bar_ptr=None) + # Read first 4 BF16 values from SMEM and write to out_buf + for i in range(4): + out_buf[kt * 4 + i] = sK[0, i, 0, 0, 0] + + +def test(): + for n in [256]: + torch.manual_seed(42) + k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') + out_buf = torch.zeros(n // 128 * 4, dtype=torch.bfloat16, device='cuda') + + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mO = ct.from_dlpack(out_buf).mark_layout_dynamic(leading_dim=ct.get_leading_dim(out_buf)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + kernel = TmaMultiTileTest(s_k=n) + print(f'n={n}: Compiling...', flush=True) + compiled = cute.compile(kernel, mK, mO, stream) + compiled(mK, mO, stream) + torch.cuda.synchronize() + + print(f'out_buf: {out_buf.tolist()}') + # Check if tile 0 and tile 1 data are different + tile0 = k[:128, 0, 0][:4].tolist() + tile1 = k[128:, 0, 0][:4].tolist() + print(f'tile0 first 4: {tile0}') + print(f'tile1 first 4: {tile1}') + +if __name__ == '__main__': + test()