D1.2: simplify TMEM budget probe, fix printf args

This commit is contained in:
2026-05-23 06:40:55 +00:00
parent 640bae606e
commit 97899b42a3

View File

@@ -1,9 +1,9 @@
"""TMEM column budget probe for FMHA at various head_dims.
Uses @cute.jit to construct MMA objects inside a compiled context,
where OperandMajorMode values are valid MLIR operands.
Uses @cute.jit to construct MMA objects inside a compiled context.
Only prints the critical budget numbers.
"""
import torch, math, sys
import torch, math
import cuda.bindings.driver as cuda
import cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
@@ -32,8 +32,7 @@ class BudgetProbe:
qk_as = qk_thr.partition_shape_C((128, 128))
tStS = qk_thr.make_fragment_C(qk_as)
s_cols = find_tmem_tensor_col_offset(tStS)
cute.printf("hd=%d: QK C-fragment s_cols=%d", Int32(self.hd), Int32(s_cols))
cute.printf(" qk_as=(%d,%d)", Int32(qk_as[0]), Int32(qk_as[1]))
cute.printf("hd=%d s_cols=%d", Int32(self.hd), Int32(s_cols))
# PV MMA: (128, hd), TMEM-P (P from TMEM, K-major)
pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K
@@ -46,7 +45,7 @@ class BudgetProbe:
pv_as_tmem = pv_thr_tmem.partition_shape_C((128, self.hd))
tOtO_tmem = pv_thr_tmem.make_fragment_C(pv_as_tmem)
o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem)
cute.printf(" PV TMEM-P: o_cols=%d", Int32(o_cols_tmem))
cute.printf("hd=%d o_cols_tmem=%d", Int32(self.hd), Int32(o_cols_tmem))
# PV MMA: (128, hd), SMEM-P (P from SMEM, same a_major as Q)
pv_mma_smem = utils.sm100.make_trivial_tiled_mma(
@@ -58,11 +57,11 @@ class BudgetProbe:
pv_as_smem = pv_thr_smem.partition_shape_C((128, self.hd))
tOtO_smem = pv_thr_smem.make_fragment_C(pv_as_smem)
o_cols_smem = find_tmem_tensor_col_offset(tOtO_smem)
cute.printf(" PV SMEM-P: o_cols=%d", Int32(o_cols_smem))
cute.printf("hd=%d o_cols_smem=%d", Int32(self.hd), Int32(o_cols_smem))
# P columns in TMEM
# P columns in TMEM (TMEM-P path only)
p_cols = 128 * 16 // 32 # pv_mma_tiler[2] * bf16_width / fp32_width
cute.printf(" P cols (FP32): %d", Int32(p_cols))
cute.printf("hd=%d p_cols=%d", Int32(self.hd), Int32(p_cols))
# TMEM-P total
tmem_p0 = 32
@@ -70,10 +69,10 @@ class BudgetProbe:
o_after = s_cols if s_cols > p_end else p_end
tmem_o0 = ((o_after + 31) // 32) * 32
total_tmem_p = tmem_o0 + o_cols_tmem
cute.printf(" TMEM-P total: %d / 512", Int32(total_tmem_p))
cute.printf("hd=%d TMEM-P_total=%d", Int32(self.hd), Int32(total_tmem_p))
# SMEM-P total (O reuses S space, starting at col 0)
cute.printf(" SMEM-P total: %d / 512", Int32(o_cols_smem))
cute.printf("hd=%d SMEM-P_total=%d", Int32(self.hd), Int32(o_cols_smem))
# Split-PV at hd > 256
if self.hd > 256:
@@ -87,7 +86,7 @@ class BudgetProbe:
pv_as_split = pv_thr_split.partition_shape_C((128, pv_n_tile))
tOtO_split = pv_thr_split.make_fragment_C(pv_as_split)
o_cols_split = find_tmem_tensor_col_offset(tOtO_split)
cute.printf(" Split-PV (128, 256): o_cols=%d", Int32(o_cols_split))
cute.printf("hd=%d split_pv_256_o_cols=%d", Int32(self.hd), Int32(o_cols_split))
def probe_hd(hd):