D1: Layout diagnostic for SMEM-P

This commit is contained in:
2026-05-24 03:30:29 +00:00
parent dd7356afc6
commit d0aec403e4

View File

@@ -0,0 +1,35 @@
"""D1: Print SMEM-P layout diagnostics for hd=128."""
import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.torch as ct
import cuda.bindings.driver as cuda
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
from cutlass.utils import LayoutEnum
for hd in [64, 128, 256]:
q = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda')
a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode()
b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM)
pv_n_tile = min(hd, 256)
pv_a_major = a_major # SMEM-P path uses a_major
pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, pv_a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128,pv_n_tile), tcgen05.OperandSource.SMEM)
qk_mma_tiler = (128, 128, cute.size(qk_mma.shape_mnk, mode=[2]) * 4)
pv_mma_tiler = (128, pv_n_tile, cute.size(pv_mma.shape_mnk, mode=[2]) * (128 // cute.size(pv_mma.shape_mnk, mode=[2])))
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
# QK C-fragment layout
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
print(f'--- hd={hd} ---')
print(f' p_tmem_s.outer = {p_tmem_s.outer}')
print(f' p_smem_s.outer = {p_smem_s.outer}')
print(f' tStS.layout = {tStS.layout}')
print(f' pv_mma_tiler = {pv_mma_tiler}')
print()