From 7ad4ddb6ba107cb3e8c471895e75cdde7b7e4b7c Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 18:43:05 +0000 Subject: [PATCH] Diag: print TMA partition shapes for multi-tile debugging --- tests/diag_tma_shapes.py | 99 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 tests/diag_tma_shapes.py diff --git a/tests/diag_tma_shapes.py b/tests/diag_tma_shapes.py new file mode 100644 index 00000000..88ccfd01 --- /dev/null +++ b/tests/diag_tma_shapes.py @@ -0,0 +1,99 @@ +"""Diagnostic: print TMA partition tensor shapes for multi-tile K/V.""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda +import cutlass.torch as ct +import math + +HEAD_DIM = 64 +n = 256 # 2 KV tiles + +q = torch.randn(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') +k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') +v = torch.randn(n, HEAD_DIM, dtype=torch.bfloat16, device='cuda') +v_kernel = v.unsqueeze(-1) +c = torch.zeros(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') + +mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) +mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) +mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) +mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + +# V layout for FMHA +v_fmha = cute.make_tensor( + mV.iterator, + cute.make_layout( + (HEAD_DIM, n, 1), + stride=(1, HEAD_DIM, HEAD_DIM * n), + ), +) + +qk_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, LayoutEnum.from_tensor(mQ).mma_major_mode(), LayoutEnum.from_tensor(mK).mma_major_mode(), Float32, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM) +v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() +pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, v_major, Float32, tcgen05.CtaGroup.ONE, (128,HEAD_DIM), tcgen05.OperandSource.TMEM) + +qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) +qk_mma_tiler = (128, 128, qk_ik * 4) +pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) +pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik)) + +cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + +kv_stage = 2 +k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, BFloat16, kv_stage) +v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, BFloat16, kv_stage) + +tma_k, mK_tma = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, qk_mma.thr_id), + mK, cute.slice_(k_smem_s,(None,None,None,0)), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape +) +tma_v, mV_tma = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, pv_mma.thr_id), + v_fmha, cute.slice_(v_smem_s,(None,None,None,0)), pv_mma_tiler, pv_mma, cluster_layout_vmnk.shape +) + +gK = cute.local_tile(mK_tma, cute.slice_(qk_mma_tiler,(0,None,None)),(None,None,None)) +gV = cute.local_tile(mV_tma, cute.slice_(pv_mma_tiler,(0,None,None)),(None,None,None)) + +print(f'gK shape: {cute.shape(gK)}') +print(f'gV shape: {cute.shape(gV)}') + +qk_thr = qk_mma.get_slice(0) +pv_thr = pv_mma.get_slice(0) +tCgK = qk_thr.partition_B(gK) +tCgV = pv_thr.partition_B(gV) + +print(f'tCgK shape: {cute.shape(tCgK)}') +print(f'tCgV shape: {cute.shape(tCgV)}') + +k_s = cute.slice_(k_smem_s,(None,None,None,0)) +v_s = cute.slice_(v_smem_s,(None,None,None,0)) +sK = cute.make_tensor(BFloat16, k_s.outer) +sV = cute.make_tensor(BFloat16, v_s.outer) + +b_lay = cute.make_layout(cute.slice_(cluster_layout_vmnk,(0,None,0,0)).shape) + +tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3)) +tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) + +print(f'tBsK shape: {cute.shape(tBsK)}') +print(f'tBgK shape: {cute.shape(tBgK)}') +print(f'tVsV shape: {cute.shape(tVsV)}') +print(f'tVgV shape: {cute.shape(tVgV)}') + +# Now apply the slice +tBgK_sliced = tBgK[(None,0,None,0)] +tVgV_sliced = tVgV[(None,0,None,0)] +print(f'tBgK after (None,0,None,0) shape: {cute.shape(tBgK_sliced)}') +print(f'tVgV after (None,0,None,0) shape: {cute.shape(tVgV_sliced)}') + +# What about the CUTLASS-style pre-slice? +# Try (None,None,0,0) — keeps first 2 modes, fixes last 2 +# tBgK_refstyle = tBgK[(None,None,0,0)] +# print(f'tBgK after (None,None,0,0) shape: {cute.shape(tBgK_refstyle)}') + +n_kv_tiles = cute.size(gK, mode=[3]) +print(f'n_kv_tiles = {n_kv_tiles}')