D1: FIX qk_mma_tiler K-dim = head_dim (was hardcoded to 64, broke hd>64)
This commit is contained in:
@@ -36,8 +36,9 @@ class FmhaKernel:
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (128, 128, qk_ik * 4)
|
||||
print(f"_setup: head_dim={self.head_dim}, qk_ik={qk_ik}, qk_mma_tiler={self.qk_mma_tiler}")
|
||||
# QK GEMM K-dim = head_dim. Each MMA sub-tile covers qk_ik*4 elements.
|
||||
# The tiler K must be head_dim so the QK loop iterates over all K sub-tiles.
|
||||
self.qk_mma_tiler = (128, 128, self.head_dim)
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
|
||||
Reference in New Issue
Block a user