Files
nvfp4-megamoe-kernel/tests/test_tmem_addressing.py
biondizzle 97656a5cd1 Stage B: two MMAs + identity softmax — crash fixed, softmax output still wrong
Key fixes:
- PipelineUmmaAsync consumer group: 32*4=128 threads (not 4 warps)
- TMEM offsets computed from find_tmem_tensor_col_offset (not hardcoded)
- P fragment from p_tmem_s.outer + make_fragment_A (matching fmha.py)
- V SMEM aliasing via recast_ptr

Status:
- Stage A: cosine 0.999999 
- Stage B: runs without crash, identity softmax cosine -0.02 
- Diagnostics: TMEM layout inspection, bisection results
2026-05-20 20:26:25 +00:00

264 lines
12 KiB
Python

"""
TMEM Addressing Test: verify offset computation from layouts.
Allocates TMEM, computes offsets from QK accumulator and PV fragment sizes,
writes known values via tcgen05.st at each offset region, reads them back
via tcgen05.ld, and verifies correctness. No MMA, no softmax, no V load.
This validates that our offset arithmetic is correct before wiring it into Stage B.
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class TmemAddressingTest:
def __init__(self, mma_tiler_mn):
self.acc_dtype = Float32
self.qk_acc_dtype = Float32
self.q_dtype = BFloat16
self.o_dtype = BFloat16
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 192
self.tmem_alloc_sync_bar_id = 2
self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
@cute.jit
def __call__(self, debug_buf: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = BFloat16
self.b_dtype = BFloat16
self.a_major = cute.nvgpu.OperandMajorMode.K
self.b_major = cute.nvgpu.OperandMajorMode.K
self.c_layout = LayoutEnum.RowMajor
# Create the same MMAs as Stage B to get the same fragment layouts
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.TMEM)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (qk_mma.thr_id.shape,))
# Compute TMEM fragment sizes from layouts
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
qk_acc_cols = cute.size(tStS.layout, mode=[1])
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
pv_acc_cols = cute.size(tOtO.layout, mode=[1])
# P operand size: tilePlikeFP32 = qk_mma_tiler[1] * q_dtype.width // 32
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
# Compute offsets
tmem_s_offset = 0
tmem_p_offset = qk_acc_cols # P right after QK accumulator
tmem_o_offset = qk_acc_cols + tilePlikeFP32 # O right after P
# Total allocation
tmem_alloc_cols = tmem_o_offset + pv_acc_cols
# JIT-time prints — these appear during compilation
print(f"[TMEM] qk_acc_cols = {qk_acc_cols}")
print(f"[TMEM] tilePlikeFP32 = {tilePlikeFP32}")
print(f"[TMEM] pv_acc_cols = {pv_acc_cols}")
print(f"[TMEM] tmem_s_offset = {tmem_s_offset}")
print(f"[TMEM] tmem_p_offset = {tmem_p_offset}")
print(f"[TMEM] tmem_o_offset = {tmem_o_offset}")
print(f"[TMEM] tmem_alloc_cols = {tmem_alloc_cols}")
self._kernel(
qk_mma, pv_mma, tStS, tOtO, tmem_alloc_cols,
tmem_s_offset, tmem_p_offset, tmem_o_offset, tilePlikeFP32,
debug_buf, self.cluster_layout_vmnk
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tStS, tOtO, tmem_alloc_cols,
tmem_s_offset, tmem_p_offset, tmem_o_offset, tilePlikeFP32,
debug_buf, cl_vmnk):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
@cute.struct
class SS:
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator()
st = smem.allocate(SS)
tmem_bar = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(
st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
pipeline.pipeline_init_wait(cluster_shape_mvnk=cl_vmnk)
# ── MMA WARP: allocate TMEM, write test values ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
# Create TMEM tensors at computed offsets
# Scores region: write 1.0
tStS0 = cute.make_tensor(tStS.iterator + tmem_s_offset, tStS.layout)
# P region: write 2.0
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + tmem_p_offset, tStS_P_layout)
# Output region: write 3.0
tOtO0 = cute.make_tensor(tOtO.iterator + tmem_o_offset, tOtO.layout)
# Use tcgen05.st to write known values into each region
# We'll use the store copy atom
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# Store to scores region (value = 1.0)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype)
tiled_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS0)
thr_store = tiled_store.get_slice(sfw_idx)
tTMEM_STOREtS = thr_store.partition_D(tStS0)
# We need a source tensor with the same shape
tTMEM_STOREcS = thr_store.partition_S(
cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])))
# Load from scores region (verify readback)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype)
tiled_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
# The MMA warp doesn't do the ld/st — the epilogue warps do.
# For this test, just signal that TMEM is ready, epilogue will verify.
# But actually, MMA warp CAN write to TMEM via cute.fill or direct MMA.
# The simplest test: MMA warp issues a QK MMA with accumulate=False (known result),
# then epilogue warps tcgen05.ld from the scores region and dump to debug_buf.
# For now: the MMA warp just signals and the epilogue does the verification.
# We'll write test values using tcgen05.st from epilogue warps (they have the copy atoms).
pass
# ── EPILOGUE WARPS: allocate TMEM, write test values, read back ──
if warp_idx < self.mma_warp_id:
tmem.allocate(tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# ── Write 1.0 to scores region ──
tStS0 = cute.make_tensor(tStS.iterator + tmem_s_offset, tStS.layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype)
tiled_store_s = tcgen05.make_tmem_copy(tmem_store_atom, tStS0)
thr_store_s = tiled_store_s.get_slice(sfw_idx)
tTMEM_STOREtS = thr_store_s.partition_D(tStS0)
tScS_s = qk_mma.get_slice(0).partition_C(
cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])))
tTMEM_STOREcS = thr_store_s.partition_S(tScS_s)
# Create register tensor filled with 1.0
tTMEM_STORErS = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.acc_dtype)
# Fill with 1.0
for i in cutlass.range(cute.size(tTMEM_STORErS), unroll_full=True):
tTMEM_STORErS.store(i, cutlass.Float32(1.0))
cute.copy(tiled_store_s, tTMEM_STORErS, tTMEM_STOREtS)
cute.arch.fence_view_async_tmem_store()
# ── Read back from scores region ──
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype)
tiled_load_s = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load_s = tiled_load_s.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load_s.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_mma.get_slice(0).partition_C(cS)
tTMEM_LOADcS = thr_load_s.partition_D(tScS)
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.acc_dtype)
cute.copy(tiled_load_s, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# Dump one value per thread to debug_buf for verification
# debug_buf shape: (threads_per_cta,) Float32
# Only epilogue warps (0..3, 128 threads) write
if tidx < 128:
val = tTMEM_LOADrS.load()
# Store first element of the loaded vector
debug_buf[tidx] = val # type: ignore
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test_tmem_addressing():
device = torch.device("cuda")
debug_buf = torch.zeros(128, dtype=torch.float32, device=device)
import cutlass.torch as cutlass_torch
mD = cutlass_torch.from_dlpack(debug_buf).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(debug_buf))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = TmemAddressingTest(mma_tiler_mn=(128, 128))
print("Compiling TMEM addressing test...", flush=True)
compiled = cute.compile(kernel, mD, stream)
print("Running...", flush=True)
compiled(mD, stream)
torch.cuda.synchronize()
print("Debug buffer (first 16 values):", debug_buf[:16].tolist())
# All values should be 1.0 if addressing is correct
nonzero = (debug_buf[:128] != 0).sum().item()
ones = (debug_buf[:128] == 1.0).sum().item()
print(f"Non-zero: {nonzero}/128, Ones: {ones}/128")
if nonzero > 0:
print("PASS: TMEM addressing works — read back non-zero values")
else:
print("FAIL: All zeros — TMEM addressing broken")
if __name__ == "__main__":
test_tmem_addressing()