D1.4: Fix tTMrO placeholder - define only inside const_expr block
This commit is contained in:
@@ -366,12 +366,8 @@ class FmhaKernel:
|
||||
# O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale)
|
||||
# Only needed when there are multiple KV tiles (O must be rescaled per-kt).
|
||||
# With n_kv_tiles=1, no rescale is needed (kt is always 0).
|
||||
# Define placeholder values unconditionally for CuTeDSL scoping.
|
||||
corr_tile_size = 16
|
||||
n_corr_tiles = self.pv_n_tile // corr_tile_size
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(cute.make_layout((1,)), 1), self.acc_dtype
|
||||
)
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
tOcO = pv_thr.partition_C(cS)
|
||||
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
|
||||
|
||||
Reference in New Issue
Block a user