Files
nvfp4-megamoe-kernel/tests/diag_tmem.py
biondizzle 97656a5cd1 Stage B: two MMAs + identity softmax — crash fixed, softmax output still wrong
Key fixes:
- PipelineUmmaAsync consumer group: 32*4=128 threads (not 4 warps)
- TMEM offsets computed from find_tmem_tensor_col_offset (not hardcoded)
- P fragment from p_tmem_s.outer + make_fragment_A (matching fmha.py)
- V SMEM aliasing via recast_ptr

Status:
- Stage A: cosine 0.999999 
- Stage B: runs without crash, identity softmax cosine -0.02 
- Diagnostics: TMEM layout inspection, bisection results
2026-05-20 20:26:25 +00:00

92 lines
4.4 KiB
Python

"""Diagnostic: Q1, Q2 for Stage B TMEM debugging.
Uses cute.compile with a dummy kernel that prints layout info at JIT time."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
@cute.jit
def diag_tmem(stream: cuda.CUstream):
a_dtype = BFloat16; b_dtype = BFloat16
a_major = cute.nvgpu.OperandMajorMode.K
b_major = cute.nvgpu.OperandMajorMode.K
qk_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, cute.nvgpu.OperandMajorMode.K, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.TMEM)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
mma_tiler = (128, 128, qk_inst_k * 4)
pv_mma_tiler = (128, 128, pv_inst_k * 4)
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
# Q1: QK accumulator C fragment
qk_acc_shape = qk_thr.partition_shape_C(mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
print(f"=== Q1: QK accumulator C fragment ===")
print(f" tStS.layout = {tStS.layout}")
print(f" cute.size(tStS.layout) = {cute.size(tStS.layout)}")
print(f" cute.cosize(tStS.layout) = {cute.cosize(tStS.layout)}")
print(f" cute.size(mode=[0]) = {cute.size(tStS.layout, mode=[0])}")
print(f" cute.size(mode=[1]) = {cute.size(tStS.layout, mode=[1])}")
s_tmem_cols = find_tmem_tensor_col_offset(tStS)
print(f" find_tmem_tensor_col_offset(tStS) = {s_tmem_cols}")
# PV accumulator O fragment
pv_acc_shape = pv_thr.partition_shape_C(mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
print(f"=== PV accumulator O fragment ===")
print(f" tOtO.layout = {tOtO.layout}")
print(f" cute.size(tOtO.layout) = {cute.size(tOtO.layout)}")
print(f" cute.cosize(tOtO.layout) = {cute.cosize(tOtO.layout)}")
print(f" cute.size(mode=[0]) = {cute.size(tOtO.layout, mode=[0])}")
print(f" cute.size(mode=[1]) = {cute.size(tOtO.layout, mode=[1])}")
o_tmem_cols = find_tmem_tensor_col_offset(tOtO)
print(f" find_tmem_tensor_col_offset(tOtO) = {o_tmem_cols}")
# Q2: PV A-fragment (P operand from TMEM)
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP_sliced = tOrP_base[(None, None, None, 0)]
print(f"=== Q2: PV A-fragment (P operand) ===")
print(f" tP.layout = {tP.layout}")
print(f" cute.size(tP.layout) = {cute.size(tP.layout)}")
print(f" cute.cosize(tP.layout) = {cute.cosize(tP.layout)}")
print(f" tOrP_sliced.layout = {tOrP_sliced.layout}")
print(f" cute.size(tOrP_sliced.layout) = {cute.size(tOrP_sliced.layout)}")
print(f" cute.cosize(tOrP_sliced.layout) = {cute.cosize(tOrP_sliced.layout)}")
p_tmem_cols = find_tmem_tensor_col_offset(tOrP_sliced)
print(f" find_tmem_tensor_col_offset(tOrP_sliced) = {p_tmem_cols}")
# Decompose 32800
print(f" 32800 in hex = 0x{32800:04x}")
print(f" 32800 - 0x8000 = {32800 - 0x8000}")
print(f" 32800 & 0x0000FFFF = {32800 & 0x0000FFFF}")
print(f" p_tmem_cols in hex = 0x{p_tmem_cols:04x}")
if isinstance(p_tmem_cols, int):
print(f" p_tmem_cols & 0x0000FFFF = {p_tmem_cols & 0x0000FFFF}")
print(f" p_tmem_cols >> 16 = {p_tmem_cols >> 16}")
# Staged fragments
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
print(f"=== Staged fragments ===")
print(f" find_tmem_tensor_col_offset(tCtS_fake) = {find_tmem_tensor_col_offset(tCtS_fake)}")
print(f" find_tmem_tensor_col_offset(tCtO_fake) = {find_tmem_tensor_col_offset(tCtO_fake)}")
print(f" get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake]) = {utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch='sm_100')}")
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
print("Compiling diagnostics...", flush=True)
compiled = cute.compile(diag_tmem, stream)
print("Done. Results above.", flush=True)