D1: Add K sub-tile loop for hd=512 (const_expr guarded, hd≤256 path unchanged)
This commit is contained in:
@@ -181,7 +181,6 @@ class FmhaKernel:
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
|
||||
print(f"TMA: tAgQ shape={cute.shape(tAgQ)}, tBgK shape={cute.shape(tBgK)}, tVgV shape={cute.shape(tVgV)}")
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
@@ -213,50 +212,104 @@ class FmhaKernel:
|
||||
|
||||
# ===== TMA LOAD warp =====
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset(); pk = kvp.try_acquire()
|
||||
for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1):
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
if const_expr(self.n_k_sub_tiles > 1):
|
||||
# K sub-tiling path (hd=512): serialized TMA loads with barrier sync
|
||||
_ksub_bar = pipeline.NamedBarrier(barrier_id=6, num_threads=64) # TMA + MMA warps
|
||||
_v_bar = pipeline.NamedBarrier(barrier_id=7, num_threads=64) # TMA + MMA warps
|
||||
for kt in range(self.n_kv_tiles):
|
||||
for k_sub in range(self.n_k_sub_tiles):
|
||||
# Load Q[k_sub] → sQ
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, 0)])
|
||||
# Load K[k_sub] → sK
|
||||
cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, 0)])
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
# Sync with MMA warp: sQ and sK ready
|
||||
_ksub_bar.arrive_and_wait()
|
||||
# Load V[kt] → sV
|
||||
cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, 0)])
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
# Sync with MMA warp: sV ready for PV
|
||||
_v_bar.arrive_and_wait()
|
||||
else:
|
||||
# Original pipeline path (hd≤256)
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset(); pk = kvp.try_acquire()
|
||||
for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1):
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
# ===== MMA warp =====
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
qc.reset(); qh = qc.wait_and_advance(); qh.release()
|
||||
kvc.reset(); pk = kvc.try_wait()
|
||||
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_st)
|
||||
for kt in range(self.n_kv_tiles):
|
||||
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
|
||||
sh = s_prod.acquire_and_advance()
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
sh.commit()
|
||||
softmax_done_bar.arrive_and_wait()
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
|
||||
if not self.use_smem_p:
|
||||
# TMEM-P: PV reads P from TMEM
|
||||
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
else:
|
||||
# SMEM-P: PV reads P from SMEM
|
||||
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
kvh.release()
|
||||
acc_pipe.producer_commit(acc_st); acc_st.advance()
|
||||
final_o_bar.arrive()
|
||||
acc_pipe.producer_tail(acc_st)
|
||||
if const_expr(self.n_k_sub_tiles > 1):
|
||||
# K sub-tiling path (hd=512): serialized with barrier sync
|
||||
_ksub_bar = pipeline.NamedBarrier(barrier_id=6, num_threads=64)
|
||||
_v_bar = pipeline.NamedBarrier(barrier_id=7, num_threads=64)
|
||||
for kt in range(self.n_kv_tiles):
|
||||
for k_sub in range(self.n_k_sub_tiles):
|
||||
# Wait for TMA warp: sQ and sK ready
|
||||
_ksub_bar.arrive_and_wait()
|
||||
# QK GEMM: load Q from sQ, K from sK, accumulate into S
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, k_sub != 0)
|
||||
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,0)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
# After all k_sub: S has full QK for this kt
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
# Signal softmax
|
||||
softmax_done_bar.arrive() # signal that S is ready
|
||||
# Wait for V load
|
||||
_v_bar.arrive_and_wait()
|
||||
# PV GEMM
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
|
||||
if not self.use_smem_p:
|
||||
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,0)], tOtO0)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
else:
|
||||
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,0)], tOtO0)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
softmax_done_bar.arrive() # signal PV complete for this kt
|
||||
final_o_bar.arrive()
|
||||
else:
|
||||
# Original pipeline path (hd≤256)
|
||||
qc.reset(); qh = qc.wait_and_advance(); qh.release()
|
||||
kvc.reset(); pk = kvc.try_wait()
|
||||
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_st)
|
||||
for kt in range(self.n_kv_tiles):
|
||||
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
|
||||
sh = s_prod.acquire_and_advance()
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
sh.commit()
|
||||
softmax_done_bar.arrive_and_wait()
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
|
||||
if not self.use_smem_p:
|
||||
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
else:
|
||||
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
kvh.release()
|
||||
acc_pipe.producer_commit(acc_st); acc_st.advance()
|
||||
final_o_bar.arrive()
|
||||
acc_pipe.producer_tail(acc_st)
|
||||
|
||||
# ===== SOFTMAX + CORRECTION EPILOGUE warps =====
|
||||
if warp_idx < self.mma_warp_id:
|
||||
|
||||
Reference in New Issue
Block a user