D1.5: Move tTMrO after O rescale atoms (fix tTMEM_LOADcO reference)
This commit is contained in:
@@ -291,12 +291,6 @@ class FmhaKernel:
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
|
||||
# Define tTMrO UNCONDITIONALLY (CuTeDSL scoping rule).
|
||||
# Used for O rescale (kt > 0) and O normalization (after loop).
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
|
||||
# O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale)
|
||||
@@ -323,6 +317,12 @@ class FmhaKernel:
|
||||
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
|
||||
n_corr_tiles = self.pv_n_tile // corr_tile_size
|
||||
|
||||
# Define tTMrO UNCONDITIONALLY (CuTeDSL scoping rule).
|
||||
# Used for O rescale (kt > 0) and O normalization (after loop).
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
|
||||
for kt in range(self.n_kv_tiles):
|
||||
si_handle = s_cons.wait_and_advance()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user