D1: Layout diagnostic for SMEM-P
This commit is contained in:
35
tests/unit/test_d1_layout_diag.py
Normal file
35
tests/unit/test_d1_layout_diag.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""D1: Print SMEM-P layout diagnostics for hd=128."""
|
||||
import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass import Float32, BFloat16
|
||||
from cutlass.utils import LayoutEnum
|
||||
|
||||
for hd in [64, 128, 256]:
|
||||
q = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode()
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM)
|
||||
pv_n_tile = min(hd, 256)
|
||||
pv_a_major = a_major # SMEM-P path uses a_major
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, pv_a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128,pv_n_tile), tcgen05.OperandSource.SMEM)
|
||||
|
||||
qk_mma_tiler = (128, 128, cute.size(qk_mma.shape_mnk, mode=[2]) * 4)
|
||||
pv_mma_tiler = (128, pv_n_tile, cute.size(pv_mma.shape_mnk, mode=[2]) * (128 // cute.size(pv_mma.shape_mnk, mode=[2])))
|
||||
|
||||
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
||||
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
||||
|
||||
# QK C-fragment layout
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
|
||||
print(f'--- hd={hd} ---')
|
||||
print(f' p_tmem_s.outer = {p_tmem_s.outer}')
|
||||
print(f' p_smem_s.outer = {p_smem_s.outer}')
|
||||
print(f' tStS.layout = {tStS.layout}')
|
||||
print(f' pv_mma_tiler = {pv_mma_tiler}')
|
||||
print()
|
||||
Reference in New Issue
Block a user