From 13b5afc471e7e2ea56ac8f457cbdc7d09edeb2fb Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 25 May 2026 17:04:31 +0000 Subject: [PATCH] fully revert FmhaKernel changes to debug regression --- dsv4/kernels/attention/fmha.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 05355ea1..f42693d4 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -16,7 +16,7 @@ import math class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1): + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True): self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 @@ -28,8 +28,6 @@ class FmhaKernel: self.pv_n_tile = 128 self.n_pv_tiles = head_dim // self.pv_n_tile self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) - self.num_query_heads = num_query_heads - self.batch_size = batch_size self.normalize = normalize # D5a: False = emit un-normalized O + lse self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16