D1.2: TMEM budget probe using @cute.jit for MLIR context
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
"""TMEM column budget probe for FMHA at various head_dims.
|
||||
|
||||
Uses real tensors to get correct OperandMajorMode values, same as FmhaKernel.
|
||||
Uses @cute.jit to construct MMA objects inside a compiled context,
|
||||
where OperandMajorMode values are valid MLIR operands.
|
||||
"""
|
||||
import torch, math
|
||||
import torch, math, sys
|
||||
import cutlass, cutlass.cute as cute, cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32
|
||||
@@ -11,98 +12,99 @@ from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cutlass.torch as ct
|
||||
|
||||
|
||||
def probe_hd(hd):
|
||||
print(f"\n=== HEAD_DIM={hd} ===")
|
||||
m = 128 # M tile
|
||||
n = 128 # KV length (s_k)
|
||||
class BudgetProbe:
|
||||
def __init__(self, head_dim):
|
||||
self.hd = head_dim
|
||||
|
||||
# Create dummy tensors to extract major modes (same as FmhaKernel)
|
||||
@cute.jit
|
||||
def __call__(self, mQ, mK, stream):
|
||||
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
|
||||
# QK MMA: always (128, 128)
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, b_major,
|
||||
Float32, tcgen05.CtaGroup.ONE, (128, 128),
|
||||
tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_as = qk_thr.partition_shape_C((128, 128))
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
cute.printf("hd=%d: QK C-fragment s_cols=%d", Int32(self.hd), Int32(s_cols))
|
||||
cute.printf(" qk_as=(%d,%d)", Int32(qk_as[0]), Int32(qk_as[1]))
|
||||
|
||||
# PV MMA: (128, hd), TMEM-P (P from TMEM, K-major)
|
||||
pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K
|
||||
pv_mma_tmem = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, pv_a_major_tmem, b_major,
|
||||
Float32, tcgen05.CtaGroup.ONE, (128, self.hd),
|
||||
tcgen05.OperandSource.TMEM,
|
||||
)
|
||||
pv_thr_tmem = pv_mma_tmem.get_slice(0)
|
||||
pv_as_tmem = pv_thr_tmem.partition_shape_C((128, self.hd))
|
||||
tOtO_tmem = pv_thr_tmem.make_fragment_C(pv_as_tmem)
|
||||
o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem)
|
||||
cute.printf(" PV TMEM-P: o_cols=%d", Int32(o_cols_tmem))
|
||||
|
||||
# PV MMA: (128, hd), SMEM-P (P from SMEM, same a_major as Q)
|
||||
pv_mma_smem = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, b_major,
|
||||
Float32, tcgen05.CtaGroup.ONE, (128, self.hd),
|
||||
tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
pv_thr_smem = pv_mma_smem.get_slice(0)
|
||||
pv_as_smem = pv_thr_smem.partition_shape_C((128, self.hd))
|
||||
tOtO_smem = pv_thr_smem.make_fragment_C(pv_as_smem)
|
||||
o_cols_smem = find_tmem_tensor_col_offset(tOtO_smem)
|
||||
cute.printf(" PV SMEM-P: o_cols=%d", Int32(o_cols_smem))
|
||||
|
||||
# P columns in TMEM
|
||||
p_cols = 128 * 16 // 32 # pv_mma_tiler[2] * bf16_width / fp32_width
|
||||
cute.printf(" P cols (FP32): %d", Int32(p_cols))
|
||||
|
||||
# TMEM-P total
|
||||
tmem_p0 = 32
|
||||
p_end = tmem_p0 + p_cols
|
||||
o_after = s_cols if s_cols > p_end else p_end
|
||||
tmem_o0 = ((o_after + 31) // 32) * 32
|
||||
total_tmem_p = tmem_o0 + o_cols_tmem
|
||||
cute.printf(" TMEM-P total: %d / 512", Int32(total_tmem_p))
|
||||
|
||||
# SMEM-P total (O reuses S space, starting at col 0)
|
||||
cute.printf(" SMEM-P total: %d / 512", Int32(o_cols_smem))
|
||||
|
||||
# Split-PV at hd > 256
|
||||
if self.hd > 256:
|
||||
pv_n_tile = 256
|
||||
pv_mma_split = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, b_major,
|
||||
Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile),
|
||||
tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
pv_thr_split = pv_mma_split.get_slice(0)
|
||||
pv_as_split = pv_thr_split.partition_shape_C((128, pv_n_tile))
|
||||
tOtO_split = pv_thr_split.make_fragment_C(pv_as_split)
|
||||
o_cols_split = find_tmem_tensor_col_offset(tOtO_split)
|
||||
cute.printf(" Split-PV (128, 256): o_cols=%d", Int32(o_cols_split))
|
||||
|
||||
|
||||
def probe_hd(hd):
|
||||
print(f"\n=== HEAD_DIM={hd} ===", flush=True)
|
||||
m, n = 128, 128
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
stream = cutlass.cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
# Get major modes from the actual tensors
|
||||
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
|
||||
# QK MMA: always (128, 128)
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, b_major,
|
||||
Float32, tcgen05.CtaGroup.ONE, (128, 128),
|
||||
tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_as = qk_thr.partition_shape_C((128, 128))
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
print(f" QK C-fragment: qk_as={qk_as}, tStS.layout shape={cute.shape(tStS)}, s_cols={s_cols}")
|
||||
|
||||
# PV MMA: (128, hd), TMEM-P (P from TMEM, K-major)
|
||||
pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K
|
||||
pv_mma_tmem = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, pv_a_major_tmem, b_major,
|
||||
Float32, tcgen05.CtaGroup.ONE, (128, hd),
|
||||
tcgen05.OperandSource.TMEM,
|
||||
)
|
||||
pv_thr_tmem = pv_mma_tmem.get_slice(0)
|
||||
pv_as_tmem = pv_thr_tmem.partition_shape_C((128, hd))
|
||||
tOtO_tmem = pv_thr_tmem.make_fragment_C(pv_as_tmem)
|
||||
o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem)
|
||||
print(f" PV C-fragment (TMEM-P): pv_as={pv_as_tmem}, tOtO.layout shape={cute.shape(tOtO_tmem)}, o_cols={o_cols_tmem}")
|
||||
|
||||
# PV MMA: (128, hd), SMEM-P (P from SMEM, same a_major as Q)
|
||||
pv_mma_smem = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, b_major,
|
||||
Float32, tcgen05.CtaGroup.ONE, (128, hd),
|
||||
tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
pv_thr_smem = pv_mma_smem.get_slice(0)
|
||||
pv_as_smem = pv_thr_smem.partition_shape_C((128, hd))
|
||||
tOtO_smem = pv_thr_smem.make_fragment_C(pv_as_smem)
|
||||
o_cols_smem = find_tmem_tensor_col_offset(tOtO_smem)
|
||||
print(f" PV C-fragment (SMEM-P): pv_as={pv_as_smem}, tOtO.layout shape={cute.shape(tOtO_smem)}, o_cols={o_cols_smem}")
|
||||
|
||||
# P columns in TMEM (TMEM-P path only)
|
||||
pv_mma_tiler = (128, hd, 128)
|
||||
p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width
|
||||
print(f" P cols (FP32): {p_cols_fp32} (pv_mma_tiler[2]={pv_mma_tiler[2]})")
|
||||
|
||||
# TMEM budget calculations
|
||||
print(f" --- TMEM Budget ---")
|
||||
print(f" S cols: {s_cols}")
|
||||
print(f" P cols (TMEM-P): {p_cols_fp32}")
|
||||
print(f" O cols: {o_cols_tmem} (TMEM-P) / {o_cols_smem} (SMEM-P)")
|
||||
|
||||
# TMEM-P: S at 0, P at 32, O after max(S, P_end)
|
||||
tmem_p0 = 32
|
||||
p_end = tmem_p0 + p_cols_fp32
|
||||
o_after = max(s_cols, p_end)
|
||||
tmem_o0_tmem_p = ((o_after + 31) // 32) * 32
|
||||
total_tmem_p = tmem_o0_tmem_p + o_cols_tmem
|
||||
print(f" TMEM-P total: S(0) + P({tmem_p0}) + O({tmem_o0_tmem_p}) + O_size({o_cols_tmem}) = {total_tmem_p} / 512 cols {'✅' if total_tmem_p <= 512 else '❌ OVER BUDGET'}")
|
||||
|
||||
# SMEM-P: P not in TMEM. S and O sequential (after softmax, S is dead, O reuses S space).
|
||||
total_smem_p = o_cols_smem
|
||||
print(f" SMEM-P total (O at 0, reuses S): {total_smem_p} / 512 cols {'✅' if total_smem_p <= 512 else '❌ OVER BUDGET'}")
|
||||
|
||||
# Split-PV: if hd > 256, process (128, 256) PV tiles
|
||||
if hd > 256:
|
||||
pv_n_tile = 256
|
||||
pv_mma_split = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, b_major,
|
||||
Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile),
|
||||
tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
pv_thr_split = pv_mma_split.get_slice(0)
|
||||
pv_as_split = pv_thr_split.partition_shape_C((128, pv_n_tile))
|
||||
tOtO_split = pv_thr_split.make_fragment_C(pv_as_split)
|
||||
o_cols_split = find_tmem_tensor_col_offset(tOtO_split)
|
||||
total_split = o_cols_split
|
||||
print(f" Split-PV (128, {pv_n_tile}) O cols: {o_cols_split}, total SMEM-P: {total_split} / 512 {'✅' if total_split <= 512 else '❌ OVER BUDGET'}")
|
||||
probe = BudgetProbe(head_dim=hd)
|
||||
print(f' Compiling hd={hd}...', flush=True)
|
||||
compiled = cute.compile(probe, mQ, mK, stream)
|
||||
compiled(mQ, mK, stream)
|
||||
torch.cuda.synchronize()
|
||||
print(f' Done.', flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user