Fix v_fmha layout to use pv_n_tile instead of head_dim for multi-PV-tile support
This commit is contained in:
@@ -92,7 +92,10 @@ class FmhaKernel:
|
||||
v_fmha = cute.make_tensor(
|
||||
v.iterator,
|
||||
cute.make_layout(
|
||||
(self.head_dim, self.s_k, 1),
|
||||
(self.pv_n_tile, self.s_k, 1),
|
||||
stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k),
|
||||
),
|
||||
)
|
||||
stride=(1, self.head_dim, self.head_dim * self.s_k),
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user