diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 1ae02921..c7f3a3bc 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -92,7 +92,10 @@ class FmhaKernel: v_fmha = cute.make_tensor( v.iterator, cute.make_layout( - (self.head_dim, self.s_k, 1), + (self.pv_n_tile, self.s_k, 1), + stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k), + ), + ) stride=(1, self.head_dim, self.head_dim * self.s_k), ), )