v29: FIX DEADLOCK - add V bytes to num_tma_load_bytes. V=I(128,128) cosine 1.0
This commit is contained in:
@@ -31,7 +31,7 @@ class StageBIdentitySoftmax:
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[1], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
@@ -69,7 +69,7 @@ class StageBIdentitySoftmax:
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + cute.size_in_bytes(self.b_dtype, cute.slice_(self.v_smem_s, (None, None, None, 0)))
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
@@ -240,12 +240,11 @@ class StageBIdentitySoftmax:
|
||||
s0_handle.commit()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
|
||||
Reference in New Issue
Block a user