Files
nvfp4-megamoe-kernel/tests/test_bf16_pack.py
biondizzle 467ade37b2 Stage B: C-fragment vs A-fragment TMEM layout mismatch diagnosed
Key finding: C-fragment and A-fragment use different physical TMEM address
mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA
reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16.

Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999)
Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02)
Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
2026-05-21 00:12:47 +00:00

238 lines
13 KiB
Python

"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1.
No recast, no BF16, no packing. Pure FP32 copy between TMEM regions."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class BF16PackTest:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2; self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
self.s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = self.s_cols * 2
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
# LD and ST on same layout
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
thr_st = tiled_st.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
tStS = thr_st.partition_D(tStS1)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
tStcS = thr_st.partition_S(tScS)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# FP32 ld → FP32 st, NO recast
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# Direct copy: ld register → st register (same shape since same layout)
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
# Since ld and st have the same C-fragment layout and same identity tensor,
# the register shapes should match. Copy element by element.
for i in cutlass.range(cute.size(rLd), vectorize=True):
rSt[i] = rLd[i].to(self.q_dtype).to(self.qk_acc_dtype)
cute.copy(tiled_st, rSt, tStS)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# epi reads S1
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = BF16PackTest(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('BF16 elemwise ld→st to S1: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()