Files
nvfp4-megamoe-kernel/tests/test_tmem_layout_diag.py

82 lines
3.3 KiB
Python

"""
Diagnostic: Compare TMEM column mapping between QK C-fragment and PV A-fragment.
Write the TMEM column index for each logical element to GMEM.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16, Int32
from cutlass.utils import LayoutEnum
import cutlass.torch as ct
import cuda.bindings.driver as cuda
class TmemLayoutDiag:
def __init__(self):
self.threads_per_cta = 128 # just 1 warp for simplicity
@cute.jit
def __call__(self, qk_cols, pv16_cols, stream):
# qk_cols: GMEM tensor (128,) to store QK C-fragment TMEM column indices
# pv16_cols: GMEM tensor (128,) to store PV A-fragment TMEM column indices
# Create QK MMA (128,128)
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, LayoutEnum.ROW_MAJOR, LayoutEnum.ROW_MAJOR,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM)
# Create PV (128,16) MMA
pv16_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, LayoutEnum.ROW_MAJOR,
Float32, tcgen05.CtaGroup.ONE, (128, 16), tcgen05.OperandSource.TMEM)
qk_inst_k = int(cute.size(qk_mma.shape_mnk, mode=[2]))
qk_mma_tiler = (128, 128, int(qk_inst_k * 4))
pv16_mma_tiler = (128, qk_inst_k, 128)
# Create layouts
p_tmem_16 = utils.sm100.make_smem_layout_a(pv16_mma, pv16_mma_tiler, BFloat16, 1)
# Create tStS (QK C-fragment)
qk_thr = qk_mma.get_slice(0)
qk_acc = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc)
# Create tP for PV16
tP16 = cute.make_tensor(tStS.iterator, p_tmem_16.outer)
pv16_thr = pv16_mma.get_slice(0)
tOrP16 = pv16_thr.make_fragment_A(tP16)
# Write sizes to GMEM
# For thread 0 only
tidx, _, _ = cute.arch.thread_idx()
if tidx == 0:
# Write tStS size and tP16 size
qk_cols[0] = Int32(cute.size(tStS))
pv16_cols[0] = Int32(cute.size(tP16))
# Write qk_inst_k
qk_cols[1] = Int32(qk_inst_k)
# Write p_tmem_16 outer shape
pv16_cols[1] = Int32(cute.size(p_tmem_16.outer, mode=[0]))
pv16_cols[2] = Int32(cute.size(p_tmem_16.outer, mode=[1])) if p_tmem_16.outer.ndim >= 2 else Int32(0)
# Write QK acc shape
qk_cols[2] = Int32(qk_acc[0]) if hasattr(qk_acc, '__getitem__') else Int32(0)
def run(self):
qk_cols = torch.zeros(4, dtype=torch.int32, device='cuda')
pv16_cols = torch.zeros(4, dtype=torch.int32, device='cuda')
mQ = ct.from_dlpack(qk_cols).mark_layout_dynamic()
mP = ct.from_dlpack(pv16_cols).mark_layout_dynamic()
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
print("Compiling...", flush=True)
compiled = cute.compile(self, mQ, mP, stream)
compiled(mQ, mP, stream)
torch.cuda.synchronize()
print(f"QK results: {qk_cols.tolist()}")
print(f"PV16 results: {pv16_cols.tolist()}")
if __name__ == "__main__":
TmemLayoutDiag().run()