Fix: add scale_softmax_log2, use O TMEM rescale for C9 normalization
- scale_softmax_log2 was missing from _setup (patch artifact) - C9 normalization: load O from TMEM, multiply by 1/row_sum, store back instead of trying to capture runtime value in const_expr lambda - Then use standard epilogue_tma_store with identity transform
This commit is contained in:
@@ -63,6 +63,7 @@ class FmhaV3Softmax:
|
||||
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
|
||||
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
|
||||
self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta
|
||||
self.scale_softmax_log2 = Float32(1.0 / math.sqrt(HEAD_DIM) * math.log2(math.e))
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q, k, v, c, stream):
|
||||
@@ -348,16 +349,34 @@ class FmhaV3Softmax:
|
||||
|
||||
row_sum = row_sum + tile_sum
|
||||
|
||||
# --- C9: Final normalization + epilogue TMA store ---
|
||||
# --- C9: Final normalization via O TMEM rescale ---
|
||||
# After all KV tiles, O = sum(P_i @ V_i) but unnormalized.
|
||||
# Load O, multiply by 1/row_sum, store O. Then use identity epilogue.
|
||||
inv_row_sum = cutlass.Float32(1.0) / row_sum
|
||||
|
||||
tTMrO_final = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype)
|
||||
for i in range(o_col_tiles):
|
||||
tTMrO_i_ = tTMrO_final[None, i]
|
||||
tTMrO_i_layout = cute.composition(tTMrO_i_.layout, cute.make_layout(tTMrO_final.shape[0]))
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout)
|
||||
tTMEM_STOREtO_i = cute.make_tensor(
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout)
|
||||
cute.copy(o_tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i)
|
||||
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
|
||||
cute.copy(o_tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Now O in TMEM is normalized. Use standard epilogue_tma_store with identity.
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
# C9: Normalize by 1/row_sum
|
||||
inv_row_sum = cutlass.Float32(1.0) / row_sum
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0,
|
||||
const_expr(lambda x, s=inv_row_sum: x * s),
|
||||
const_expr(lambda x: x),
|
||||
(0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
|
||||
Reference in New Issue
Block a user