From 760d120b1cd9d0324c9dcce699bbdb4c8350c604 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 06:35:55 +0000 Subject: [PATCH] fix: typo + OperandMajorMode for TMEM budget probe --- tests/unit/test_tmem_budget.py | 43 +++++++++++++++++----------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/tests/unit/test_tmem_budget.py b/tests/unit/test_tmem_budget.py index fd1b08e5..e0ede7e2 100644 --- a/tests/unit/test_tmem_budget.py +++ b/tests/unit/test_tmem_budget.py @@ -6,17 +6,25 @@ plan the SMEM-P path and verify TMEM fits in 512 columns at hd=512. import torch, math import cutlass, cutlass.cute as cute, cutlass.utils as utils from cutlass.cute.nvgpu import tcgen05 -from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass import Float32, BFloat16, Int32 from cutlass.utils import LayoutEnum from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +# Use the same major modes as FmhaKernel: +# a_major = OperandMajorMode.M (Q is M-major, typical for row-major Q) +# b_major = OperandMajorMode.K (K is K-major, typical for column-major K) +# For PV TMEM-P: a_major = OperandMajorMode.K (P in TMEM, K-major) +# For PV SMEM-P: a_major = OperandMajorMode.M (P from SMEM, M-major) + +M = cute.nvgpu.OperandMajorMode.M +K = cute.nvgpu.OperandMajorMode.K + def probe_hd(hd): print(f"\n=== HEAD_DIM={hd} ===") - # QK MMA: always (128, 128) + # QK MMA: always (128, 128), SMEM A + SMEM B qk_mma = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, - LayoutEnum.ROW_MAJOR, LayoutEnum.ROW_MAJOR, + BFloat16, BFloat16, M, K, Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM, ) @@ -26,11 +34,9 @@ def probe_hd(hd): s_cols = find_tmem_tensor_col_offset(tStS) print(f" QK C-fragment: qk_as={qk_as}, tStS.layout shape={cute.shape(tStS)}, s_cols={s_cols}") - # PV MMA: (128, hd) - # TMEM-P path + # PV MMA: (128, hd), TMEM-P (P from TMEM, K-major) pv_mma_tmem = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, - LayoutEnum.ROW_MAJOR, LayoutEnum.ROW_MAJOR, + BFloat16, BFloat16, K, K, Float32, tcgen05.CtaGroup.ONE, (128, hd), tcgen05.OperandSource.TMEM, ) @@ -40,10 +46,9 @@ def probe_hd(hd): o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem) print(f" PV C-fragment (TMEM-P): pv_as={pv_as_tmem}, tOtO.layout shape={cute.shape(tOtO_tmem)}, o_cols={o_cols_tmem}") - # SMEM-P path (PV from SMEM) + # PV MMA: (128, hd), SMEM-P (P from SMEM, M-major) pv_mma_smem = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, - LayoutEnum.ROW_MAJOR, LayoutEnum.ROW_MAJOR, + BFloat16, BFloat16, M, K, Float32, tcgen05.CtaGroup.ONE, (128, hd), tcgen05.OperandSource.SMEM, ) @@ -54,14 +59,11 @@ def probe_hd(hd): print(f" PV C-fragment (SMEM-P): pv_as={pv_as_smem}, tOtO.layout shape={cute.shape(tOtO_smem)}, o_cols={o_cols_smem}") # P columns in TMEM (TMEM-P path only) - # pv_mma_tiler[2] is the K-dim of the PV MMA, which determines P cols - # At hd=64: pv_mma_tiler = (128, 64, 128), pv_mma_tiler[2] = 128 - # p_cols_fp32 = 128 * 16 / 32 = 64 - pv_mma_tiler = (128, hd, 128) # assuming s_k=128 + pv_mma_tiler = (128, hd, 128) p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width print(f" P cols (FP32): {p_cols_fp32} (pv_mma_tiler[2]={pv_mma_tiler[2]})") - # TMEM budget calculation + # TMEM budget calculations print(f" --- TMEM Budget ---") print(f" S cols: {s_cols}") print(f" P cols (TMEM-P): {p_cols_fp32}") @@ -75,17 +77,16 @@ def probe_hd(hd): total_tmem_p = tmem_o0_tmem_p + o_cols_tmem print(f" TMEM-P total: S(0) + P({tmem_p0}) + O({tmem_o0_tmem_p}) + O_size({o_cols_tmem}) = {total_tmem_p} / 512 cols {'✅' if total_tmem_p <= 512 else '❌ OVER BUDGET'}") - # SMEM-P: P not in TMEM. S and O sequential (S consumed before O written). - # Best case: O at 0 (reuses S space), total = max(s_cols, o_cols) - total_smem_p = o_cols_smem # O starts at 0 + # SMEM-P: P not in TMEM. S and O share TMEM (sequential). + # After softmax consumes S, PV writes O starting at col 0. + total_smem_p = o_cols_smem print(f" SMEM-P total (O at 0, reuses S): {total_smem_p} / 512 cols {'✅' if total_smem_p <= 512 else '❌ OVER BUDGET'}") # Split-PV: if hd > 256, process (128, 256) PV tiles if hd > 256: pv_n_tile = 256 pv_mma_split = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, - LayoutEnum.ROW_MAJOR, LayoutEnum.ROW_MAJOR, + BFloat16, BFloat16, M, K, Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile), tcgen05.OperandSource.SMEM, )