diff --git a/tests/unit/test_d1_hd512_merge.py b/tests/unit/test_d1_hd512_merge.py index fcdd061f..4e200cc5 100644 --- a/tests/unit/test_d1_hd512_merge.py +++ b/tests/unit/test_d1_hd512_merge.py @@ -37,7 +37,7 @@ def test(): # Use the hd=256 kernel (no k_sub path) with k_tile=256 # Call once per k_sub tile, merge results via online softmax - kernel = FmhaKernel(head_dim=k_tile, s_k=n, normalize=False) + kernel = FmhaKernel(head_dim=k_tile, s_k=n, use_smem_p=False, normalize=False) pv_n_tile = kernel.pv_n_tile n_pv_tiles = kernel.n_pv_tiles stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)