[BUG] fix crash on flashinfer backend with cudagraph disabled, when attention group_size not in [1,2,4,8] (#7509)
This commit is contained in:
@@ -4,7 +4,7 @@ import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
|
||||
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
@@ -123,7 +123,10 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.\
|
||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
||||
use_tensor_cores=(
|
||||
(num_query_heads//num_kv_heads) not in (1, 2, 4, 8))
|
||||
)
|
||||
wrapper.begin_forward(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
|
||||
Reference in New Issue
Block a user