101 lines
4.1 KiB
Python
101 lines
4.1 KiB
Python
"""TMEM column budget probe for FMHA at various head_dims.
|
|
|
|
Uses @cute.jit to construct MMA objects inside a compiled context.
|
|
Only prints the critical budget numbers.
|
|
"""
|
|
import torch, math
|
|
import cuda.bindings.driver as cuda
|
|
import cutlass, cutlass.cute as cute, cutlass.utils as utils
|
|
from cutlass.cute.nvgpu import tcgen05
|
|
from cutlass import Float32, BFloat16, Int32
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
|
import cutlass.torch as ct
|
|
|
|
|
|
class BudgetProbe:
|
|
def __init__(self, head_dim):
|
|
self.hd = head_dim
|
|
|
|
@cute.jit
|
|
def __call__(self, mQ, mK, stream):
|
|
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
|
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
|
|
|
# QK MMA: always (128, 128)
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, a_major, b_major,
|
|
Float32, tcgen05.CtaGroup.ONE, (128, 128),
|
|
tcgen05.OperandSource.SMEM,
|
|
)
|
|
qk_thr = qk_mma.get_slice(0)
|
|
qk_as = qk_thr.partition_shape_C((128, 128))
|
|
tStS = qk_thr.make_fragment_C(qk_as)
|
|
s_cols = find_tmem_tensor_col_offset(tStS)
|
|
cute.printf("hd=%d s_cols=%d", Int32(self.hd), Int32(s_cols))
|
|
|
|
# PV MMA: (128, hd), TMEM-P (P from TMEM, K-major)
|
|
# tcgen05 MMA max N = 256. At hd > 256, skip TMEM-P and SMEM-P full-hd probes.
|
|
pv_n_tile = self.hd if self.hd <= 256 else 256
|
|
pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K
|
|
pv_mma_tmem = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, pv_a_major_tmem, b_major,
|
|
Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile),
|
|
tcgen05.OperandSource.TMEM,
|
|
)
|
|
pv_thr_tmem = pv_mma_tmem.get_slice(0)
|
|
pv_as_tmem = pv_thr_tmem.partition_shape_C((128, pv_n_tile))
|
|
tOtO_tmem = pv_thr_tmem.make_fragment_C(pv_as_tmem)
|
|
o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem)
|
|
cute.printf("hd=%d pv_n_tile=%d o_cols_tmem=%d", Int32(self.hd), Int32(pv_n_tile), Int32(o_cols_tmem))
|
|
|
|
# PV MMA: (128, pv_n_tile), SMEM-P (P from SMEM, same a_major as Q)
|
|
pv_mma_smem = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, a_major, b_major,
|
|
Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile),
|
|
tcgen05.OperandSource.SMEM,
|
|
)
|
|
pv_thr_smem = pv_mma_smem.get_slice(0)
|
|
pv_as_smem = pv_thr_smem.partition_shape_C((128, pv_n_tile))
|
|
tOtO_smem = pv_thr_smem.make_fragment_C(pv_as_smem)
|
|
o_cols_smem = find_tmem_tensor_col_offset(tOtO_smem)
|
|
cute.printf("hd=%d pv_n_tile=%d o_cols_smem=%d", Int32(self.hd), Int32(pv_n_tile), Int32(o_cols_smem))
|
|
|
|
# P columns in TMEM (TMEM-P path only)
|
|
p_cols = 128 * 16 // 32 # pv_mma_tiler[2] * bf16_width / fp32_width
|
|
cute.printf("hd=%d p_cols=%d", Int32(self.hd), Int32(p_cols))
|
|
|
|
# TMEM-P total
|
|
tmem_p0 = 32
|
|
p_end = tmem_p0 + p_cols
|
|
o_after = s_cols if s_cols > p_end else p_end
|
|
tmem_o0 = ((o_after + 31) // 32) * 32
|
|
total_tmem_p = tmem_o0 + o_cols_tmem
|
|
cute.printf("hd=%d TMEM-P_total=%d", Int32(self.hd), Int32(total_tmem_p))
|
|
|
|
# SMEM-P total (O reuses S space, starting at col 0)
|
|
cute.printf("hd=%d SMEM-P_total=%d", Int32(self.hd), Int32(o_cols_smem))
|
|
|
|
|
|
def probe_hd(hd):
|
|
print(f"\n=== HEAD_DIM={hd} ===", flush=True)
|
|
m, n = 128, 128
|
|
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
probe = BudgetProbe(head_dim=hd)
|
|
print(f' Compiling hd={hd}...', flush=True)
|
|
compiled = cute.compile(probe, mQ, mK, stream)
|
|
compiled(mQ, mK, stream)
|
|
torch.cuda.synchronize()
|
|
print(f' Done.', flush=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
for hd in [64, 128, 256, 512]:
|
|
probe_hd(hd)
|