- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
92 lines
4.4 KiB
Python
92 lines
4.4 KiB
Python
"""Diagnostic: Q1, Q2 for Stage B TMEM debugging.
|
|
Uses cute.compile with a dummy kernel that prints layout info at JIT time."""
|
|
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
|
|
from cutlass.cute.nvgpu import tcgen05
|
|
from cutlass import Float32, BFloat16
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
|
import cuda.bindings.driver as cuda
|
|
|
|
|
|
@cute.jit
|
|
def diag_tmem(stream: cuda.CUstream):
|
|
a_dtype = BFloat16; b_dtype = BFloat16
|
|
a_major = cute.nvgpu.OperandMajorMode.K
|
|
b_major = cute.nvgpu.OperandMajorMode.K
|
|
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
|
a_dtype, b_dtype, a_major, b_major,
|
|
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM)
|
|
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
|
a_dtype, b_dtype, cute.nvgpu.OperandMajorMode.K, b_major,
|
|
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.TMEM)
|
|
|
|
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
|
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
|
mma_tiler = (128, 128, qk_inst_k * 4)
|
|
pv_mma_tiler = (128, 128, pv_inst_k * 4)
|
|
|
|
qk_thr = qk_mma.get_slice(0)
|
|
pv_thr = pv_mma.get_slice(0)
|
|
|
|
# Q1: QK accumulator C fragment
|
|
qk_acc_shape = qk_thr.partition_shape_C(mma_tiler[:2])
|
|
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
|
print(f"=== Q1: QK accumulator C fragment ===")
|
|
print(f" tStS.layout = {tStS.layout}")
|
|
print(f" cute.size(tStS.layout) = {cute.size(tStS.layout)}")
|
|
print(f" cute.cosize(tStS.layout) = {cute.cosize(tStS.layout)}")
|
|
print(f" cute.size(mode=[0]) = {cute.size(tStS.layout, mode=[0])}")
|
|
print(f" cute.size(mode=[1]) = {cute.size(tStS.layout, mode=[1])}")
|
|
s_tmem_cols = find_tmem_tensor_col_offset(tStS)
|
|
print(f" find_tmem_tensor_col_offset(tStS) = {s_tmem_cols}")
|
|
|
|
# PV accumulator O fragment
|
|
pv_acc_shape = pv_thr.partition_shape_C(mma_tiler[:2])
|
|
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
|
print(f"=== PV accumulator O fragment ===")
|
|
print(f" tOtO.layout = {tOtO.layout}")
|
|
print(f" cute.size(tOtO.layout) = {cute.size(tOtO.layout)}")
|
|
print(f" cute.cosize(tOtO.layout) = {cute.cosize(tOtO.layout)}")
|
|
print(f" cute.size(mode=[0]) = {cute.size(tOtO.layout, mode=[0])}")
|
|
print(f" cute.size(mode=[1]) = {cute.size(tOtO.layout, mode=[1])}")
|
|
o_tmem_cols = find_tmem_tensor_col_offset(tOtO)
|
|
print(f" find_tmem_tensor_col_offset(tOtO) = {o_tmem_cols}")
|
|
|
|
# Q2: PV A-fragment (P operand from TMEM)
|
|
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
|
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
|
tOrP_base = pv_thr.make_fragment_A(tP)
|
|
tOrP_sliced = tOrP_base[(None, None, None, 0)]
|
|
print(f"=== Q2: PV A-fragment (P operand) ===")
|
|
print(f" tP.layout = {tP.layout}")
|
|
print(f" cute.size(tP.layout) = {cute.size(tP.layout)}")
|
|
print(f" cute.cosize(tP.layout) = {cute.cosize(tP.layout)}")
|
|
print(f" tOrP_sliced.layout = {tOrP_sliced.layout}")
|
|
print(f" cute.size(tOrP_sliced.layout) = {cute.size(tOrP_sliced.layout)}")
|
|
print(f" cute.cosize(tOrP_sliced.layout) = {cute.cosize(tOrP_sliced.layout)}")
|
|
p_tmem_cols = find_tmem_tensor_col_offset(tOrP_sliced)
|
|
print(f" find_tmem_tensor_col_offset(tOrP_sliced) = {p_tmem_cols}")
|
|
# Decompose 32800
|
|
print(f" 32800 in hex = 0x{32800:04x}")
|
|
print(f" 32800 - 0x8000 = {32800 - 0x8000}")
|
|
print(f" 32800 & 0x0000FFFF = {32800 & 0x0000FFFF}")
|
|
print(f" p_tmem_cols in hex = 0x{p_tmem_cols:04x}")
|
|
if isinstance(p_tmem_cols, int):
|
|
print(f" p_tmem_cols & 0x0000FFFF = {p_tmem_cols & 0x0000FFFF}")
|
|
print(f" p_tmem_cols >> 16 = {p_tmem_cols >> 16}")
|
|
|
|
# Staged fragments
|
|
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
|
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
|
print(f"=== Staged fragments ===")
|
|
print(f" find_tmem_tensor_col_offset(tCtS_fake) = {find_tmem_tensor_col_offset(tCtS_fake)}")
|
|
print(f" find_tmem_tensor_col_offset(tCtO_fake) = {find_tmem_tensor_col_offset(tCtO_fake)}")
|
|
print(f" get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake]) = {utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch='sm_100')}")
|
|
|
|
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
print("Compiling diagnostics...", flush=True)
|
|
compiled = cute.compile(diag_tmem, stream)
|
|
print("Done. Results above.", flush=True)
|