diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index d11482a5..509cd56f 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -37,6 +37,7 @@ class FmhaKernel: def _setup(self, qk_mma, pv_mma): qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) self.qk_mma_tiler = (128, 128, qk_ik * 4) + print(f"_setup: head_dim={self.head_dim}, qk_ik={qk_ik}, qk_mma_tiler={self.qk_mma_tiler}") pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) self.mma_tiler = self.qk_mma_tiler