Files
nvfp4-megamoe-kernel/tests/test_mma_si_only.py
biondizzle 7a8945eb76 Stage B: pipeline deadlock fixed, V MN-major applied, PV output garbage
Pipeline deadlock fixed:
- No cta_layout_vmnk on mma_si PipelineUmmaAsync
- TMA warp excluded from tmem.wait_for_alloc
- PipelineTmaStore (not TmaStorePipeline)

Bug 1 (V MN-major): fix applied
- PV MMA uses v_major=OperandMajorMode.MN
- V shaped (64,128) strides(1,64) via as_strided

Bug 2 (softmax packing): C-fragment composition store applied
- FP32 to BF16 packing works
- St32x32bOp uses Float32 (not BFloat16)

Bug 3 (PV garbage): investigating
- PV MMA cosine ~0.01 against reference
- Suspected TMEM layout mismatch between softmax P store and PV A-fragment read

Test results:
- test_mma_si_only: cosine 0.999999 PASS
- test_mma_si_pv: cosine 0.01 FAIL (pipeline works, PV output wrong)
2026-05-21 04:10:07 +00:00

248 lines
13 KiB
Python

"""
Minimal test: Stage A + mma_si pipeline (no PV, no V).
If this deadlocks, the mma_si pipeline is broken.
If this passes, the deadlock is caused by adding V/PV.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class MmaSiTest:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
def _setup(self, tiled_mma):
mma_inst_k = cute.size(tiled_mma.shape_mnk, mode=[2])
self.mma_tiler = (*self.mma_tiler_mn, mma_inst_k * 4)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (tiled_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1], self.mma_tiler[2])
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(tiled_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(tiled_mma, self.mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(tiled_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.q_dtype = a.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
tiled_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
self._setup(tiled_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id),
a, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id),
b, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(tiled_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, tiled_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(tiled_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] # ADDED: mma_si
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
# ADDED: mma_si pipeline (same as v27)
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
thr_mma = tiled_mma.get_slice(0)
tCgA = thr_mma.partition_A(gA); tCgB = thr_mma.partition_B(gB); tCgC = thr_mma.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = tiled_mma.make_fragment_A(sA); tCrB = tiled_mma.make_fragment_B(sB)
acc_shape = thr_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA WARP
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA WARP — same as Stage A but with mma_si pipeline added (just acquire/commit, no PV)
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc = tCtAcc_base[(None, None, None, 0)]
ab_c.reset(); peek = ab_c.try_wait()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
# ADDED: mma_si acquire (just like v27)
s0_handle = mma_si_prod.acquire_and_advance()
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(tiled_mma, tCtAcc, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tCtAcc)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
# ADDED: mma_si commit + second acquire (like v27)
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance() # wait for "softmax"
# In real use, softmax would happen here. For this test, just release immediately.
# The epilogue will do mma_si_cons wait_and_advance then release.
# After the second acquire returns, continue to acc_pipe commit.
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# EPILOGUE WARPS — same as Stage A but with mma_si wait+release added
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
# ADDED: mma_si wait + release (simulating softmax)
si_handle = mma_si_cons.wait_and_advance()
# (no actual softmax — just release immediately)
si_handle.release()
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtAcc_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 64
a = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
b = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = a[:,:,0].float() @ b[:,:,0].float().T
import cutlass.torch as ct
mA = ct.from_dlpack(a).mark_layout_dynamic(leading_dim=ct.get_leading_dim(a))
mB = ct.from_dlpack(b).mark_layout_dynamic(leading_dim=ct.get_leading_dim(b))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = MmaSiTest(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mA, mB, mC, stream)
print('Running...', flush=True)
compiled(mA, mB, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('MMA+mma_si only test: cosine {:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()