diff --git a/tests/unit/test_tmem_budget.py b/tests/unit/test_tmem_budget.py index 3fb5d759..b2321215 100644 --- a/tests/unit/test_tmem_budget.py +++ b/tests/unit/test_tmem_budget.py @@ -1,9 +1,9 @@ """TMEM column budget probe for FMHA at various head_dims. -Uses @cute.jit to construct MMA objects inside a compiled context, -where OperandMajorMode values are valid MLIR operands. +Uses @cute.jit to construct MMA objects inside a compiled context. +Only prints the critical budget numbers. """ -import torch, math, sys +import torch, math import cuda.bindings.driver as cuda import cutlass, cutlass.cute as cute, cutlass.utils as utils from cutlass.cute.nvgpu import tcgen05 @@ -32,8 +32,7 @@ class BudgetProbe: qk_as = qk_thr.partition_shape_C((128, 128)) tStS = qk_thr.make_fragment_C(qk_as) s_cols = find_tmem_tensor_col_offset(tStS) - cute.printf("hd=%d: QK C-fragment s_cols=%d", Int32(self.hd), Int32(s_cols)) - cute.printf(" qk_as=(%d,%d)", Int32(qk_as[0]), Int32(qk_as[1])) + cute.printf("hd=%d s_cols=%d", Int32(self.hd), Int32(s_cols)) # PV MMA: (128, hd), TMEM-P (P from TMEM, K-major) pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K @@ -46,7 +45,7 @@ class BudgetProbe: pv_as_tmem = pv_thr_tmem.partition_shape_C((128, self.hd)) tOtO_tmem = pv_thr_tmem.make_fragment_C(pv_as_tmem) o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem) - cute.printf(" PV TMEM-P: o_cols=%d", Int32(o_cols_tmem)) + cute.printf("hd=%d o_cols_tmem=%d", Int32(self.hd), Int32(o_cols_tmem)) # PV MMA: (128, hd), SMEM-P (P from SMEM, same a_major as Q) pv_mma_smem = utils.sm100.make_trivial_tiled_mma( @@ -58,11 +57,11 @@ class BudgetProbe: pv_as_smem = pv_thr_smem.partition_shape_C((128, self.hd)) tOtO_smem = pv_thr_smem.make_fragment_C(pv_as_smem) o_cols_smem = find_tmem_tensor_col_offset(tOtO_smem) - cute.printf(" PV SMEM-P: o_cols=%d", Int32(o_cols_smem)) + cute.printf("hd=%d o_cols_smem=%d", Int32(self.hd), Int32(o_cols_smem)) - # P columns in TMEM + # P columns in TMEM (TMEM-P path only) p_cols = 128 * 16 // 32 # pv_mma_tiler[2] * bf16_width / fp32_width - cute.printf(" P cols (FP32): %d", Int32(p_cols)) + cute.printf("hd=%d p_cols=%d", Int32(self.hd), Int32(p_cols)) # TMEM-P total tmem_p0 = 32 @@ -70,10 +69,10 @@ class BudgetProbe: o_after = s_cols if s_cols > p_end else p_end tmem_o0 = ((o_after + 31) // 32) * 32 total_tmem_p = tmem_o0 + o_cols_tmem - cute.printf(" TMEM-P total: %d / 512", Int32(total_tmem_p)) + cute.printf("hd=%d TMEM-P_total=%d", Int32(self.hd), Int32(total_tmem_p)) # SMEM-P total (O reuses S space, starting at col 0) - cute.printf(" SMEM-P total: %d / 512", Int32(o_cols_smem)) + cute.printf("hd=%d SMEM-P_total=%d", Int32(self.hd), Int32(o_cols_smem)) # Split-PV at hd > 256 if self.hd > 256: @@ -87,7 +86,7 @@ class BudgetProbe: pv_as_split = pv_thr_split.partition_shape_C((128, pv_n_tile)) tOtO_split = pv_thr_split.make_fragment_C(pv_as_split) o_cols_split = find_tmem_tensor_col_offset(tOtO_split) - cute.printf(" Split-PV (128, 256): o_cols=%d", Int32(o_cols_split)) + cute.printf("hd=%d split_pv_256_o_cols=%d", Int32(self.hd), Int32(o_cols_split)) def probe_hd(hd):