diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index c3cf87a3..ae49935b 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -291,6 +291,12 @@ class FmhaKernel: row_max = -Float32.inf row_sum = Float32(0.0) + + # Define tTMrO UNCONDITIONALLY (CuTeDSL scoping rule). + # Used for O rescale (kt > 0) and O normalization (after loop). + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) scale_log2 = Float32(self.scale_softmax_log2) # O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale) @@ -368,9 +374,6 @@ class FmhaKernel: _sP_nostage[(j0, j1), 0, (0, 0)] = BFloat16(0.0) cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) for i in range(n_corr_tiles): tTMrO_i_ = tTMrO[None, i] tTMrO_i_layout = cute.composition(