fix: revert to composition layout for hand-constructed atoms (matching CUTLASS)

This commit is contained in:
2026-05-23 02:54:54 +00:00
parent 3a2d3c66da
commit 9c331de7ba

View File

@@ -234,12 +234,13 @@ class FmhaV3StageCMulti:
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)
# O rescale atoms (hand-constructed, for per-tile O *= acc_scale)
# Use logical_divide (not composition) to match get_tmem_load_op layout
# O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale)
corr_tile_size = 16
tOcO = pv_thr.partition_C(cS)
tOtO_epi_r = cute.logical_divide(tOtO0, cute.make_layout((128, corr_tile_size)))
tOcO_epi_r = cute.logical_divide(tOcO, cute.make_layout((128, corr_tile_size)))
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
tmem_load_o_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
self.acc_dtype,
@@ -248,13 +249,13 @@ class FmhaV3StageCMulti:
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
self.acc_dtype,
)
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_epi_r[(None, None), 0])
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_epi_r[(None, None), 0])
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_epi_r[(None, None), None])
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_epi_r[(None, None), None])
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_epi_r[(None, None), None])
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i)
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i)
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
n_corr_tiles = HEAD_DIM // corr_tile_size
for kt in range(self.n_kv_tiles):