Files
nvfp4-megamoe-kernel/tests/unit/test_tmem_budget.py

101 lines
4.1 KiB
Python

"""TMEM column budget probe for FMHA at various head_dims.
Uses @cute.jit to construct MMA objects inside a compiled context.
Only prints the critical budget numbers.
"""
import torch, math
import cuda.bindings.driver as cuda
import cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16, Int32
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cutlass.torch as ct
class BudgetProbe:
def __init__(self, head_dim):
self.hd = head_dim
@cute.jit
def __call__(self, mQ, mK, stream):
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
# QK MMA: always (128, 128)
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128),
tcgen05.OperandSource.SMEM,
)
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C((128, 128))
tStS = qk_thr.make_fragment_C(qk_as)
s_cols = find_tmem_tensor_col_offset(tStS)
cute.printf("hd=%d s_cols=%d", Int32(self.hd), Int32(s_cols))
# PV MMA: (128, hd), TMEM-P (P from TMEM, K-major)
# tcgen05 MMA max N = 256. At hd > 256, skip TMEM-P and SMEM-P full-hd probes.
pv_n_tile = self.hd if self.hd <= 256 else 256
pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K
pv_mma_tmem = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, pv_a_major_tmem, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile),
tcgen05.OperandSource.TMEM,
)
pv_thr_tmem = pv_mma_tmem.get_slice(0)
pv_as_tmem = pv_thr_tmem.partition_shape_C((128, pv_n_tile))
tOtO_tmem = pv_thr_tmem.make_fragment_C(pv_as_tmem)
o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem)
cute.printf("hd=%d pv_n_tile=%d o_cols_tmem=%d", Int32(self.hd), Int32(pv_n_tile), Int32(o_cols_tmem))
# PV MMA: (128, pv_n_tile), SMEM-P (P from SMEM, same a_major as Q)
pv_mma_smem = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile),
tcgen05.OperandSource.SMEM,
)
pv_thr_smem = pv_mma_smem.get_slice(0)
pv_as_smem = pv_thr_smem.partition_shape_C((128, pv_n_tile))
tOtO_smem = pv_thr_smem.make_fragment_C(pv_as_smem)
o_cols_smem = find_tmem_tensor_col_offset(tOtO_smem)
cute.printf("hd=%d pv_n_tile=%d o_cols_smem=%d", Int32(self.hd), Int32(pv_n_tile), Int32(o_cols_smem))
# P columns in TMEM (TMEM-P path only)
p_cols = 128 * 16 // 32 # pv_mma_tiler[2] * bf16_width / fp32_width
cute.printf("hd=%d p_cols=%d", Int32(self.hd), Int32(p_cols))
# TMEM-P total
tmem_p0 = 32
p_end = tmem_p0 + p_cols
o_after = s_cols if s_cols > p_end else p_end
tmem_o0 = ((o_after + 31) // 32) * 32
total_tmem_p = tmem_o0 + o_cols_tmem
cute.printf("hd=%d TMEM-P_total=%d", Int32(self.hd), Int32(total_tmem_p))
# SMEM-P total (O reuses S space, starting at col 0)
cute.printf("hd=%d SMEM-P_total=%d", Int32(self.hd), Int32(o_cols_smem))
def probe_hd(hd):
print(f"\n=== HEAD_DIM={hd} ===", flush=True)
m, n = 128, 128
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
probe = BudgetProbe(head_dim=hd)
print(f' Compiling hd={hd}...', flush=True)
compiled = cute.compile(probe, mQ, mK, stream)
compiled(mQ, mK, stream)
torch.cuda.synchronize()
print(f' Done.', flush=True)
if __name__ == '__main__':
for hd in [64, 128, 256, 512]:
probe_hd(hd)