93 lines
3.7 KiB
Python
93 lines
3.7 KiB
Python
"""D1 diagnostic: Print TMEM budget and shapes at hd=256."""
|
|
import torch, math
|
|
import cutlass, cutlass.cute as cute
|
|
import cutlass.torch as ct
|
|
import cutlass.utils as utils
|
|
from cutlass.cute.nvgpu import tcgen05
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
|
import cuda.bindings.driver as cuda
|
|
|
|
hd = 256
|
|
n_kv = 128
|
|
m = 128
|
|
|
|
torch.manual_seed(42)
|
|
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(n_kv, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(n_kv, hd, dtype=torch.bfloat16, device='cuda')
|
|
v_kernel = v.unsqueeze(-1)
|
|
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
q_dtype = cutlass.BFloat16; qk_acc_dtype = cutlass.Float32
|
|
cta_group = tcgen05.CtaGroup.ONE
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
|
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
|
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
|
|
|
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
|
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
|
|
|
pv_n_tile = min(hd, 256)
|
|
v_fmha = cute.make_tensor(
|
|
mV.iterator,
|
|
cute.make_layout((pv_n_tile, n_kv, 1), stride=(1, pv_n_tile, pv_n_tile * n_kv)),
|
|
)
|
|
v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
|
|
|
|
print(f"a_major={a_major}, b_major={b_major}, v_major={v_major}")
|
|
print(f"Q shape: {q.shape}, K shape: {k.shape}, V shape: {v.shape}")
|
|
print(f"mQ shape: {mQ.shape}, mK shape: {mK.shape}, mV shape: {mV.shape}, mC shape: {mC.shape}")
|
|
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(q_dtype, q_dtype, a_major, b_major, qk_acc_dtype, cta_group, (128, 128), tcgen05.OperandSource.SMEM)
|
|
print(f"\nQK MMA shape_mnk: {qk_mma.shape_mnk}")
|
|
|
|
pv_mma = utils.sm100.make_trivial_tiled_mma(q_dtype, q_dtype, cutlass.nvgpu.OperandMajorMode.K, v_major, qk_acc_dtype, cta_group, (128, pv_n_tile), tcgen05.OperandSource.TMEM)
|
|
print(f"PV MMA shape_mnk: {pv_mma.shape_mnk}")
|
|
|
|
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
|
qk_mma_tiler = (128, 128, qk_ik * 4)
|
|
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
|
pv_k_inner = pv_ik * (128 // pv_ik)
|
|
pv_mma_tiler = (128, pv_n_tile, pv_k_inner)
|
|
print(f"qk_ik={qk_ik}, qk_mma_tiler={qk_mma_tiler}")
|
|
print(f"pv_ik={pv_ik}, pv_mma_tiler={pv_mma_tiler}")
|
|
|
|
# GMEM tile views
|
|
gQ = cute.local_tile(mQ, cute.slice_(qk_mma_tiler,(None,0,None)),(None,None,None))
|
|
gK = cute.local_tile(mK, cute.slice_(qk_mma_tiler,(0,None,None)),(None,None,None))
|
|
gV = cute.local_tile(mV, cute.slice_(pv_mma_tiler,(0,None,None)),(None,None,None))
|
|
gC = cute.local_tile(mC, cute.slice_(pv_mma_tiler,(None,None,0)),(None,None,None))
|
|
print(f"\ngQ shape: {gQ.shape}")
|
|
print(f"gK shape: {gK.shape}")
|
|
print(f"gV shape: {gV.shape}")
|
|
print(f"gC shape: {gC.shape}")
|
|
print(f"n_kv_tiles: {cute.size(gK, mode=[3])}")
|
|
|
|
# TMEM budget
|
|
qk_thr = qk_mma.get_slice(0)
|
|
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
|
tStS = qk_thr.make_fragment_C(qk_as)
|
|
pv_thr = pv_mma.get_slice(0)
|
|
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
|
|
tOtO = pv_thr.make_fragment_C(pv_as)
|
|
|
|
tmem_p0_offset = 32
|
|
p_cols_fp32 = pv_mma_tiler[2] * q_dtype.width // qk_acc_dtype.width
|
|
o_cols = find_tmem_tensor_col_offset(tOtO)
|
|
s_cols = qk_mma_tiler[1]
|
|
p_end = tmem_p0_offset + p_cols_fp32
|
|
o_after = max(s_cols, p_end)
|
|
tmem_o0_offset = ((o_after + 31) // 32) * 32
|
|
total = tmem_o0_offset + o_cols
|
|
|
|
print(f"\n=== TMEM Budget (hd={hd}) ===")
|
|
print(f"S={s_cols}, P_fp32={p_cols_fp32}, P_end={p_end}, O_offset={tmem_o0_offset}, O_cols={o_cols}")
|
|
print(f"Total={total}, Fits={total<=512}")
|
|
|
|
# QK C-fragment info
|
|
print(f"\ntStS shape: {tStS.shape}, cosize: {cute.cosize(tStS.layout)}")
|
|
print(f"tOtO shape: {tOtO.shape}, cosize: {cute.cosize(tOtO.layout)}")
|