fully revert FmhaKernel changes to debug regression

This commit is contained in:
2026-05-25 17:04:31 +00:00
parent 0b9f9da2f7
commit 13b5afc471

View File

@@ -16,7 +16,7 @@ import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1):
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True):
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
@@ -28,8 +28,6 @@ class FmhaKernel:
self.pv_n_tile = 128
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.num_query_heads = num_query_heads
self.batch_size = batch_size
self.normalize = normalize # D5a: False = emit un-normalized O + lse
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16