Add c_direct GMEM tensor for direct writes in SMEM accumulator path
This commit is contained in:
@@ -126,7 +126,7 @@ class FmhaKernel:
|
||||
cute.size_in_bytes(self.q_dtype, v_s)) * cta
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None):
|
||||
def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None, c_direct=None):
|
||||
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
@@ -171,11 +171,14 @@ class FmhaKernel:
|
||||
# For single-head (n_h=1): grid=(1,1,1) — backward compatible
|
||||
if const_expr(row_sums is None):
|
||||
row_sums = cute.make_tensor(lse.iterator, lse.layout)
|
||||
# c_direct: simple GMEM tensor for direct writes (SMEM accumulator path)
|
||||
if const_expr(c_direct is None):
|
||||
c_direct = cute.make_tensor(c.iterator, cute.make_layout((1, 1, 1), stride=(1, 1, 1)))
|
||||
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums,c_direct).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums):
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums, mCDirect):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx,_,_ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
@@ -608,22 +611,15 @@ class FmhaKernel:
|
||||
else:
|
||||
# Path 2: write sO_acc (FP32) -> GMEM directly from registers
|
||||
# Each thread handles one row (sfw_idx) of the output.
|
||||
# Read sO_acc row, normalize, cast to BF16, write to GMEM.
|
||||
#
|
||||
# For GMEM writes, use gC (the local_tile of the output tensor).
|
||||
# gC is created from mC with TMA-compatible layout.
|
||||
# We write to it using cute.copy with a universal copy atom.
|
||||
#
|
||||
# Actually, the simplest approach: write each element directly
|
||||
# to the GMEM output tensor using scalar stores.
|
||||
# gC[sfw_idx, col, 0] = BF16(sO_acc[sfw_idx, col] / row_sum)
|
||||
# Read sO_acc row, normalize, cast to BF16, write to GMEM via mCDirect.
|
||||
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
|
||||
row = sfw_idx
|
||||
if row < Int32(128):
|
||||
val = sO_acc[row, col]
|
||||
if const_expr(not self.normalize):
|
||||
val = val / row_sum
|
||||
gC[Int32(row), Int32(col), Int32(0)] = val.to(self.o_dtype)
|
||||
if const_expr(self.normalize):
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
val = val * inv_row_sum
|
||||
mCDirect[Int32(row), Int32(col), Int32(0)] = val.to(self.o_dtype)
|
||||
|
||||
# Compute LSE: lse = ln(row_sum) + row_max * ln(2)
|
||||
# Only when emitting un-normalized output (D5a path).
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_smem_acc(hd=64, s_k=256, use_smem_p=False, normalize=False):
|
||||
mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor))
|
||||
|
||||
print(f' hd={hd}, s_k={s_k} ({n_kv_tiles} KV tiles, pv_n_tile={pv_n_tile}, n_pv_tiles={n_pv_tiles}): Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS, c_direct=mC)
|
||||
|
||||
for nt in range(n_pv_tiles):
|
||||
v_start = nt * pv_n_tile
|
||||
@@ -69,7 +69,7 @@ def test_smem_acc(hd=64, s_k=256, use_smem_p=False, normalize=False):
|
||||
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
|
||||
mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor))
|
||||
|
||||
compiled(mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS)
|
||||
compiled(mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS, c_direct=mC)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
c[:, v_start:v_end, :] = c_tile
|
||||
|
||||
Reference in New Issue
Block a user