D1.2: TMEM budget probe using @cute.jit for MLIR context

This commit is contained in:
2026-05-23 06:39:27 +00:00
parent 6575e83f6d
commit 1c20b826d9

View File

@@ -1,8 +1,9 @@
"""TMEM column budget probe for FMHA at various head_dims.
Uses real tensors to get correct OperandMajorMode values, same as FmhaKernel.
Uses @cute.jit to construct MMA objects inside a compiled context,
where OperandMajorMode values are valid MLIR operands.
"""
import torch, math
import torch, math, sys
import cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16, Int32
@@ -11,98 +12,99 @@ from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cutlass.torch as ct
def probe_hd(hd):
print(f"\n=== HEAD_DIM={hd} ===")
m = 128 # M tile
n = 128 # KV length (s_k)
class BudgetProbe:
def __init__(self, head_dim):
self.hd = head_dim
# Create dummy tensors to extract major modes (same as FmhaKernel)
@cute.jit
def __call__(self, mQ, mK, stream):
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
# QK MMA: always (128, 128)
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128),
tcgen05.OperandSource.SMEM,
)
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C((128, 128))
tStS = qk_thr.make_fragment_C(qk_as)
s_cols = find_tmem_tensor_col_offset(tStS)
cute.printf("hd=%d: QK C-fragment s_cols=%d", Int32(self.hd), Int32(s_cols))
cute.printf(" qk_as=(%d,%d)", Int32(qk_as[0]), Int32(qk_as[1]))
# PV MMA: (128, hd), TMEM-P (P from TMEM, K-major)
pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K
pv_mma_tmem = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, pv_a_major_tmem, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, self.hd),
tcgen05.OperandSource.TMEM,
)
pv_thr_tmem = pv_mma_tmem.get_slice(0)
pv_as_tmem = pv_thr_tmem.partition_shape_C((128, self.hd))
tOtO_tmem = pv_thr_tmem.make_fragment_C(pv_as_tmem)
o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem)
cute.printf(" PV TMEM-P: o_cols=%d", Int32(o_cols_tmem))
# PV MMA: (128, hd), SMEM-P (P from SMEM, same a_major as Q)
pv_mma_smem = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, self.hd),
tcgen05.OperandSource.SMEM,
)
pv_thr_smem = pv_mma_smem.get_slice(0)
pv_as_smem = pv_thr_smem.partition_shape_C((128, self.hd))
tOtO_smem = pv_thr_smem.make_fragment_C(pv_as_smem)
o_cols_smem = find_tmem_tensor_col_offset(tOtO_smem)
cute.printf(" PV SMEM-P: o_cols=%d", Int32(o_cols_smem))
# P columns in TMEM
p_cols = 128 * 16 // 32 # pv_mma_tiler[2] * bf16_width / fp32_width
cute.printf(" P cols (FP32): %d", Int32(p_cols))
# TMEM-P total
tmem_p0 = 32
p_end = tmem_p0 + p_cols
o_after = s_cols if s_cols > p_end else p_end
tmem_o0 = ((o_after + 31) // 32) * 32
total_tmem_p = tmem_o0 + o_cols_tmem
cute.printf(" TMEM-P total: %d / 512", Int32(total_tmem_p))
# SMEM-P total (O reuses S space, starting at col 0)
cute.printf(" SMEM-P total: %d / 512", Int32(o_cols_smem))
# Split-PV at hd > 256
if self.hd > 256:
pv_n_tile = 256
pv_mma_split = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile),
tcgen05.OperandSource.SMEM,
)
pv_thr_split = pv_mma_split.get_slice(0)
pv_as_split = pv_thr_split.partition_shape_C((128, pv_n_tile))
tOtO_split = pv_thr_split.make_fragment_C(pv_as_split)
o_cols_split = find_tmem_tensor_col_offset(tOtO_split)
cute.printf(" Split-PV (128, 256): o_cols=%d", Int32(o_cols_split))
def probe_hd(hd):
print(f"\n=== HEAD_DIM={hd} ===", flush=True)
m, n = 128, 128
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
stream = cutlass.cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Get major modes from the actual tensors
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
# QK MMA: always (128, 128)
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128),
tcgen05.OperandSource.SMEM,
)
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C((128, 128))
tStS = qk_thr.make_fragment_C(qk_as)
s_cols = find_tmem_tensor_col_offset(tStS)
print(f" QK C-fragment: qk_as={qk_as}, tStS.layout shape={cute.shape(tStS)}, s_cols={s_cols}")
# PV MMA: (128, hd), TMEM-P (P from TMEM, K-major)
pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K
pv_mma_tmem = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, pv_a_major_tmem, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, hd),
tcgen05.OperandSource.TMEM,
)
pv_thr_tmem = pv_mma_tmem.get_slice(0)
pv_as_tmem = pv_thr_tmem.partition_shape_C((128, hd))
tOtO_tmem = pv_thr_tmem.make_fragment_C(pv_as_tmem)
o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem)
print(f" PV C-fragment (TMEM-P): pv_as={pv_as_tmem}, tOtO.layout shape={cute.shape(tOtO_tmem)}, o_cols={o_cols_tmem}")
# PV MMA: (128, hd), SMEM-P (P from SMEM, same a_major as Q)
pv_mma_smem = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, hd),
tcgen05.OperandSource.SMEM,
)
pv_thr_smem = pv_mma_smem.get_slice(0)
pv_as_smem = pv_thr_smem.partition_shape_C((128, hd))
tOtO_smem = pv_thr_smem.make_fragment_C(pv_as_smem)
o_cols_smem = find_tmem_tensor_col_offset(tOtO_smem)
print(f" PV C-fragment (SMEM-P): pv_as={pv_as_smem}, tOtO.layout shape={cute.shape(tOtO_smem)}, o_cols={o_cols_smem}")
# P columns in TMEM (TMEM-P path only)
pv_mma_tiler = (128, hd, 128)
p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width
print(f" P cols (FP32): {p_cols_fp32} (pv_mma_tiler[2]={pv_mma_tiler[2]})")
# TMEM budget calculations
print(f" --- TMEM Budget ---")
print(f" S cols: {s_cols}")
print(f" P cols (TMEM-P): {p_cols_fp32}")
print(f" O cols: {o_cols_tmem} (TMEM-P) / {o_cols_smem} (SMEM-P)")
# TMEM-P: S at 0, P at 32, O after max(S, P_end)
tmem_p0 = 32
p_end = tmem_p0 + p_cols_fp32
o_after = max(s_cols, p_end)
tmem_o0_tmem_p = ((o_after + 31) // 32) * 32
total_tmem_p = tmem_o0_tmem_p + o_cols_tmem
print(f" TMEM-P total: S(0) + P({tmem_p0}) + O({tmem_o0_tmem_p}) + O_size({o_cols_tmem}) = {total_tmem_p} / 512 cols {'' if total_tmem_p <= 512 else '❌ OVER BUDGET'}")
# SMEM-P: P not in TMEM. S and O sequential (after softmax, S is dead, O reuses S space).
total_smem_p = o_cols_smem
print(f" SMEM-P total (O at 0, reuses S): {total_smem_p} / 512 cols {'' if total_smem_p <= 512 else '❌ OVER BUDGET'}")
# Split-PV: if hd > 256, process (128, 256) PV tiles
if hd > 256:
pv_n_tile = 256
pv_mma_split = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile),
tcgen05.OperandSource.SMEM,
)
pv_thr_split = pv_mma_split.get_slice(0)
pv_as_split = pv_thr_split.partition_shape_C((128, pv_n_tile))
tOtO_split = pv_thr_split.make_fragment_C(pv_as_split)
o_cols_split = find_tmem_tensor_col_offset(tOtO_split)
total_split = o_cols_split
print(f" Split-PV (128, {pv_n_tile}) O cols: {o_cols_split}, total SMEM-P: {total_split} / 512 {'' if total_split <= 512 else '❌ OVER BUDGET'}")
probe = BudgetProbe(head_dim=hd)
print(f' Compiling hd={hd}...', flush=True)
compiled = cute.compile(probe, mQ, mK, stream)
compiled(mQ, mK, stream)
torch.cuda.synchronize()
print(f' Done.', flush=True)
if __name__ == '__main__':