Files
nvfp4-megamoe-kernel/tests/archive/test_tmem_layout_diag.py
biondizzle 9cbdc92744 Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

82 lines
3.3 KiB
Python

"""
Diagnostic: Compare TMEM column mapping between QK C-fragment and PV A-fragment.
Write the TMEM column index for each logical element to GMEM.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16, Int32
from cutlass.utils import LayoutEnum
import cutlass.torch as ct
import cuda.bindings.driver as cuda
class TmemLayoutDiag:
def __init__(self):
self.threads_per_cta = 128 # just 1 warp for simplicity
@cute.jit
def __call__(self, qk_cols, pv16_cols, stream):
# qk_cols: GMEM tensor (128,) to store QK C-fragment TMEM column indices
# pv16_cols: GMEM tensor (128,) to store PV A-fragment TMEM column indices
# Create QK MMA (128,128)
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, LayoutEnum.ROW_MAJOR, LayoutEnum.ROW_MAJOR,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM)
# Create PV (128,16) MMA
pv16_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, LayoutEnum.ROW_MAJOR,
Float32, tcgen05.CtaGroup.ONE, (128, 16), tcgen05.OperandSource.TMEM)
qk_inst_k = int(cute.size(qk_mma.shape_mnk, mode=[2]))
qk_mma_tiler = (128, 128, int(qk_inst_k * 4))
pv16_mma_tiler = (128, qk_inst_k, 128)
# Create layouts
p_tmem_16 = utils.sm100.make_smem_layout_a(pv16_mma, pv16_mma_tiler, BFloat16, 1)
# Create tStS (QK C-fragment)
qk_thr = qk_mma.get_slice(0)
qk_acc = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc)
# Create tP for PV16
tP16 = cute.make_tensor(tStS.iterator, p_tmem_16.outer)
pv16_thr = pv16_mma.get_slice(0)
tOrP16 = pv16_thr.make_fragment_A(tP16)
# Write sizes to GMEM
# For thread 0 only
tidx, _, _ = cute.arch.thread_idx()
if tidx == 0:
# Write tStS size and tP16 size
qk_cols[0] = Int32(cute.size(tStS))
pv16_cols[0] = Int32(cute.size(tP16))
# Write qk_inst_k
qk_cols[1] = Int32(qk_inst_k)
# Write p_tmem_16 outer shape
pv16_cols[1] = Int32(cute.size(p_tmem_16.outer, mode=[0]))
pv16_cols[2] = Int32(cute.size(p_tmem_16.outer, mode=[1])) if p_tmem_16.outer.ndim >= 2 else Int32(0)
# Write QK acc shape
qk_cols[2] = Int32(qk_acc[0]) if hasattr(qk_acc, '__getitem__') else Int32(0)
def run(self):
qk_cols = torch.zeros(4, dtype=torch.int32, device='cuda')
pv16_cols = torch.zeros(4, dtype=torch.int32, device='cuda')
mQ = ct.from_dlpack(qk_cols).mark_layout_dynamic()
mP = ct.from_dlpack(pv16_cols).mark_layout_dynamic()
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
print("Compiling...", flush=True)
compiled = cute.compile(self, mQ, mP, stream)
compiled(mQ, mP, stream)
torch.cuda.synchronize()
print(f"QK results: {qk_cols.tolist()}")
print(f"PV16 results: {pv16_cols.tolist()}")
if __name__ == "__main__":
TmemLayoutDiag().run()