D1.2: TMEM budget probe with real tensor major modes

This commit is contained in:
2026-05-23 06:37:34 +00:00
parent 8b351bd871
commit 3490601e02

View File

@@ -1,7 +1,6 @@
"""TMEM column budget probe for FMHA at various head_dims.
Prints find_tmem_tensor_col_offset(tOtO) and related shapes so we can
plan the SMEM-P path and verify TMEM fits in 512 columns at hd=512.
Uses real tensors to get correct OperandMajorMode values, same as FmhaKernel.
"""
import torch, math
import cutlass, cutlass.cute as cute, cutlass.utils as utils
@@ -9,22 +8,34 @@ from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16, Int32
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cutlass.torch as ct
# Use the same major modes as FmhaKernel:
# a_major = OperandMajorMode.M (Q is M-major, typical for row-major Q)
# b_major = OperandMajorMode.K (K is K-major, typical for column-major K)
# For PV TMEM-P: a_major = OperandMajorMode.K (P in TMEM, K-major)
# For PV SMEM-P: a_major = OperandMajorMode.M (P from SMEM, M-major)
MN = cute.nvgpu.OperandMajorMode.MN
K = cute.nvgpu.OperandMajorMode.K
def probe_hd(hd):
print(f"\n=== HEAD_DIM={hd} ===")
m = 128 # M tile
n = 128 # KV length (s_k)
# QK MMA: always (128, 128), SMEM A + SMEM B
# Create dummy tensors to extract major modes (same as FmhaKernel)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
# V with FMHA layout: (head_dim, s_k, 1)
v_fmha_layout = cute.make_layout(
(hd, n, 1), stride=(1, hd, hd * n),
)
# Get major modes from the actual tensors
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
# QK MMA: always (128, 128)
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, MN, K,
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128),
tcgen05.OperandSource.SMEM,
)
@@ -35,8 +46,9 @@ def probe_hd(hd):
print(f" QK C-fragment: qk_as={qk_as}, tStS.layout shape={cute.shape(tStS)}, s_cols={s_cols}")
# PV MMA: (128, hd), TMEM-P (P from TMEM, K-major)
pv_a_major_tmem = cute.nvgpu.OperandMajorMode.K
pv_mma_tmem = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, K, K,
BFloat16, BFloat16, pv_a_major_tmem, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, hd),
tcgen05.OperandSource.TMEM,
)
@@ -46,9 +58,9 @@ def probe_hd(hd):
o_cols_tmem = find_tmem_tensor_col_offset(tOtO_tmem)
print(f" PV C-fragment (TMEM-P): pv_as={pv_as_tmem}, tOtO.layout shape={cute.shape(tOtO_tmem)}, o_cols={o_cols_tmem}")
# PV MMA: (128, hd), SMEM-P (P from SMEM, M-major)
# PV MMA: (128, hd), SMEM-P (P from SMEM, same a_major as Q)
pv_mma_smem = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, MN, K,
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, hd),
tcgen05.OperandSource.SMEM,
)
@@ -77,8 +89,7 @@ def probe_hd(hd):
total_tmem_p = tmem_o0_tmem_p + o_cols_tmem
print(f" TMEM-P total: S(0) + P({tmem_p0}) + O({tmem_o0_tmem_p}) + O_size({o_cols_tmem}) = {total_tmem_p} / 512 cols {'' if total_tmem_p <= 512 else '❌ OVER BUDGET'}")
# SMEM-P: P not in TMEM. S and O share TMEM (sequential).
# After softmax consumes S, PV writes O starting at col 0.
# SMEM-P: P not in TMEM. S and O sequential (after softmax, S is dead, O reuses S space).
total_smem_p = o_cols_smem
print(f" SMEM-P total (O at 0, reuses S): {total_smem_p} / 512 cols {'' if total_smem_p <= 512 else '❌ OVER BUDGET'}")
@@ -86,7 +97,7 @@ def probe_hd(hd):
if hd > 256:
pv_n_tile = 256
pv_mma_split = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, MN, K,
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile),
tcgen05.OperandSource.SMEM,
)