D1: fix O rescale identity tensor - use PV MMA shape not QK shape
This commit is contained in:
@@ -368,7 +368,8 @@ class FmhaKernel:
|
||||
corr_tile_size = 16
|
||||
n_corr_tiles = self.pv_n_tile // corr_tile_size
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
tOcO = pv_thr.partition_C(cS)
|
||||
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO = pv_thr.partition_C(cO)
|
||||
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
|
||||
|
||||
Reference in New Issue
Block a user