diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index ae49935b..3a016fa9 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -291,12 +291,6 @@ class FmhaKernel: row_max = -Float32.inf row_sum = Float32(0.0) - - # Define tTMrO UNCONDITIONALLY (CuTeDSL scoping rule). - # Used for O rescale (kt > 0) and O normalization (after loop). - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) scale_log2 = Float32(self.scale_softmax_log2) # O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale) @@ -323,6 +317,12 @@ class FmhaKernel: tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) n_corr_tiles = self.pv_n_tile // corr_tile_size + # Define tTMrO UNCONDITIONALLY (CuTeDSL scoping rule). + # Used for O rescale (kt > 0) and O normalization (after loop). + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) + for kt in range(self.n_kv_tiles): si_handle = s_cons.wait_and_advance()