D1: fix O rescale identity tensor - use PV MMA shape not QK shape

This commit is contained in:
2026-05-24 22:02:55 +00:00
parent f1aab1bfc1
commit 55c6903980

View File

@@ -368,7 +368,8 @@ class FmhaKernel:
corr_tile_size = 16
n_corr_tiles = self.pv_n_tile // corr_tile_size
if const_expr(self.n_kv_tiles > 1):
tOcO = pv_thr.partition_C(cS)
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO = pv_thr.partition_C(cO)
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)