Files
nvfp4-megamoe-kernel/tests/test_afrag_roundtrip.py
biondizzle 467ade37b2 Stage B: C-fragment vs A-fragment TMEM layout mismatch diagnosed
Key finding: C-fragment and A-fragment use different physical TMEM address
mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA
reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16.

Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999)
Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02)
Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
2026-05-21 00:12:47 +00:00

176 lines
13 KiB
Python

"""Test: ld FP32 from S0, st BF16 to A-fragment layout tdVrP,
ld BF16 back from tdVrP, epi the result.
If this works, the A-fragment store is correct and the issue is in the PV MMA."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class AFragRoundtrip:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = self.s_cols # Only need S region
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
# A-fragment for pv_mma
pv_thr = pv_mma.get_slice(0)
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
tdVrP = cute.make_tensor(tOrP.iterator, tOrP.layout)
# TMEM ld (C-fragment)
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
# TMEM st (A-fragment layout, BF16)
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP)
thr_st = tiled_st.get_slice(sfw)
tStP = thr_st.partition_D(tdVrP)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
# EPILOGUE WARPS: ld FP32 → BF16 → st A-frag → ld A-frag BF16 → FP32 → epi
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# 1. ld FP32 from S0
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# 2. Convert FP32 → BF16
rBf16 = cute.make_rmem_tensor(tLdcS.shape, self.q_dtype)
for i in cutlass.range(cute.size(rLd), vectorize=True):
rBf16[i] = rLd[i].to(self.q_dtype)
# 3. st BF16 to A-fragment layout
cute.copy(tiled_st, rBf16, tStP)
cute.arch.fence_view_async_tmem_store()
# 4. Store to A-frag done. Check if S0 epi still works.
si_handle.release()
tCtS0 = cute.make_tensor(tmem_ptr, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtS0, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
def test():
torch.manual_seed(42); m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = AFragRoundtrip(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('A-frag roundtrip: cos={:.6f} (expect 0.999 from Stage A)'.format(cos))
if __name__ == '__main__':
test()