[Core] Always use tensor cores for Flashinfer Decode Wrapper (#23214)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv(
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.\
|
||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
||||
use_tensor_cores=(
|
||||
(num_query_heads//num_kv_heads) > 4)
|
||||
)
|
||||
use_tensor_cores=True)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
@@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
use_tensor_cores = (num_query_heads // num_kv_heads) > 4
|
||||
use_tensor_cores = True
|
||||
kv_cache_dtype = torch.float8_e4m3fn
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user