From 2deb28827a518c1857402bd3cf4030d53366e42e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 18:43:55 +0000 Subject: [PATCH] Diag: TMA shapes with hardcoded major modes --- tests/diag_tma_shapes.py | 67 ++++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/tests/diag_tma_shapes.py b/tests/diag_tma_shapes.py index f7cab341..1aa0e694 100644 --- a/tests/diag_tma_shapes.py +++ b/tests/diag_tma_shapes.py @@ -1,9 +1,7 @@ -"""Diagnostic: print TMA partition tensor shapes for multi-tile K/V. -Simplified to avoid JIT-only constructs.""" +"""Diagnostic: print TMA partition tensor shapes for multi-tile K/V.""" import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass import Float32, BFloat16, Int32 -from cutlass.utils import LayoutEnum import cutlass.torch as ct import math @@ -19,8 +17,9 @@ mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) -qk_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, LayoutEnum.from_tensor(mQ).mma_major_mode(), LayoutEnum.from_tensor(mK).mma_major_mode(), Float32, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM) -pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(mV).mma_major_mode(), Float32, tcgen05.CtaGroup.ONE, (128,HEAD_DIM), tcgen05.OperandSource.TMEM) +# Hardcode major modes since LayoutEnum.from_tensor needs JIT context +qk_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.K, Float32, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM) +pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.MN, Float32, tcgen05.CtaGroup.ONE, (128,HEAD_DIM), tcgen05.OperandSource.TMEM) qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) qk_mma_tiler = (128, 128, qk_ik * 4) @@ -30,38 +29,26 @@ cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_i print(f'qk_mma_tiler: {qk_mma_tiler}') print(f'pv_mma_tiler: {pv_mma_tiler}') -print(f'cluster_layout_vmnk: {cute.shape(cluster_layout_vmnk)}') -kv_stage = 2; q_stage = 1 +kv_stage = 2 k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, BFloat16, kv_stage) v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, BFloat16, kv_stage) -q_s = cute.slice_(k_smem_s,(None,None,None,0)) # just to get the per-stage shape k_s = cute.slice_(k_smem_s,(None,None,None,0)) v_s = cute.slice_(v_smem_s,(None,None,None,0)) -print(f'k_smem_s outer shape: {cute.shape(k_smem_s.outer)}') -print(f'k_smem_s inner shape: {cute.shape(k_smem_s.inner)}') -print(f'v_smem_s outer shape: {cute.shape(v_smem_s.outer)}') -print(f'v_smem_s inner shape: {cute.shape(v_smem_s.inner)}') -print(f'k_s (per-stage) shape: {cute.shape(k_s)}') -print(f'v_s (per-stage) shape: {cute.shape(v_s)}') +print(f'k_s shape: {cute.shape(k_s)}') +print(f'v_s shape: {cute.shape(v_s)}') tma_k, mK_tma = cute.nvgpu.make_tiled_tma_atom_B( utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, qk_mma.thr_id), mK, k_s, qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape ) -# For V, we use the raw mV (not v_fmha) just for shape diag -# The FMHA layout matters for the actual kernel but for shape analysis the mV is sufficient -v_major = LayoutEnum.from_tensor(mV).mma_major_mode() -print(f'v_major (raw mV): {v_major}') gK = cute.local_tile(mK_tma, cute.slice_(qk_mma_tiler,(0,None,None)),(None,None,None)) - print(f'mK_tma shape: {cute.shape(mK_tma)}') print(f'gK shape: {cute.shape(gK)}') -n_kv_tiles = cute.size(gK, mode=[3]) -print(f'n_kv_tiles (mode 3): {n_kv_tiles}') +print(f'n_kv_tiles: {cute.size(gK, mode=[3])}') qk_thr = qk_mma.get_slice(0) tCgK = qk_thr.partition_B(gK) @@ -69,7 +56,6 @@ print(f'tCgK shape: {cute.shape(tCgK)}') sK = cute.make_tensor(BFloat16, k_s) b_lay = cute.make_layout(cute.slice_(cluster_layout_vmnk,(0,None,0,0)).shape) -print(f'b_lay: {cute.shape(b_lay)}') tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3)) @@ -78,20 +64,35 @@ print(f'tBgK shape: {cute.shape(tBgK)}') print(f'tBsK layout: {tBsK.layout}') print(f'tBgK layout: {tBgK.layout}') -# Apply the problematic slice -tBgK_sliced = tBgK[(None,0,None,0)] -print(f'tBgK after (None,0,None,0) shape: {cute.shape(tBgK_sliced)}') - -# Try different slices to understand the mode meanings +# Test slices for desc, sl in [ - ("(None,0,0,0)", (None,0,0,0)), - ("(0,None,0,0)", (0,None,0,0)), - ("(None,None,0,0)", (None,None,0,0)), - ("(0,0,None,0)", (0,0,None,0)), - ("(None,0,None,None)", (None,0,None,None)), + ("(None,0,None,0)", (None,0,None,0)), # Current (broken) + ("(None,None,0,0)", (None,None,0,0)), # CUTLASS reference style + ("(0,None,None,0)", (0,None,None,0)), # Alternative ]: try: result = tBgK[sl] print(f'tBgK after {desc} shape: {cute.shape(result)}') except Exception as e: - print(f'tBgK after {desc}: ERROR {e}') + print(f'tBgK after {desc}: ERROR {type(e).__name__}: {e}') + +# Also check V +tma_v, mV_tma = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, pv_mma.thr_id), + mV, v_s, pv_mma_tiler, pv_mma, cluster_layout_vmnk.shape +) +gV = cute.local_tile(mV_tma, cute.slice_(pv_mma_tiler,(0,None,None)),(None,None,None)) +pv_thr = pv_mma.get_slice(0) +tCgV = pv_thr.partition_B(gV) +sV = cute.make_tensor(BFloat16, v_s) +tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) +print(f'tVgV shape: {cute.shape(tVgV)}') +for desc, sl in [ + ("(None,0,None,0)", (None,0,None,0)), + ("(None,None,0,0)", (None,None,0,0)), +]: + try: + result = tVgV[sl] + print(f'tVgV after {desc} shape: {cute.shape(result)}') + except Exception as e: + print(f'tVgV after {desc}: ERROR {type(e).__name__}: {e}')