From a2d0dec7bbfd40ff629b2881a9efc3d929665755 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 06:33:26 +0000 Subject: [PATCH] D1.2: TMEM budget probe script for hd=64,128,256,512 --- tests/unit/test_tmem_budget.py | 100 +++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 tests/unit/test_tmem_budget.py diff --git a/tests/unit/test_tmem_budget.py b/tests/unit/test_tmem_budget.py new file mode 100644 index 00000000..3a70e603 --- /dev/null +++ b/tests/unit/test_tmem_budget.py @@ -0,0 +1,100 @@ +"""TMEM column budget probe for FMHA at various head_dims. + +Prints find_tmem_tensor_col_offset(tOtO) and related shapes so we can +plan the SMEM-P path and verify TMEM fits in 512 columns at hd=512. +""" +import torch, math +import cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.nvgpu.tcgen05 as tcgen05 +from cutlass import BFloat16, Float32, LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset + +def probe_hd(hd): + print(f"\n=== HEAD_DIM={hd} ===") + + # QK MMA: always (128, 128) + qk_mma = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, + LayoutEnum.ROW_MAJOR, LayoutEnum.ROW_MAJOR, + Float32, tcgen05.CtaGroup.ONE, (128, 128), + tcgen05.OperandSource.SMEM, + ) + qk_thr = qk_mma.get_slice(0) + qk_as = qk_thr.partition_shape_C((128, 128)) + tStS = qk_thr.make_fragment_C(qk_as) + s_cols = find_tmem_tensor_col_offset(tStS) + print(f" QK C-fragment: qk_as={qk_as}, tStS.layout shape={cute.shape(tStS)}, s_cols={s_cols}") + + # PV MMA: (128, hd) + # TMEM-P path + pv_mma_tmem = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, + LayoutEnum.ROW_MAJOR, LayoutEnum.ROW_MAJOR, + Float32, tcgen05.CtaGroup.ONE, (128, hd), + tcgen05.OperandSource.TMEM, + ) + pv_thr_tmem = pv_mma_tmem.get_slice(0) + pv_as_tmem = pv_thr_tmem.partition_shape_C((128, hd)) + tOtO_tmem = pv_thr_tmem.make_fragment_C(pv_as_tmem) + o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem) + print(f" PV C-fragment (TMEM-P): pv_as={pv_as_tmem}, tOtO.layout shape={cute.shape(tOtO_tmem)}, o_cols={o_cols_tmem}") + + # SMEM-P path (PV from SMEM) + pv_mma_smem = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, + LayoutEnum.ROW_MAJOR, LayoutEnum.ROW_MAJOR, + Float32, tcgen05.CtaGroup.ONE, (128, hd), + tcgen05.OperandSource.SMEM, + ) + pv_thr_smem = pv_mma_smem.get_slice(0) + pv_as_smem = pv_thr_smem.partition_shape_C((128, hd)) + tOtO_smem = pv_thr_smem.make_fragment_C(pv_as_smem) + o_cols_smem = find_tmem_tensor_col_offset(tOtO_smem) + print(f" PV C-fragment (SMEM-P): pv_as={pv_as_smem}, tOtO.layout shape={cute.shape(tOtO_smem)}, o_cols={o_cols_smem}") + + # P columns in TMEM (TMEM-P path only) + # pv_mma_tiler[2] is the K-dim of the PV MMA, which determines P cols + # At hd=64: pv_mma_tiler = (128, 64, 128), pv_mma_tiler[2] = 128 + # p_cols_fp32 = 128 * 16 / 32 = 64 + pv_mma_tiler = (128, hd, 128) # assuming s_k=128 + p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width + print(f" P cols (FP32): {p_cols_fp32} (pv_mma_tiler[2]={pv_mma_tiler[2]})") + + # TMEM budget calculation + print(f" --- TMEM Budget ---") + print(f" S cols: {s_cols}") + print(f" P cols (TMEM-P): {p_cols_fp32}") + print(f" O cols: {o_cols_tmem} (TMEM-P) / {o_cols_smem} (SMEM-P)") + + # TMEM-P: S at 0, P at 32, O after max(S, P_end) + tmem_p0 = 32 + p_end = tmem_p0 + p_cols_fp32 + o_after = max(s_cols, p_end) + tmem_o0_tmem_p = ((o_after + 31) // 32) * 32 + total_tmem_p = tmem_o0_tmem_p + o_cols_tmem + print(f" TMEM-P total: S(0) + P({tmem_p0}) + O({tmem_o0_tmem_p}) + O_size({o_cols_tmem}) = {total_tmem_p} / 512 cols {'✅' if total_tmem_p <= 512 else '❌ OVER BUDGET'}") + + # SMEM-P: P not in TMEM. S and O sequential (S consumed before O written). + # Best case: O at 0 (reuses S space), total = max(s_cols, o_cols) + total_smem_p = o_cols_smem # O starts at 0 + print(f" SMEM-P total (O at 0, reuses S): {total_smem_p} / 512 cols {'✅' if total_smem_p <= 512 else '❌ OVER BUDGET'}") + + # Split-PV: if hd > 256, process (128, 256) PV tiles + if hd > 256: + pv_n_tile = 256 + pv_mma_split = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, + LayoutEnum.ROW_MAJOR, LayoutEnum.ROW_MAJOR, + Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile), + tcgen05.OperandSource.SMEM, + ) + pv_thr_split = pv_mma_split.get_slice(0) + pv_as_split = pv_thr_split.partition_shape_C((128, pv_n_tile)) + tOtO_split = pv_thr_split.make_fragment_C(pv_as_split) + o_cols_split = find_tmem_tensor_col_offset(tOtO_split) + total_split = o_cols_split + print(f" Split-PV (128, {pv_n_tile}) O cols: {o_cols_split}, total SMEM-P: {total_split} / 512 {'✅' if total_split <= 512 else '❌ OVER BUDGET'}") + + +if __name__ == '__main__': + for hd in [64, 128, 256, 512]: + probe_hd(hd)