Files
nvfp4-megamoe-kernel/tests/archive/test_tma_coord.py
biondizzle 524f0bdfb4 Clean up: archive diagnostics and superseded tests
Kept:
- example10 (CUTLASS LLM, O rescale + final normalize)
- example9 (SSA kv_coord version)
- working_softmax_maybe.py (working softmax snapshot from before the nuke)
- test_fmha_v3_stage_c.py (identity softmax baseline, n=128 cos 0.999998)
- test_fmha_v3.py (Stage A+B baseline)
- layertest.py, cudagraph_test.py (required)
- test_cutedsl.py, test_fp4_roundtrip.py (NVFP4 tests)

Archived: diag_tma_*, example8, test_diag_multitile, test_reference_fmha,
test_ref_minimal, test_tma_coord, test_fmha_v3_diag*, test_fmha_v3_12w,
test_dense_router, test_interleave*, test_fused_step1, test_router,
test_cache, test_compile_custom_op, test_custom_op, test_layer_schedule
2026-05-23 00:17:07 +00:00

103 lines
4.5 KiB
Python

"""Minimal TMA multi-tile test: does the TMA coordinate propagate at runtime?
This is a tiny kernel that loads K from two different tiles and checks if the data is different.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
HEAD_DIM = 64
class TmaMultiTileTest:
def __init__(self, s_k=256):
self.s_k = s_k
self.q_dtype = BFloat16
self.cta_group = tcgen05.CtaGroup.ONE
self.kv_stage = 2
def _setup(self, qk_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
cta = cute.size(qk_mma.thr_id.shape)
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta
@cute.jit
def __call__(self, k, out_buf, stream):
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
cute.nvgpu.OperandMajorMode.K, b_major,
Float32, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_layout_vmnk.shape, qk_mma.thr_id),
k, k_s, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
self._kernel(qk_mma, tma_k, mK, out_buf, self.cluster_layout_vmnk, self.k_smem_s).launch(
grid=(1,1,1), block=[32,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_k, mK, out_buf, cl_vmnk, k_smem_s):
tidx,_,_ = cute.arch.thread_idx()
cpasync.prefetch_descriptor(tma_k)
sK = cute.make_tensor(self.q_dtype, cute.slice_(k_smem_s,(None,None,None,0)), byte_alignment=128)
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgK = qk_thr.partition_B(gK)
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
# Test different slice patterns
tBgK_nn = tBgK[(None,None,0,0)] # Mode 1 = GMEM tiles (should work for multi-tile)
tBgK_n0 = tBgK[(None,0,None,0)] # Mode 1 fixed to 0 (broken for multi-tile)
print(f"tBgK shape: {cute.shape(tBgK)}", end="")
print(f" tBgK_nn shape: {cute.shape(tBgK_nn)}", end="")
print(f" tBgK_n0 shape: {cute.shape(tBgK_n0)}", end="")
print(f" n_kv_tiles: {n_kv_tiles}")
# Load K from each tile using the (None,None,0,0) slice
# and write the first 4 elements to out_buf for verification
for kt in range(n_kv_tiles):
cute.copy(tma_k, tBgK_nn[(None, Int32(kt))], tBsK[(None, 0)], tma_bar_ptr=None)
# Read first 4 BF16 values from SMEM and write to out_buf
for i in range(4):
out_buf[kt * 4 + i] = sK[0, i, 0, 0, 0]
def test():
for n in [256]:
torch.manual_seed(42)
k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
out_buf = torch.zeros(n // 128 * 4, dtype=torch.bfloat16, device='cuda')
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mO = ct.from_dlpack(out_buf).mark_layout_dynamic(leading_dim=ct.get_leading_dim(out_buf))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = TmaMultiTileTest(s_k=n)
print(f'n={n}: Compiling...', flush=True)
compiled = cute.compile(kernel, mK, mO, stream)
compiled(mK, mO, stream)
torch.cuda.synchronize()
print(f'out_buf: {out_buf.tolist()}')
# Check if tile 0 and tile 1 data are different
tile0 = k[:128, 0, 0][:4].tolist()
tile1 = k[128:, 0, 0][:4].tolist()
print(f'tile0 first 4: {tile0}')
print(f'tile1 first 4: {tile1}')
if __name__ == '__main__':
test()