diff --git a/tests/unit/test_tmem_budget.py b/tests/unit/test_tmem_budget.py index 4d7e8422..c9dc694f 100644 --- a/tests/unit/test_tmem_budget.py +++ b/tests/unit/test_tmem_budget.py @@ -1,7 +1,6 @@ """TMEM column budget probe for FMHA at various head_dims. -Prints find_tmem_tensor_col_offset(tOtO) and related shapes so we can -plan the SMEM-P path and verify TMEM fits in 512 columns at hd=512. +Uses real tensors to get correct OperandMajorMode values, same as FmhaKernel. """ import torch, math import cutlass, cutlass.cute as cute, cutlass.utils as utils @@ -9,22 +8,34 @@ from cutlass.cute.nvgpu import tcgen05 from cutlass import Float32, BFloat16, Int32 from cutlass.utils import LayoutEnum from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cutlass.torch as ct -# Use the same major modes as FmhaKernel: -# a_major = OperandMajorMode.M (Q is M-major, typical for row-major Q) -# b_major = OperandMajorMode.K (K is K-major, typical for column-major K) -# For PV TMEM-P: a_major = OperandMajorMode.K (P in TMEM, K-major) -# For PV SMEM-P: a_major = OperandMajorMode.M (P from SMEM, M-major) - -MN = cute.nvgpu.OperandMajorMode.MN -K = cute.nvgpu.OperandMajorMode.K def probe_hd(hd): print(f"\n=== HEAD_DIM={hd} ===") + m = 128 # M tile + n = 128 # KV length (s_k) - # QK MMA: always (128, 128), SMEM A + SMEM B + # Create dummy tensors to extract major modes (same as FmhaKernel) + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + + # V with FMHA layout: (head_dim, s_k, 1) + v_fmha_layout = cute.make_layout( + (hd, n, 1), stride=(1, hd, hd * n), + ) + # Get major modes from the actual tensors + a_major = LayoutEnum.from_tensor(mQ).mma_major_mode() + b_major = LayoutEnum.from_tensor(mK).mma_major_mode() + + # QK MMA: always (128, 128) qk_mma = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, MN, K, + BFloat16, BFloat16, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM, ) @@ -35,8 +46,9 @@ def probe_hd(hd): print(f" QK C-fragment: qk_as={qk_as}, tStS.layout shape={cute.shape(tStS)}, s_cols={s_cols}") # PV MMA: (128, hd), TMEM-P (P from TMEM, K-major) + pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K pv_mma_tmem = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, K, K, + BFloat16, BFloat16, pv_a_major_tmem, b_major, Float32, tcgen05.CtaGroup.ONE, (128, hd), tcgen05.OperandSource.TMEM, ) @@ -46,9 +58,9 @@ def probe_hd(hd): o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem) print(f" PV C-fragment (TMEM-P): pv_as={pv_as_tmem}, tOtO.layout shape={cute.shape(tOtO_tmem)}, o_cols={o_cols_tmem}") - # PV MMA: (128, hd), SMEM-P (P from SMEM, M-major) + # PV MMA: (128, hd), SMEM-P (P from SMEM, same a_major as Q) pv_mma_smem = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, MN, K, + BFloat16, BFloat16, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128, hd), tcgen05.OperandSource.SMEM, ) @@ -77,8 +89,7 @@ def probe_hd(hd): total_tmem_p = tmem_o0_tmem_p + o_cols_tmem print(f" TMEM-P total: S(0) + P({tmem_p0}) + O({tmem_o0_tmem_p}) + O_size({o_cols_tmem}) = {total_tmem_p} / 512 cols {'✅' if total_tmem_p <= 512 else '❌ OVER BUDGET'}") - # SMEM-P: P not in TMEM. S and O share TMEM (sequential). - # After softmax consumes S, PV writes O starting at col 0. + # SMEM-P: P not in TMEM. S and O sequential (after softmax, S is dead, O reuses S space). total_smem_p = o_cols_smem print(f" SMEM-P total (O at 0, reuses S): {total_smem_p} / 512 cols {'✅' if total_smem_p <= 512 else '❌ OVER BUDGET'}") @@ -86,7 +97,7 @@ def probe_hd(hd): if hd > 256: pv_n_tile = 256 pv_mma_split = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, MN, K, + BFloat16, BFloat16, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile), tcgen05.OperandSource.SMEM, )