D1.4: Fix merge test - use use_smem_p=False for hd=256 kernel (SMEM budget)
This commit is contained in:
@@ -37,7 +37,7 @@ def test():
|
||||
|
||||
# Use the hd=256 kernel (no k_sub path) with k_tile=256
|
||||
# Call once per k_sub tile, merge results via online softmax
|
||||
kernel = FmhaKernel(head_dim=k_tile, s_k=n, normalize=False)
|
||||
kernel = FmhaKernel(head_dim=k_tile, s_k=n, use_smem_p=False, normalize=False)
|
||||
pv_n_tile = kernel.pv_n_tile
|
||||
n_pv_tiles = kernel.n_pv_tiles
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
Reference in New Issue
Block a user