From d0aec403e42e2b6cece4b8100ad1449db2ad5c2d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 03:30:29 +0000 Subject: [PATCH] D1: Layout diagnostic for SMEM-P --- tests/unit/test_d1_layout_diag.py | 35 +++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tests/unit/test_d1_layout_diag.py diff --git a/tests/unit/test_d1_layout_diag.py b/tests/unit/test_d1_layout_diag.py new file mode 100644 index 00000000..981b82dd --- /dev/null +++ b/tests/unit/test_d1_layout_diag.py @@ -0,0 +1,35 @@ +"""D1: Print SMEM-P layout diagnostics for hd=128.""" +import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.torch as ct +import cuda.bindings.driver as cuda +from cutlass.cute.nvgpu import tcgen05 +from cutlass import Float32, BFloat16 +from cutlass.utils import LayoutEnum + +for hd in [64, 128, 256]: + q = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda') + + a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode() + b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode() + qk_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM) + pv_n_tile = min(hd, 256) + pv_a_major = a_major # SMEM-P path uses a_major + pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, pv_a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128,pv_n_tile), tcgen05.OperandSource.SMEM) + + qk_mma_tiler = (128, 128, cute.size(qk_mma.shape_mnk, mode=[2]) * 4) + pv_mma_tiler = (128, pv_n_tile, cute.size(pv_mma.shape_mnk, mode=[2]) * (128 // cute.size(pv_mma.shape_mnk, mode=[2]))) + + p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) + p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) + + # QK C-fragment layout + qk_thr = qk_mma.get_slice(0) + qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + + print(f'--- hd={hd} ---') + print(f' p_tmem_s.outer = {p_tmem_s.outer}') + print(f' p_smem_s.outer = {p_smem_s.outer}') + print(f' tStS.layout = {tStS.layout}') + print(f' pv_mma_tiler = {pv_mma_tiler}') + print()