82 lines
3.3 KiB
Python
82 lines
3.3 KiB
Python
"""
|
|
Diagnostic: Compare TMEM column mapping between QK C-fragment and PV A-fragment.
|
|
Write the TMEM column index for each logical element to GMEM.
|
|
"""
|
|
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
|
|
from cutlass.cute.nvgpu import tcgen05
|
|
from cutlass import Float32, BFloat16, Int32
|
|
from cutlass.utils import LayoutEnum
|
|
import cutlass.torch as ct
|
|
import cuda.bindings.driver as cuda
|
|
|
|
|
|
class TmemLayoutDiag:
|
|
def __init__(self):
|
|
self.threads_per_cta = 128 # just 1 warp for simplicity
|
|
|
|
@cute.jit
|
|
def __call__(self, qk_cols, pv16_cols, stream):
|
|
# qk_cols: GMEM tensor (128,) to store QK C-fragment TMEM column indices
|
|
# pv16_cols: GMEM tensor (128,) to store PV A-fragment TMEM column indices
|
|
|
|
# Create QK MMA (128,128)
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, LayoutEnum.ROW_MAJOR, LayoutEnum.ROW_MAJOR,
|
|
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM)
|
|
# Create PV (128,16) MMA
|
|
pv16_mma = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, LayoutEnum.ROW_MAJOR,
|
|
Float32, tcgen05.CtaGroup.ONE, (128, 16), tcgen05.OperandSource.TMEM)
|
|
|
|
qk_inst_k = int(cute.size(qk_mma.shape_mnk, mode=[2]))
|
|
qk_mma_tiler = (128, 128, int(qk_inst_k * 4))
|
|
pv16_mma_tiler = (128, qk_inst_k, 128)
|
|
|
|
# Create layouts
|
|
p_tmem_16 = utils.sm100.make_smem_layout_a(pv16_mma, pv16_mma_tiler, BFloat16, 1)
|
|
|
|
# Create tStS (QK C-fragment)
|
|
qk_thr = qk_mma.get_slice(0)
|
|
qk_acc = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
|
tStS = qk_thr.make_fragment_C(qk_acc)
|
|
|
|
# Create tP for PV16
|
|
tP16 = cute.make_tensor(tStS.iterator, p_tmem_16.outer)
|
|
pv16_thr = pv16_mma.get_slice(0)
|
|
tOrP16 = pv16_thr.make_fragment_A(tP16)
|
|
|
|
# Write sizes to GMEM
|
|
# For thread 0 only
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
if tidx == 0:
|
|
# Write tStS size and tP16 size
|
|
qk_cols[0] = Int32(cute.size(tStS))
|
|
pv16_cols[0] = Int32(cute.size(tP16))
|
|
# Write qk_inst_k
|
|
qk_cols[1] = Int32(qk_inst_k)
|
|
# Write p_tmem_16 outer shape
|
|
pv16_cols[1] = Int32(cute.size(p_tmem_16.outer, mode=[0]))
|
|
pv16_cols[2] = Int32(cute.size(p_tmem_16.outer, mode=[1])) if p_tmem_16.outer.ndim >= 2 else Int32(0)
|
|
# Write QK acc shape
|
|
qk_cols[2] = Int32(qk_acc[0]) if hasattr(qk_acc, '__getitem__') else Int32(0)
|
|
|
|
def run(self):
|
|
qk_cols = torch.zeros(4, dtype=torch.int32, device='cuda')
|
|
pv16_cols = torch.zeros(4, dtype=torch.int32, device='cuda')
|
|
|
|
mQ = ct.from_dlpack(qk_cols).mark_layout_dynamic()
|
|
mP = ct.from_dlpack(pv16_cols).mark_layout_dynamic()
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
print("Compiling...", flush=True)
|
|
compiled = cute.compile(self, mQ, mP, stream)
|
|
compiled(mQ, mP, stream)
|
|
torch.cuda.synchronize()
|
|
|
|
print(f"QK results: {qk_cols.tolist()}")
|
|
print(f"PV16 results: {pv16_cols.tolist()}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
TmemLayoutDiag().run()
|