fix: revert to composition layout for hand-constructed atoms (matching CUTLASS)
This commit is contained in:
@@ -234,12 +234,13 @@ class FmhaV3StageCMulti:
|
||||
row_sum = Float32(0.0)
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
|
||||
# O rescale atoms (hand-constructed, for per-tile O *= acc_scale)
|
||||
# Use logical_divide (not composition) to match get_tmem_load_op layout
|
||||
# O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale)
|
||||
corr_tile_size = 16
|
||||
tOcO = pv_thr.partition_C(cS)
|
||||
tOtO_epi_r = cute.logical_divide(tOtO0, cute.make_layout((128, corr_tile_size)))
|
||||
tOcO_epi_r = cute.logical_divide(tOcO, cute.make_layout((128, corr_tile_size)))
|
||||
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
|
||||
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
|
||||
tmem_load_o_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
|
||||
self.acc_dtype,
|
||||
@@ -248,13 +249,13 @@ class FmhaV3StageCMulti:
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
|
||||
self.acc_dtype,
|
||||
)
|
||||
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_epi_r[(None, None), 0])
|
||||
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_epi_r[(None, None), 0])
|
||||
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
|
||||
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
|
||||
thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
|
||||
thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
|
||||
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_epi_r[(None, None), None])
|
||||
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_epi_r[(None, None), None])
|
||||
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_epi_r[(None, None), None])
|
||||
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i)
|
||||
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i)
|
||||
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
|
||||
n_corr_tiles = HEAD_DIM // corr_tile_size
|
||||
|
||||
for kt in range(self.n_kv_tiles):
|
||||
|
||||
Reference in New Issue
Block a user