TMA shape diag: pure Python, no JIT

This commit is contained in:
2026-05-22 21:55:03 +00:00
parent 2f670e33d1
commit fd6b1e82d8

View File

@@ -1,173 +1,110 @@
"""
Diagnostic: print tBgK and tVgV shapes BEFORE pre-slicing.
This runs at JIT trace time, so Python print gives us static shape info.
We do this at Python trace time (before JIT), not inside the kernel.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass import Float32, BFloat16, Int32
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
HEAD_DIM = 64
class TmaShapeDiag:
def __init__(self, s_k=256):
self.s_k = s_k
self.n_kv_tiles = s_k // 128
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
self.scale_softmax = 1.0 / math.sqrt(HEAD_DIM)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0; self.tmem_p0_offset = 32
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32
s_cols = self.qk_mma_tiler[1]
o_after = max(s_cols, p_end)
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
o_cols = find_tmem_tensor_col_offset(tOtO)
total = self.tmem_o0_offset + o_cols
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) +
cute.size_in_bytes(self.q_dtype, v_s)) * cta
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(HEAD_DIM, self.s_k, 1),
stride=(1, HEAD_DIM, HEAD_DIM * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
# Stop here — just check shapes
self._diag_kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,self.cluster_layout_vmnk).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _diag_kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, cl_vmnk):
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
sQ = cute.make_tensor(BFloat16, q_s.outer)
sK = cute.make_tensor(BFloat16, k_s.outer)
sV = cute.make_tensor(BFloat16, v_s.outer)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV)
a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
# Print shapes BEFORE any pre-slice
print(f"=== TMA partition shapes (n_kv_tiles={self.n_kv_tiles}) ===")
print(f"tAgQ: shape={cute.shape(tAgQ)} rank={tAgQ.rank}")
print(f"tBgK: shape={cute.shape(tBgK)} rank={tBgK.rank}")
print(f"tVgV: shape={cute.shape(tVgV)} rank={tVgV.rank}")
print(f"tAsQ: shape={cute.shape(tAsQ)} rank={tAsQ.rank}")
print(f"tBsK: shape={cute.shape(tBsK)} rank={tBsK.rank}")
print(f"tVsV: shape={cute.shape(tVsV)} rank={tVsV.rank}")
# Print per-mode sizes
for name, t in [("tAgQ", tAgQ), ("tBgK", tBgK), ("tVgV", tVgV)]:
for i in range(t.rank):
sz = cute.size(t, mode=[i])
print(f" {name} mode {i} size={sz}")
# Also print after the original pre-slice
tAgQ2 = tAgQ[(None,0,None,0)]
tBgK2 = tBgK[(None,None,0,0)]
tVgV2 = tVgV[(None,0,None,0)]
print(f"\nAfter pre-slice:")
print(f"tAgQ[(None,0,None,0)]: shape={cute.shape(tAgQ2)} rank={tAgQ2.rank}")
print(f"tBgK[(None,None,0,0)]: shape={cute.shape(tBgK2)} rank={tBgK2.rank}")
print(f"tVgV[(None,0,None,0)]: shape={cute.shape(tVgV2)} rank={tVgV2.rank}")
for name, t in [("tAgQ2", tAgQ2), ("tBgK2", tBgK2), ("tVgV2", tVgV2)]:
for i in range(t.rank):
sz = cute.size(t, mode=[i])
print(f" {name} mode {i} size={sz}")
def test():
def diag():
n = 256
m, hd = 128, HEAD_DIM
s_k = n
n_kv_tiles = s_k // 128
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
qk_mma_tiler = (128, 128, 4)
pv_mma_tiler = (128, HEAD_DIM, 4)
cluster_shape_mn = (1, 1)
cta_group = tcgen05.CtaGroup.ONE
qk_acc_dtype = Float32
q_dtype = BFloat16
a_major = LayoutEnum.from_tensor(q).mma_major_mode()
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_fmha = cute.make_tensor(
v_kernel,
cute.make_layout(
(HEAD_DIM, s_k, 1),
stride=(1, HEAD_DIM, HEAD_DIM * s_k),
),
)
v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(q_dtype, q_dtype, a_major, b_major, qk_acc_dtype, cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(q_dtype, q_dtype, cute.nvgpu.OperandMajorMode.K, v_major, qk_acc_dtype, cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
kv_stage = 2; q_stage = 1
q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, qk_mma_tiler, q_dtype, q_stage)
k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, q_dtype, kv_stage)
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, q_dtype, kv_stage)
q_s = cute.slice_(q_smem_s,(None,None,None,0))
k_s = cute.slice_(k_smem_s,(None,None,None,0))
v_s = cute.slice_(v_smem_s,(None,None,None,0))
cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
diag = TmaShapeDiag(s_k=n)
print('Compiling...', flush=True)
compiled = cute.compile(diag, mQ, mK, mV, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
print('Done.')
tma_q,tma_mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(cluster_shape_mn,qk_mma.thr_id),mQ,q_s,qk_mma_tiler,qk_mma,cluster_layout_vmnk.shape)
tma_k,tma_mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(cluster_shape_mn,qk_mma.thr_id),mK,k_s,qk_mma_tiler,qk_mma,cluster_layout_vmnk.shape)
tma_v,tma_mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(cluster_shape_mn,pv_mma.thr_id),mV,v_s,pv_mma_tiler,pv_mma,cluster_layout_vmnk.shape)
gQ = cute.local_tile(tma_mQ,cute.slice_(qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(tma_mK,cute.slice_(qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(tma_mV,cute.slice_(pv_mma_tiler,(0,None,None)),(None,None,None))
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV)
a_lay = cute.make_layout(cute.slice_(cluster_layout_vmnk,(0,0,None,0)).shape)
b_lay = cute.make_layout(cute.slice_(cluster_layout_vmnk,(0,None,0,0)).shape)
# Use the full SMEM layouts (not sliced) for group_modes
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(q_s,0,3),cute.group_modes(tCgQ,0,3))
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(k_s,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(v_s,0,3),cute.group_modes(tCgV,0,3))
print(f"=== TMA partition shapes (n_kv_tiles={n_kv_tiles}) ===")
print(f"tAgQ: shape={cute.shape(tAgQ)}")
print(f"tBgK: shape={cute.shape(tBgK)}")
print(f"tVgV: shape={cute.shape(tVgV)}")
print(f"tAsQ: shape={cute.shape(tAsQ)}")
print(f"tBsK: shape={cute.shape(tBsK)}")
print(f"tVsV: shape={cute.shape(tVsV)}")
for name, t in [("tAgQ", tAgQ), ("tBgK", tBgK), ("tVgV", tVgV)]:
for i in range(t.rank):
sz = cute.size(t, mode=[i])
print(f" {name} mode {i} size={sz}")
# After pre-slice
tAgQ2 = tAgQ[(None,0,None,0)]
tBgK2 = tBgK[(None,None,0,0)]
tVgV2 = tVgV[(None,0,None,0)]
print(f"\nAfter pre-slice (None,0,None,0) / (None,None,0,0) / (None,0,None,0):")
print(f"tAgQ: shape={cute.shape(tAgQ2)}")
print(f"tBgK: shape={cute.shape(tBgK2)}")
print(f"tVgV: shape={cute.shape(tVgV2)}")
for name, t in [("tAgQ2", tAgQ2), ("tBgK2", tBgK2), ("tVgV2", tVgV2)]:
for i in range(t.rank):
sz = cute.size(t, mode=[i])
print(f" {name} mode {i} size={sz}")
if __name__ == '__main__':
test()
diag()