Files
nvfp4-megamoe-kernel/tests/test_b_afrag2.py
biondizzle 467ade37b2 Stage B: C-fragment vs A-fragment TMEM layout mismatch diagnosed
Key finding: C-fragment and A-fragment use different physical TMEM address
mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA
reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16.

Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999)
Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02)
Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
2026-05-21 00:12:47 +00:00

217 lines
16 KiB
Python

"""Stage B: Store P via A-fragment layout with recast C-fragment iterator.
Matching the backward FMHA pattern exactly:
1. tOrP = pv_thr.make_fragment_A(tP)[None,None,None,0] (A-fragment layout)
2. tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=BF16) (C-fragment base, recast to BF16)
3. tdVrP = cute.make_tensor(tdVrP_iter + offset, tOrP.layout)
4. make_tmem_copy(St32x32bOp(Repetition(8)), BF16, tdVrP)
5. Store BF16 registers to tdVrP
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBAfrag2:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO)
self.tmem_s0_offset = 0
self.tmem_p0_offset = 0
self.tmem_o0_offset = self.s_cols * 2
self.tmem_alloc_cols = 512
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# ── P A-fragment (backward FMHA pattern) ──
# 1. Get A-fragment layout from pv_mma
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
# 2. Recast C-fragment iterator to BF16 (matching backward FMHA line 962)
tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
# 3. Create store target with A-fragment layout + recast iterator
# The offset for P within TMEM: qk_acc_dtype.width / q_dtype.width * tmem_p0_offset
# But since we recast to BF16, the offset should be in BF16 units
tdVrP = cute.make_tensor(
tdVrP_iter + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# PV MMA's A-fragment (for reading)
tOrP0 = cute.make_tensor(tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.s_cols, tOrP.layout)
# ── TMEM LOAD from C-fragment ──
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
# ── TMEM STORE via A-fragment layout (backward FMHA pattern) ──
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP)
thr_st = tiled_st.get_slice(sfw)
tStP = thr_st.partition_D(tdVrP)
# Source identity for store (A-fragment shape)
cS_P = cute.make_identity_tensor((self.qk_mma_tiler[0], self.pv_mma_tiler[2]))
tScS_P = pv_thr.partition_A(cS_P)
tStcS = thr_st.partition_S(tScS_P)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
print(f'[A2] tdVrP.layout: {tdVrP.layout}')
print(f'[A2] tOrP0.layout: {tOrP0.layout}')
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
# PV MMA
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
# SOFTMAX/EPILOGUE WARPS
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# ld FP32 from S0
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# Convert FP32 → BF16 (backward-style: true BF16 register, not recast)
rBf16 = cute.make_rmem_tensor(tStcS.shape, self.q_dtype)
for i in cutlass.range(cute.size(rLd), vectorize=True):
rBf16[i] = rLd[i].to(self.q_dtype)
# Store BF16 to TMEM via A-fragment layout
cute.copy(tiled_st, rBf16, tStP)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
def test():
torch.manual_seed(42); m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBAfrag2(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Stage B A-frag2 (backward FMHA pattern): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()