Fix ALL loops: use self.n_kv_tiles everywhere
The MMA loop (cutlass.range) and MMA consumer loop (range) also used cute.size(gK, mode=[3]) which returns 1 for all n. Fixed all 3 loops: 1. TMA load loop (cutlass.range, line 215) 2. MMA consumer loop (range, line 231) 3. Softmax loop (range, line 324) This was causing the deadlock — MMA only produced S[0] while softmax waited for S[1].
This commit is contained in:
@@ -212,7 +212,7 @@ class FmhaV3StageCMulti:
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset(); pk = kvp.try_acquire()
|
||||
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
|
||||
for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1):
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
@@ -228,7 +228,7 @@ class FmhaV3StageCMulti:
|
||||
kvc.reset(); pk = kvc.try_wait()
|
||||
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_st)
|
||||
for kt in range(n_kv_tiles):
|
||||
for kt in range(self.n_kv_tiles):
|
||||
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
|
||||
sh = s_prod.acquire_and_advance()
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
|
||||
Reference in New Issue
Block a user