Add TMA shape diagnostic

This commit is contained in:
2026-05-22 21:52:51 +00:00
parent 845ad98b22
commit be27720cb2

View File

@@ -0,0 +1,106 @@
"""
Diagnostic: print tBgK and tVgV shapes after tma_partition.
Just need to see how many modes and which is the KV tile dim.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
HEAD_DIM = 64
class TmaShapeDiag:
def __init__(self, s_k=256):
self.s_k = s_k
self.n_kv_tiles = s_k // 128
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.kv_stage = 2; self.q_stage = 1
self.threads_per_cta = 192
self.qk_mma_tiler = (128, 128, 4)
self.pv_mma_tiler = (128, HEAD_DIM, 4)
@cute.jit
def __call__(self, q, k, v, stream):
a_major = LayoutEnum.from_tensor(q).mma_major_mode()
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(HEAD_DIM, self.s_k, 1),
stride=(1, HEAD_DIM, HEAD_DIM * self.s_k),
),
)
v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, a_major, b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
q_s = cute.slice_(q_smem_s,(None,None,None,0))
k_s = cute.slice_(k_smem_s,(None,None,None,0))
v_s = cute.slice_(v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,(1,1,1))
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,(1,1,1))
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,(1,1,1))
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV)
a_lay = cute.make_layout(1)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(q_s,0,2),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(1)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(k_s,0,2),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(v_s,0,2),cute.group_modes(tCgV,0,3))
print(f"tAgQ shape: {cute.shape(tAgQ)} rank: {tAgQ.rank}")
print(f"tBgK shape: {cute.shape(tBgK)} rank: {tBgK.rank}")
print(f"tVgV shape: {cute.shape(tVgV)} rank: {tVgV.rank}")
print(f"tAsQ shape: {cute.shape(tAsQ)} rank: {tAsQ.rank}")
print(f"tBsK shape: {cute.shape(tBsK)} rank: {tBsK.rank}")
print(f"tVsV shape: {cute.shape(tVsV)} rank: {tVsV.rank}")
# Print size of each mode
for i in range(tBgK.rank):
try:
print(f" tBgK mode {i} size: {cute.size(tBgK, mode=[i])}")
except:
print(f" tBgK mode {i}: error getting size")
for i in range(tVgV.rank):
try:
print(f" tVgV mode {i} size: {cute.size(tVgV, mode=[i])}")
except:
print(f" tVgV mode {i}: error getting size")
def test():
n = 256
m, hd = 128, HEAD_DIM
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
diag = TmaShapeDiag(s_k=n)
compiled = cute.compile(diag, mQ, mK, mV, stream)
compiled(mQ, mK, mV, stream)
torch.cuda.synchronize()
if __name__ == '__main__':
test()