Files
nvfp4-megamoe-kernel/tests/archive/diag_tmem.py
biondizzle 9cbdc92744 Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

92 lines
4.4 KiB
Python

"""Diagnostic: Q1, Q2 for Stage B TMEM debugging.
Uses cute.compile with a dummy kernel that prints layout info at JIT time."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
@cute.jit
def diag_tmem(stream: cuda.CUstream):
a_dtype = BFloat16; b_dtype = BFloat16
a_major = cute.nvgpu.OperandMajorMode.K
b_major = cute.nvgpu.OperandMajorMode.K
qk_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, cute.nvgpu.OperandMajorMode.K, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.TMEM)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
mma_tiler = (128, 128, qk_inst_k * 4)
pv_mma_tiler = (128, 128, pv_inst_k * 4)
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
# Q1: QK accumulator C fragment
qk_acc_shape = qk_thr.partition_shape_C(mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
print(f"=== Q1: QK accumulator C fragment ===")
print(f" tStS.layout = {tStS.layout}")
print(f" cute.size(tStS.layout) = {cute.size(tStS.layout)}")
print(f" cute.cosize(tStS.layout) = {cute.cosize(tStS.layout)}")
print(f" cute.size(mode=[0]) = {cute.size(tStS.layout, mode=[0])}")
print(f" cute.size(mode=[1]) = {cute.size(tStS.layout, mode=[1])}")
s_tmem_cols = find_tmem_tensor_col_offset(tStS)
print(f" find_tmem_tensor_col_offset(tStS) = {s_tmem_cols}")
# PV accumulator O fragment
pv_acc_shape = pv_thr.partition_shape_C(mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
print(f"=== PV accumulator O fragment ===")
print(f" tOtO.layout = {tOtO.layout}")
print(f" cute.size(tOtO.layout) = {cute.size(tOtO.layout)}")
print(f" cute.cosize(tOtO.layout) = {cute.cosize(tOtO.layout)}")
print(f" cute.size(mode=[0]) = {cute.size(tOtO.layout, mode=[0])}")
print(f" cute.size(mode=[1]) = {cute.size(tOtO.layout, mode=[1])}")
o_tmem_cols = find_tmem_tensor_col_offset(tOtO)
print(f" find_tmem_tensor_col_offset(tOtO) = {o_tmem_cols}")
# Q2: PV A-fragment (P operand from TMEM)
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP_sliced = tOrP_base[(None, None, None, 0)]
print(f"=== Q2: PV A-fragment (P operand) ===")
print(f" tP.layout = {tP.layout}")
print(f" cute.size(tP.layout) = {cute.size(tP.layout)}")
print(f" cute.cosize(tP.layout) = {cute.cosize(tP.layout)}")
print(f" tOrP_sliced.layout = {tOrP_sliced.layout}")
print(f" cute.size(tOrP_sliced.layout) = {cute.size(tOrP_sliced.layout)}")
print(f" cute.cosize(tOrP_sliced.layout) = {cute.cosize(tOrP_sliced.layout)}")
p_tmem_cols = find_tmem_tensor_col_offset(tOrP_sliced)
print(f" find_tmem_tensor_col_offset(tOrP_sliced) = {p_tmem_cols}")
# Decompose 32800
print(f" 32800 in hex = 0x{32800:04x}")
print(f" 32800 - 0x8000 = {32800 - 0x8000}")
print(f" 32800 & 0x0000FFFF = {32800 & 0x0000FFFF}")
print(f" p_tmem_cols in hex = 0x{p_tmem_cols:04x}")
if isinstance(p_tmem_cols, int):
print(f" p_tmem_cols & 0x0000FFFF = {p_tmem_cols & 0x0000FFFF}")
print(f" p_tmem_cols >> 16 = {p_tmem_cols >> 16}")
# Staged fragments
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
print(f"=== Staged fragments ===")
print(f" find_tmem_tensor_col_offset(tCtS_fake) = {find_tmem_tensor_col_offset(tCtS_fake)}")
print(f" find_tmem_tensor_col_offset(tCtO_fake) = {find_tmem_tensor_col_offset(tCtO_fake)}")
print(f" get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake]) = {utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch='sm_100')}")
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
print("Compiling diagnostics...", flush=True)
compiled = cute.compile(diag_tmem, stream)
print("Done. Results above.", flush=True)