Kept: - example10 (CUTLASS LLM, O rescale + final normalize) - example9 (SSA kv_coord version) - working_softmax_maybe.py (working softmax snapshot from before the nuke) - test_fmha_v3_stage_c.py (identity softmax baseline, n=128 cos 0.999998) - test_fmha_v3.py (Stage A+B baseline) - layertest.py, cudagraph_test.py (required) - test_cutedsl.py, test_fp4_roundtrip.py (NVFP4 tests) Archived: diag_tma_*, example8, test_diag_multitile, test_reference_fmha, test_ref_minimal, test_tma_coord, test_fmha_v3_diag*, test_fmha_v3_12w, test_dense_router, test_interleave*, test_fused_step1, test_router, test_cache, test_compile_custom_op, test_custom_op, test_layer_schedule
103 lines
4.5 KiB
Python
103 lines
4.5 KiB
Python
"""Minimal TMA multi-tile test: does the TMA coordinate propagate at runtime?
|
|
This is a tiny kernel that loads K from two different tiles and checks if the data is different.
|
|
"""
|
|
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
|
from cutlass.cute.nvgpu import cpasync, tcgen05
|
|
from cutlass import Float32, BFloat16, Int32, Boolean
|
|
from cutlass.utils import LayoutEnum
|
|
import cuda.bindings.driver as cuda
|
|
import cutlass.torch as ct
|
|
import math
|
|
|
|
HEAD_DIM = 64
|
|
|
|
class TmaMultiTileTest:
|
|
def __init__(self, s_k=256):
|
|
self.s_k = s_k
|
|
self.q_dtype = BFloat16
|
|
self.cta_group = tcgen05.CtaGroup.ONE
|
|
self.kv_stage = 2
|
|
|
|
def _setup(self, qk_mma):
|
|
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
|
self.qk_mma_tiler = (128, 128, qk_ik * 4)
|
|
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
|
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
|
|
cta = cute.size(qk_mma.thr_id.shape)
|
|
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
|
|
self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta
|
|
|
|
@cute.jit
|
|
def __call__(self, k, out_buf, stream):
|
|
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
|
self.q_dtype, self.q_dtype,
|
|
cute.nvgpu.OperandMajorMode.K, b_major,
|
|
Float32, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
|
|
self._setup(qk_mma)
|
|
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
|
|
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(
|
|
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_layout_vmnk.shape, qk_mma.thr_id),
|
|
k, k_s, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
|
self._kernel(qk_mma, tma_k, mK, out_buf, self.cluster_layout_vmnk, self.k_smem_s).launch(
|
|
grid=(1,1,1), block=[32,1,1], stream=stream)
|
|
|
|
@cute.kernel
|
|
def _kernel(self, qk_mma, tma_k, mK, out_buf, cl_vmnk, k_smem_s):
|
|
tidx,_,_ = cute.arch.thread_idx()
|
|
cpasync.prefetch_descriptor(tma_k)
|
|
|
|
sK = cute.make_tensor(self.q_dtype, cute.slice_(k_smem_s,(None,None,None,0)), byte_alignment=128)
|
|
|
|
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
|
|
n_kv_tiles = cute.size(gK, mode=[3])
|
|
|
|
qk_thr = qk_mma.get_slice(0)
|
|
tCgK = qk_thr.partition_B(gK)
|
|
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
|
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
|
|
|
# Test different slice patterns
|
|
tBgK_nn = tBgK[(None,None,0,0)] # Mode 1 = GMEM tiles (should work for multi-tile)
|
|
tBgK_n0 = tBgK[(None,0,None,0)] # Mode 1 fixed to 0 (broken for multi-tile)
|
|
|
|
print(f"tBgK shape: {cute.shape(tBgK)}", end="")
|
|
print(f" tBgK_nn shape: {cute.shape(tBgK_nn)}", end="")
|
|
print(f" tBgK_n0 shape: {cute.shape(tBgK_n0)}", end="")
|
|
print(f" n_kv_tiles: {n_kv_tiles}")
|
|
|
|
# Load K from each tile using the (None,None,0,0) slice
|
|
# and write the first 4 elements to out_buf for verification
|
|
for kt in range(n_kv_tiles):
|
|
cute.copy(tma_k, tBgK_nn[(None, Int32(kt))], tBsK[(None, 0)], tma_bar_ptr=None)
|
|
# Read first 4 BF16 values from SMEM and write to out_buf
|
|
for i in range(4):
|
|
out_buf[kt * 4 + i] = sK[0, i, 0, 0, 0]
|
|
|
|
|
|
def test():
|
|
for n in [256]:
|
|
torch.manual_seed(42)
|
|
k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
|
|
out_buf = torch.zeros(n // 128 * 4, dtype=torch.bfloat16, device='cuda')
|
|
|
|
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
|
mO = ct.from_dlpack(out_buf).mark_layout_dynamic(leading_dim=ct.get_leading_dim(out_buf))
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
kernel = TmaMultiTileTest(s_k=n)
|
|
print(f'n={n}: Compiling...', flush=True)
|
|
compiled = cute.compile(kernel, mK, mO, stream)
|
|
compiled(mK, mO, stream)
|
|
torch.cuda.synchronize()
|
|
|
|
print(f'out_buf: {out_buf.tolist()}')
|
|
# Check if tile 0 and tile 1 data are different
|
|
tile0 = k[:128, 0, 0][:4].tolist()
|
|
tile1 = k[128:, 0, 0][:4].tolist()
|
|
print(f'tile0 first 4: {tile0}')
|
|
print(f'tile1 first 4: {tile1}')
|
|
|
|
if __name__ == '__main__':
|
|
test()
|