From 1c20b826d9fa627909eafef3c28546b4f09ced95 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 06:39:27 +0000 Subject: [PATCH] D1.2: TMEM budget probe using @cute.jit for MLIR context --- tests/unit/test_tmem_budget.py | 176 +++++++++++++++++---------------- 1 file changed, 89 insertions(+), 87 deletions(-) diff --git a/tests/unit/test_tmem_budget.py b/tests/unit/test_tmem_budget.py index a1eb3a96..a03eff32 100644 --- a/tests/unit/test_tmem_budget.py +++ b/tests/unit/test_tmem_budget.py @@ -1,8 +1,9 @@ """TMEM column budget probe for FMHA at various head_dims. -Uses real tensors to get correct OperandMajorMode values, same as FmhaKernel. +Uses @cute.jit to construct MMA objects inside a compiled context, +where OperandMajorMode values are valid MLIR operands. """ -import torch, math +import torch, math, sys import cutlass, cutlass.cute as cute, cutlass.utils as utils from cutlass.cute.nvgpu import tcgen05 from cutlass import Float32, BFloat16, Int32 @@ -11,98 +12,99 @@ from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset import cutlass.torch as ct -def probe_hd(hd): - print(f"\n=== HEAD_DIM={hd} ===") - m = 128 # M tile - n = 128 # KV length (s_k) +class BudgetProbe: + def __init__(self, head_dim): + self.hd = head_dim - # Create dummy tensors to extract major modes (same as FmhaKernel) + @cute.jit + def __call__(self, mQ, mK, stream): + a_major = LayoutEnum.from_tensor(mQ).mma_major_mode() + b_major = LayoutEnum.from_tensor(mK).mma_major_mode() + + # QK MMA: always (128, 128) + qk_mma = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, a_major, b_major, + Float32, tcgen05.CtaGroup.ONE, (128, 128), + tcgen05.OperandSource.SMEM, + ) + qk_thr = qk_mma.get_slice(0) + qk_as = qk_thr.partition_shape_C((128, 128)) + tStS = qk_thr.make_fragment_C(qk_as) + s_cols = find_tmem_tensor_col_offset(tStS) + cute.printf("hd=%d: QK C-fragment s_cols=%d", Int32(self.hd), Int32(s_cols)) + cute.printf(" qk_as=(%d,%d)", Int32(qk_as[0]), Int32(qk_as[1])) + + # PV MMA: (128, hd), TMEM-P (P from TMEM, K-major) + pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K + pv_mma_tmem = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, pv_a_major_tmem, b_major, + Float32, tcgen05.CtaGroup.ONE, (128, self.hd), + tcgen05.OperandSource.TMEM, + ) + pv_thr_tmem = pv_mma_tmem.get_slice(0) + pv_as_tmem = pv_thr_tmem.partition_shape_C((128, self.hd)) + tOtO_tmem = pv_thr_tmem.make_fragment_C(pv_as_tmem) + o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem) + cute.printf(" PV TMEM-P: o_cols=%d", Int32(o_cols_tmem)) + + # PV MMA: (128, hd), SMEM-P (P from SMEM, same a_major as Q) + pv_mma_smem = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, a_major, b_major, + Float32, tcgen05.CtaGroup.ONE, (128, self.hd), + tcgen05.OperandSource.SMEM, + ) + pv_thr_smem = pv_mma_smem.get_slice(0) + pv_as_smem = pv_thr_smem.partition_shape_C((128, self.hd)) + tOtO_smem = pv_thr_smem.make_fragment_C(pv_as_smem) + o_cols_smem = find_tmem_tensor_col_offset(tOtO_smem) + cute.printf(" PV SMEM-P: o_cols=%d", Int32(o_cols_smem)) + + # P columns in TMEM + p_cols = 128 * 16 // 32 # pv_mma_tiler[2] * bf16_width / fp32_width + cute.printf(" P cols (FP32): %d", Int32(p_cols)) + + # TMEM-P total + tmem_p0 = 32 + p_end = tmem_p0 + p_cols + o_after = s_cols if s_cols > p_end else p_end + tmem_o0 = ((o_after + 31) // 32) * 32 + total_tmem_p = tmem_o0 + o_cols_tmem + cute.printf(" TMEM-P total: %d / 512", Int32(total_tmem_p)) + + # SMEM-P total (O reuses S space, starting at col 0) + cute.printf(" SMEM-P total: %d / 512", Int32(o_cols_smem)) + + # Split-PV at hd > 256 + if self.hd > 256: + pv_n_tile = 256 + pv_mma_split = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, a_major, b_major, + Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile), + tcgen05.OperandSource.SMEM, + ) + pv_thr_split = pv_mma_split.get_slice(0) + pv_as_split = pv_thr_split.partition_shape_C((128, pv_n_tile)) + tOtO_split = pv_thr_split.make_fragment_C(pv_as_split) + o_cols_split = find_tmem_tensor_col_offset(tOtO_split) + cute.printf(" Split-PV (128, 256): o_cols=%d", Int32(o_cols_split)) + + +def probe_hd(hd): + print(f"\n=== HEAD_DIM={hd} ===", flush=True) + m, n = 128, 128 q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda') - v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda') - c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + stream = cutlass.cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # Get major modes from the actual tensors - a_major = LayoutEnum.from_tensor(mQ).mma_major_mode() - b_major = LayoutEnum.from_tensor(mK).mma_major_mode() - - # QK MMA: always (128, 128) - qk_mma = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, a_major, b_major, - Float32, tcgen05.CtaGroup.ONE, (128, 128), - tcgen05.OperandSource.SMEM, - ) - qk_thr = qk_mma.get_slice(0) - qk_as = qk_thr.partition_shape_C((128, 128)) - tStS = qk_thr.make_fragment_C(qk_as) - s_cols = find_tmem_tensor_col_offset(tStS) - print(f" QK C-fragment: qk_as={qk_as}, tStS.layout shape={cute.shape(tStS)}, s_cols={s_cols}") - - # PV MMA: (128, hd), TMEM-P (P from TMEM, K-major) - pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K - pv_mma_tmem = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, pv_a_major_tmem, b_major, - Float32, tcgen05.CtaGroup.ONE, (128, hd), - tcgen05.OperandSource.TMEM, - ) - pv_thr_tmem = pv_mma_tmem.get_slice(0) - pv_as_tmem = pv_thr_tmem.partition_shape_C((128, hd)) - tOtO_tmem = pv_thr_tmem.make_fragment_C(pv_as_tmem) - o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem) - print(f" PV C-fragment (TMEM-P): pv_as={pv_as_tmem}, tOtO.layout shape={cute.shape(tOtO_tmem)}, o_cols={o_cols_tmem}") - - # PV MMA: (128, hd), SMEM-P (P from SMEM, same a_major as Q) - pv_mma_smem = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, a_major, b_major, - Float32, tcgen05.CtaGroup.ONE, (128, hd), - tcgen05.OperandSource.SMEM, - ) - pv_thr_smem = pv_mma_smem.get_slice(0) - pv_as_smem = pv_thr_smem.partition_shape_C((128, hd)) - tOtO_smem = pv_thr_smem.make_fragment_C(pv_as_smem) - o_cols_smem = find_tmem_tensor_col_offset(tOtO_smem) - print(f" PV C-fragment (SMEM-P): pv_as={pv_as_smem}, tOtO.layout shape={cute.shape(tOtO_smem)}, o_cols={o_cols_smem}") - - # P columns in TMEM (TMEM-P path only) - pv_mma_tiler = (128, hd, 128) - p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width - print(f" P cols (FP32): {p_cols_fp32} (pv_mma_tiler[2]={pv_mma_tiler[2]})") - - # TMEM budget calculations - print(f" --- TMEM Budget ---") - print(f" S cols: {s_cols}") - print(f" P cols (TMEM-P): {p_cols_fp32}") - print(f" O cols: {o_cols_tmem} (TMEM-P) / {o_cols_smem} (SMEM-P)") - - # TMEM-P: S at 0, P at 32, O after max(S, P_end) - tmem_p0 = 32 - p_end = tmem_p0 + p_cols_fp32 - o_after = max(s_cols, p_end) - tmem_o0_tmem_p = ((o_after + 31) // 32) * 32 - total_tmem_p = tmem_o0_tmem_p + o_cols_tmem - print(f" TMEM-P total: S(0) + P({tmem_p0}) + O({tmem_o0_tmem_p}) + O_size({o_cols_tmem}) = {total_tmem_p} / 512 cols {'✅' if total_tmem_p <= 512 else '❌ OVER BUDGET'}") - - # SMEM-P: P not in TMEM. S and O sequential (after softmax, S is dead, O reuses S space). - total_smem_p = o_cols_smem - print(f" SMEM-P total (O at 0, reuses S): {total_smem_p} / 512 cols {'✅' if total_smem_p <= 512 else '❌ OVER BUDGET'}") - - # Split-PV: if hd > 256, process (128, 256) PV tiles - if hd > 256: - pv_n_tile = 256 - pv_mma_split = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, a_major, b_major, - Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile), - tcgen05.OperandSource.SMEM, - ) - pv_thr_split = pv_mma_split.get_slice(0) - pv_as_split = pv_thr_split.partition_shape_C((128, pv_n_tile)) - tOtO_split = pv_thr_split.make_fragment_C(pv_as_split) - o_cols_split = find_tmem_tensor_col_offset(tOtO_split) - total_split = o_cols_split - print(f" Split-PV (128, {pv_n_tile}) O cols: {o_cols_split}, total SMEM-P: {total_split} / 512 {'✅' if total_split <= 512 else '❌ OVER BUDGET'}") + probe = BudgetProbe(head_dim=hd) + print(f' Compiling hd={hd}...', flush=True) + compiled = cute.compile(probe, mQ, mK, stream) + compiled(mQ, mK, stream) + torch.cuda.synchronize() + print(f' Done.', flush=True) if __name__ == '__main__':