D1.5: Move tTMrO def before softmax loop (CuTeDSL scoping)

This commit is contained in:
2026-05-24 02:32:39 +00:00
parent 5a34865062
commit 6ead708c7d

View File

@@ -291,6 +291,12 @@ class FmhaKernel:
row_max = -Float32.inf
row_sum = Float32(0.0)
# Define tTMrO UNCONDITIONALLY (CuTeDSL scoping rule).
# Used for O rescale (kt > 0) and O normalization (after loop).
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
scale_log2 = Float32(self.scale_softmax_log2)
# O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale)
@@ -368,9 +374,6 @@ class FmhaKernel:
_sP_nostage[(j0, j1), 0, (0, 0)] = BFloat16(0.0)
cute.arch.fence_proxy("async.shared", space="cta")
if kt > 0:
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(