From 39367265e57513f3cd9c145266c0dfe2d64988ea Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 03:53:19 +0000 Subject: [PATCH] D1: FIX qk_mma_tiler K-dim = head_dim (was hardcoded to 64, broke hd>64) --- dsv4/kernels/attention/fmha.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 509cd56f..b41c344f 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -36,8 +36,9 @@ class FmhaKernel: def _setup(self, qk_mma, pv_mma): qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) - self.qk_mma_tiler = (128, 128, qk_ik * 4) - print(f"_setup: head_dim={self.head_dim}, qk_ik={qk_ik}, qk_mma_tiler={self.qk_mma_tiler}") + # QK GEMM K-dim = head_dim. Each MMA sub-tile covers qk_ik*4 elements. + # The tiler K must be head_dim so the QK loop iterates over all K sub-tiles. + self.qk_mma_tiler = (128, 128, self.head_dim) pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) self.mma_tiler = self.qk_mma_tiler