From 2c36cd0d32fae1eebab942c1a15a1787e2d7f26f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 09:04:01 +0000 Subject: [PATCH] Stage D1: Multi-PV-tile support for hd>256 (tcgen05 MMA max N=256) --- dsv4/kernels/attention/fmha.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index c7f3a3bc..14bad065 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -15,11 +15,10 @@ import math class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None): + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None): self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 - self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256 self.n_pv_tiles = head_dim // self.pv_n_tile self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) @@ -32,7 +31,6 @@ class FmhaKernel: self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2 self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) - def _setup(self, qk_mma, pv_mma): qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) self.qk_mma_tiler = (128, 128, qk_ik * 4)