D1.5: Move tTMrO def before softmax loop (CuTeDSL scoping)
This commit is contained in:
@@ -291,6 +291,12 @@ class FmhaKernel:
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
|
||||
# Define tTMrO UNCONDITIONALLY (CuTeDSL scoping rule).
|
||||
# Used for O rescale (kt > 0) and O normalization (after loop).
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
|
||||
# O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale)
|
||||
@@ -368,9 +374,6 @@ class FmhaKernel:
|
||||
_sP_nostage[(j0, j1), 0, (0, 0)] = BFloat16(0.0)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
if kt > 0:
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
|
||||
Reference in New Issue
Block a user