From dd364b6d4d760f1bf867387fa3927bdb05706c09 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 21 May 2026 22:07:53 +0000 Subject: [PATCH] 10-warp idle test: no crash but cosine 0.29 (6-warp gives 0.999999) Adding 4 idle warps (4-7) to 320-thread CTA: - No crash, no deadlock (idle warps just pass) - But output is garbage: cosine 0.29 vs 0.999999 Same softmax+MMA code, same TMEM layout, same barriers. Only difference: mma_warp_id=8 (was 4), threads_per_cta=320 (was 192) and 4 idle warps 4-7. Something in the pipeline/barrier system assumes the old 6-warp topology. Need to identify which component uses threads_per_cta or warp_idx in a way that breaks with more warps. --- tests/unit/test_fmha_v3_tenwarp.py | 78 +++++++++++++----------------- 1 file changed, 34 insertions(+), 44 deletions(-) diff --git a/tests/unit/test_fmha_v3_tenwarp.py b/tests/unit/test_fmha_v3_tenwarp.py index 173fbf5a..3f7af3b9 100644 --- a/tests/unit/test_fmha_v3_tenwarp.py +++ b/tests/unit/test_fmha_v3_tenwarp.py @@ -1,9 +1,9 @@ -"""Minimal test: 10-warp architecture with identity softmax (no vector, no correction math). -Goal: verify the 4 softmax + 4 epilogue + 1 MMA + 1 TMA pipeline works structurally. -Softmax warps (0-3): load S, identity softmax, store P. -Epilogue warps (4-7): read O from TMEM, store to GMEM via epilogue. -MMA warp (8): QK + PV. -TMA warp (9): load Q, K, V. +""" +FMHA v3 10-warp: Same logic as 6-warp but with 4 idle epilogue warps (4-7). +Testing: does adding 4 idle warps to a 320-thread CTA break anything? +softmax warps 0-3 (same as 6-warp's epilogue warps 0-3) +MMA warp 8 (was 4), TMA warp 9 (was 5) +Idle warps 4-7 (just pass, no TMEM/barrier/pipeline interaction) """ import math, torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 @@ -18,26 +18,21 @@ HEAD_DIM = 64 class FmhaV3TenWarp: def __init__(self, s_k: int = 128): self.s_k = s_k - self.acc_dtype = Float32; self.qk_acc_dtype = Float32; self.pv_acc_dtype = Float32 + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE - self.softmax_warp_ids = (0,1,2,3) - self.epilogue_warp_id = (4,5,6,7) - self.mma_warp_id = 8 - self.tma_warp_id = 9 - self.threads_per_warp = 32 - self.threads_per_cta = 320 - self.num_c_stage = 2 - self.kv_stage = 2; self.q_stage = 1 - self.tmem_s0_offset = 0; self.tmem_p0_offset = 32 - self.scale_softmax_log2 = Float32(1.0 / math.sqrt(HEAD_DIM) * math.log2(math.e)) + self.epilogue_warp_id = (0,1,2,3) + self.mma_warp_id = 8; self.tma_warp_id = 9 + self.threads_per_cta = 320; self.num_c_stage = 2 + self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2 def _setup(self, qk_mma, pv_mma): qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) self.qk_mma_tiler = (128, 128, qk_ik * 4) pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik)) + self.mma_tiler = self.qk_mma_tiler self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2]) self.c_layout = LayoutEnum.ROW_MAJOR @@ -52,6 +47,7 @@ class FmhaV3TenWarp: tStS = qk_thr.make_fragment_C(qk_as) pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) tOtO = pv_thr.make_fragment_C(pv_as) + self.tmem_s0_offset = 0; self.tmem_p0_offset = 32 p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width p_end = self.tmem_p0_offset + p_cols_fp32 o_after = max(self.qk_mma_tiler[1], p_end) @@ -64,13 +60,14 @@ class FmhaV3TenWarp: q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta + self.scale_softmax_log2 = Float32(1.0 / math.sqrt(HEAD_DIM) * math.log2(math.e)) @cute.jit def __call__(self, q, k, v, c, stream): self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() - v_fmha = cute.make_tensor(v.iterator, cute.make_layout((HEAD_DIM, self.s_k, 1), stride=(1, HEAD_DIM, HEAD_DIM * self.s_k))) + v_fmha = cute.make_tensor(v.iterator, cute.make_layout((HEAD_DIM, 128, 1), stride=(1, HEAD_DIM, HEAD_DIM * 128))) self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() self.c_layout = LayoutEnum.from_tensor(c) qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) @@ -88,12 +85,8 @@ class FmhaV3TenWarp: def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx,_,_ = cute.arch.thread_idx() - is_softmax = warp_idx < 4 - is_epilogue = warp_idx >= 4 and warp_idx < 8 - is_mma = warp_idx == 8 - is_tma = warp_idx == 9 - if is_tma: + if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) @cute.struct @@ -107,11 +100,11 @@ class FmhaV3TenWarp: qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() - s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*4)).make_participants() - softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32*4 + 32*1) - acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,4),cta_layout_vmnk=cl_vmnk,defer_sync=True) - tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*5) - tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=0,is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() + softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) + acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True) sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner) @@ -146,11 +139,12 @@ class FmhaV3TenWarp: tOrP_base = pv_thr.make_fragment_A(tP) tOrP = tOrP_base[(None,None,None,0)] tOrP0 = cute.make_tensor(tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, tOrP.layout) + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_as, self.num_acc_stage)) tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # TMA - if is_tma: + if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() cute.copy(tma_q,tAgQ[(None,qh.count)],tAsQ[(None,qh.index)],tma_bar_ptr=qh.barrier) qp.tail() @@ -165,7 +159,7 @@ class FmhaV3TenWarp: kvp.tail() # MMA - if is_mma: + if warp_idx == self.mma_warp_id: tmem.wait_for_alloc() qc.reset(); qh = qc.wait_and_advance(); qh.release() kvc.reset(); pk = kvc.try_wait() @@ -185,16 +179,18 @@ class FmhaV3TenWarp: pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) for kb in cutlass.range(cute.size(tOrP0,mode=[2]), unroll_full=True): cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,vh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) cute.arch.fence_view_async_tmem_store() vh.release() acc_pipe.producer_commit(acc_st); acc_st.advance() acc_pipe.producer_tail(acc_st) - # Softmax (identity) - if is_softmax: + # Softmax + Epilogue (warps 0-3, same as 6-warp) + if warp_idx < self.mma_warp_id and warp_idx < 4: tmem.allocate(self.num_tmem_alloc_cols) tmem.wait_for_alloc() - sfw_idx = tidx % (32 * 4) + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) thr_load = tiled_tmem_load.get_slice(sfw_idx) @@ -233,25 +229,19 @@ class FmhaV3TenWarp: cute.arch.fence_view_async_tmem_store() si_handle.release() softmax_done_bar.arrive() - # tmem.relinquish_alloc_permit() done after epilogue - # Epilogue done by softmax warps (testing 10-warp structure) - # Correction/epilogue warps just participate in TMEM alloc barrier - if is_softmax: - # ... (softmax already handled above, add epilogue after softmax loop) # Epilogue - tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) - c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * 4) + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe) c_pipe.producer_tail() - # tmem.free(tmem_ptr) # skip free - not required for correctness - - if is_epilogue: - tmem.wait_for_alloc() tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + # Idle warps 4-7: absolutely no TMEM/barrier/pipeline interaction + # (warp_idx >= 4 and warp_idx < 8 - these do nothing) def test():