diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index 69bc40ed..c999c0bd 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -126,7 +126,7 @@ class FmhaKernel: cute.size_in_bytes(self.q_dtype, v_s)) * cta @cute.jit - def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None, c_direct=None): + def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None): self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() @@ -171,14 +171,11 @@ class FmhaKernel: # For single-head (n_h=1): grid=(1,1,1) — backward compatible if const_expr(row_sums is None): row_sums = cute.make_tensor(lse.iterator, lse.layout) - # c_direct: simple GMEM tensor for direct writes (SMEM accumulator path) - if const_expr(c_direct is None): - c_direct = cute.make_tensor(c.iterator, cute.make_layout((1, 1, 1), stride=(1, 1, 1))) - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums,c_direct).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream) + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream) @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums, mCDirect): + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx,_,_ = cute.arch.thread_idx() if warp_idx == self.tma_warp_id: @@ -609,17 +606,42 @@ class FmhaKernel: ) c_pipe.producer_tail() else: - # Path 2: write sO_acc (FP32) -> GMEM directly from registers - # Each thread handles one row (sfw_idx) of the output. - # Read sO_acc row, normalize, cast to BF16, write to GMEM via mCDirect. - for col in cutlass.range(0, self.pv_n_tile, unroll=1): - row = sfw_idx - if row < Int32(128): + # Path 2: write sO_acc (FP32) -> sC -> GMEM via TMA + # Follow the CUTLASS FMHA reference pattern: + # 1. Cast sO_acc (FP32) -> sC (BF16) using sC's layout indexing + # 2. Use flat_divide on mC to create gO, then tma_partition, then copy + + # Step 1: Cast sO_acc -> sC (BF16) + for row in cutlass.range(0, 128, unroll=1): + for col in cutlass.range(0, self.pv_n_tile, unroll=1): val = sO_acc[row, col] if const_expr(self.normalize): inv_row_sum = Float32(1.0) / row_sum val = val * inv_row_sum - mCDirect[Int32(row), Int32(col), Int32(0)] = val.to(self.o_dtype) + sC[(row, col), Int32(0), Int32(0)] = val.to(self.o_dtype) + cute.arch.fence_proxy("async.shared", space="cta") + + # Step 2: TMA store sC -> GMEM (CUTLASS FMHA reference pattern) + # Create gO from mC via flat_divide + slice (same as CUTLASS reference) + gO_qdl = cute.flat_divide( + mC, cute.select(self.pv_mma_tiler, mode=[0, 1]) + ) + gO = gO_qdl[None, None, None, Int32(0), (Int32(0), Int32(0))] + tOsO, tOgO = cpasync.tma_partition( + tma_c, 0, cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(gO, 0, 2), + ) + # Wait for all epilogue warps to finish writing to sC + epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=self.epilog_sync_bar_id, + num_threads=32 * len(self.epilogue_warp_id), + ) + epilog_sync_barrier.arrive_and_wait() + if warp_idx == self.epilogue_warp_id[0]: + cute.copy(tma_c, tOsO[None, Int32(0)], tOgO[None, Int32(0)]) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) # Compute LSE: lse = ln(row_sum) + row_max * ln(2) # Only when emitting un-normalized output (D5a path). diff --git a/tests/unit/test_smem_acc.py b/tests/unit/test_smem_acc.py index ffebeb29..643257e8 100644 --- a/tests/unit/test_smem_acc.py +++ b/tests/unit/test_smem_acc.py @@ -51,7 +51,7 @@ def test_smem_acc(hd=64, s_k=256, use_smem_p=False, normalize=False): mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor)) print(f' hd={hd}, s_k={s_k} ({n_kv_tiles} KV tiles, pv_n_tile={pv_n_tile}, n_pv_tiles={n_pv_tiles}): Compiling...', flush=True) - compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS, c_direct=mC) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS) for nt in range(n_pv_tiles): v_start = nt * pv_n_tile @@ -69,7 +69,7 @@ def test_smem_acc(hd=64, s_k=256, use_smem_p=False, normalize=False): mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor)) - compiled(mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS, c_direct=mC) + compiled(mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS) torch.cuda.synchronize() c[:, v_start:v_end, :] = c_tile