D1.2: simplify TMEM budget probe, fix printf args
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
"""TMEM column budget probe for FMHA at various head_dims.
|
||||
|
||||
Uses @cute.jit to construct MMA objects inside a compiled context,
|
||||
where OperandMajorMode values are valid MLIR operands.
|
||||
Uses @cute.jit to construct MMA objects inside a compiled context.
|
||||
Only prints the critical budget numbers.
|
||||
"""
|
||||
import torch, math, sys
|
||||
import torch, math
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass, cutlass.cute as cute, cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
@@ -32,8 +32,7 @@ class BudgetProbe:
|
||||
qk_as = qk_thr.partition_shape_C((128, 128))
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
cute.printf("hd=%d: QK C-fragment s_cols=%d", Int32(self.hd), Int32(s_cols))
|
||||
cute.printf(" qk_as=(%d,%d)", Int32(qk_as[0]), Int32(qk_as[1]))
|
||||
cute.printf("hd=%d s_cols=%d", Int32(self.hd), Int32(s_cols))
|
||||
|
||||
# PV MMA: (128, hd), TMEM-P (P from TMEM, K-major)
|
||||
pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K
|
||||
@@ -46,7 +45,7 @@ class BudgetProbe:
|
||||
pv_as_tmem = pv_thr_tmem.partition_shape_C((128, self.hd))
|
||||
tOtO_tmem = pv_thr_tmem.make_fragment_C(pv_as_tmem)
|
||||
o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem)
|
||||
cute.printf(" PV TMEM-P: o_cols=%d", Int32(o_cols_tmem))
|
||||
cute.printf("hd=%d o_cols_tmem=%d", Int32(self.hd), Int32(o_cols_tmem))
|
||||
|
||||
# PV MMA: (128, hd), SMEM-P (P from SMEM, same a_major as Q)
|
||||
pv_mma_smem = utils.sm100.make_trivial_tiled_mma(
|
||||
@@ -58,11 +57,11 @@ class BudgetProbe:
|
||||
pv_as_smem = pv_thr_smem.partition_shape_C((128, self.hd))
|
||||
tOtO_smem = pv_thr_smem.make_fragment_C(pv_as_smem)
|
||||
o_cols_smem = find_tmem_tensor_col_offset(tOtO_smem)
|
||||
cute.printf(" PV SMEM-P: o_cols=%d", Int32(o_cols_smem))
|
||||
cute.printf("hd=%d o_cols_smem=%d", Int32(self.hd), Int32(o_cols_smem))
|
||||
|
||||
# P columns in TMEM
|
||||
# P columns in TMEM (TMEM-P path only)
|
||||
p_cols = 128 * 16 // 32 # pv_mma_tiler[2] * bf16_width / fp32_width
|
||||
cute.printf(" P cols (FP32): %d", Int32(p_cols))
|
||||
cute.printf("hd=%d p_cols=%d", Int32(self.hd), Int32(p_cols))
|
||||
|
||||
# TMEM-P total
|
||||
tmem_p0 = 32
|
||||
@@ -70,10 +69,10 @@ class BudgetProbe:
|
||||
o_after = s_cols if s_cols > p_end else p_end
|
||||
tmem_o0 = ((o_after + 31) // 32) * 32
|
||||
total_tmem_p = tmem_o0 + o_cols_tmem
|
||||
cute.printf(" TMEM-P total: %d / 512", Int32(total_tmem_p))
|
||||
cute.printf("hd=%d TMEM-P_total=%d", Int32(self.hd), Int32(total_tmem_p))
|
||||
|
||||
# SMEM-P total (O reuses S space, starting at col 0)
|
||||
cute.printf(" SMEM-P total: %d / 512", Int32(o_cols_smem))
|
||||
cute.printf("hd=%d SMEM-P_total=%d", Int32(self.hd), Int32(o_cols_smem))
|
||||
|
||||
# Split-PV at hd > 256
|
||||
if self.hd > 256:
|
||||
@@ -87,7 +86,7 @@ class BudgetProbe:
|
||||
pv_as_split = pv_thr_split.partition_shape_C((128, pv_n_tile))
|
||||
tOtO_split = pv_thr_split.make_fragment_C(pv_as_split)
|
||||
o_cols_split = find_tmem_tensor_col_offset(tOtO_split)
|
||||
cute.printf(" Split-PV (128, 256): o_cols=%d", Int32(o_cols_split))
|
||||
cute.printf("hd=%d split_pv_256_o_cols=%d", Int32(self.hd), Int32(o_cols_split))
|
||||
|
||||
|
||||
def probe_hd(hd):
|
||||
|
||||
Reference in New Issue
Block a user