10-warp idle test: no crash but cosine 0.29 (6-warp gives 0.999999)

Adding 4 idle warps (4-7) to 320-thread CTA:
- No crash, no deadlock (idle warps just pass)
- But output is garbage: cosine 0.29 vs 0.999999

Same softmax+MMA code, same TMEM layout, same barriers.
Only difference: mma_warp_id=8 (was 4), threads_per_cta=320 (was 192)
and 4 idle warps 4-7.

Something in the pipeline/barrier system assumes the old 6-warp topology.
Need to identify which component uses threads_per_cta or warp_idx
in a way that breaks with more warps.
This commit is contained in:
2026-05-21 22:07:53 +00:00
parent 0d06e55770
commit dd364b6d4d

View File

@@ -1,9 +1,9 @@
"""Minimal test: 10-warp architecture with identity softmax (no vector, no correction math).
Goal: verify the 4 softmax + 4 epilogue + 1 MMA + 1 TMA pipeline works structurally.
Softmax warps (0-3): load S, identity softmax, store P.
Epilogue warps (4-7): read O from TMEM, store to GMEM via epilogue.
MMA warp (8): QK + PV.
TMA warp (9): load Q, K, V.
"""
FMHA v3 10-warp: Same logic as 6-warp but with 4 idle epilogue warps (4-7).
Testing: does adding 4 idle warps to a 320-thread CTA break anything?
softmax warps 0-3 (same as 6-warp's epilogue warps 0-3)
MMA warp 8 (was 4), TMA warp 9 (was 5)
Idle warps 4-7 (just pass, no TMEM/barrier/pipeline interaction)
"""
import math, torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
@@ -18,26 +18,21 @@ HEAD_DIM = 64
class FmhaV3TenWarp:
def __init__(self, s_k: int = 128):
self.s_k = s_k
self.acc_dtype = Float32; self.qk_acc_dtype = Float32; self.pv_acc_dtype = Float32
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.softmax_warp_ids = (0,1,2,3)
self.epilogue_warp_id = (4,5,6,7)
self.mma_warp_id = 8
self.tma_warp_id = 9
self.threads_per_warp = 32
self.threads_per_cta = 320
self.num_c_stage = 2
self.kv_stage = 2; self.q_stage = 1
self.tmem_s0_offset = 0; self.tmem_p0_offset = 32
self.scale_softmax_log2 = Float32(1.0 / math.sqrt(HEAD_DIM) * math.log2(math.e))
self.epilogue_warp_id = (0,1,2,3)
self.mma_warp_id = 8; self.tma_warp_id = 9
self.threads_per_cta = 320; self.num_c_stage = 2
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
@@ -52,6 +47,7 @@ class FmhaV3TenWarp:
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0; self.tmem_p0_offset = 32
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32
o_after = max(self.qk_mma_tiler[1], p_end)
@@ -64,13 +60,14 @@ class FmhaV3TenWarp:
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta
self.scale_softmax_log2 = Float32(1.0 / math.sqrt(HEAD_DIM) * math.log2(math.e))
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_fmha = cute.make_tensor(v.iterator, cute.make_layout((HEAD_DIM, self.s_k, 1), stride=(1, HEAD_DIM, HEAD_DIM * self.s_k)))
v_fmha = cute.make_tensor(v.iterator, cute.make_layout((HEAD_DIM, 128, 1), stride=(1, HEAD_DIM, HEAD_DIM * 128)))
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
@@ -88,12 +85,8 @@ class FmhaV3TenWarp:
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
is_softmax = warp_idx < 4
is_epilogue = warp_idx >= 4 and warp_idx < 8
is_mma = warp_idx == 8
is_tma = warp_idx == 9
if is_tma:
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
@@ -107,11 +100,11 @@ class FmhaV3TenWarp:
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*4)).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32*4 + 32*1)
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,4),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*5)
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=0,is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
@@ -146,11 +139,12 @@ class FmhaV3TenWarp:
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None,None,None,0)]
tOrP0 = cute.make_tensor(tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_as, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if is_tma:
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q,tAgQ[(None,qh.count)],tAsQ[(None,qh.index)],tma_bar_ptr=qh.barrier)
qp.tail()
@@ -165,7 +159,7 @@ class FmhaV3TenWarp:
kvp.tail()
# MMA
if is_mma:
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
@@ -185,16 +179,18 @@ class FmhaV3TenWarp:
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
for kb in cutlass.range(cute.size(tOrP0,mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,vh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
vh.release()
acc_pipe.producer_commit(acc_st); acc_st.advance()
acc_pipe.producer_tail(acc_st)
# Softmax (identity)
if is_softmax:
# Softmax + Epilogue (warps 0-3, same as 6-warp)
if warp_idx < self.mma_warp_id and warp_idx < 4:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
sfw_idx = tidx % (32 * 4)
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
@@ -233,25 +229,19 @@ class FmhaV3TenWarp:
cute.arch.fence_view_async_tmem_store()
si_handle.release()
softmax_done_bar.arrive()
# tmem.relinquish_alloc_permit() done after epilogue
# Epilogue done by softmax warps (testing 10-warp structure)
# Correction/epilogue warps just participate in TMEM alloc barrier
if is_softmax:
# ... (softmax already handled above, add epilogue after softmax loop)
# Epilogue
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * 4)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
# tmem.free(tmem_ptr) # skip free - not required for correctness
if is_epilogue:
tmem.wait_for_alloc()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
# Idle warps 4-7: absolutely no TMEM/barrier/pipeline interaction
# (warp_idx >= 4 and warp_idx < 8 - these do nothing)
def test():