D1.5: Move tTMrO after O rescale atoms (fix tTMEM_LOADcO reference)

This commit is contained in:
2026-05-24 02:39:18 +00:00
parent bfd598b937
commit bb4c35facb

View File

@@ -291,12 +291,6 @@ class FmhaKernel:
row_max = -Float32.inf
row_sum = Float32(0.0)
# Define tTMrO UNCONDITIONALLY (CuTeDSL scoping rule).
# Used for O rescale (kt > 0) and O normalization (after loop).
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
scale_log2 = Float32(self.scale_softmax_log2)
# O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale)
@@ -323,6 +317,12 @@ class FmhaKernel:
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
n_corr_tiles = self.pv_n_tile // corr_tile_size
# Define tTMrO UNCONDITIONALLY (CuTeDSL scoping rule).
# Used for O rescale (kt > 0) and O normalization (after loop).
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for kt in range(self.n_kv_tiles):
si_handle = s_cons.wait_and_advance()