Fix v_fmha layout to use pv_n_tile instead of head_dim for multi-PV-tile support

This commit is contained in:
2026-05-23 09:02:01 +00:00
parent fcdfc4239c
commit eedcfd7d21

View File

@@ -92,7 +92,10 @@ class FmhaKernel:
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(self.head_dim, self.s_k, 1),
(self.pv_n_tile, self.s_k, 1),
stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k),
),
)
stride=(1, self.head_dim, self.head_dim * self.s_k),
),
)