D1.4: Fix merge test - use use_smem_p=False for hd=256 kernel (SMEM budget)

This commit is contained in:
2026-05-24 16:36:48 +00:00
parent d70f083e17
commit bd08bfee8e

View File

@@ -37,7 +37,7 @@ def test():
# Use the hd=256 kernel (no k_sub path) with k_tile=256
# Call once per k_sub tile, merge results via online softmax
kernel = FmhaKernel(head_dim=k_tile, s_k=n, normalize=False)
kernel = FmhaKernel(head_dim=k_tile, s_k=n, use_smem_p=False, normalize=False)
pv_n_tile = kernel.pv_n_tile
n_pv_tiles = kernel.n_pv_tiles
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)