[Core] Always use tensor cores for Flashinfer Decode Wrapper (#23214)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2025-08-21 13:02:11 -07:00
committed by GitHub
parent 3496274663
commit 1d353b6352
5 changed files with 31 additions and 64 deletions

View File

@@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv(
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4)
)
use_tensor_cores=True)
wrapper.plan(
kv_indptr,
kv_indices,
@@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
use_tensor_cores = (num_query_heads // num_kv_heads) > 4
use_tensor_cores = True
kv_cache_dtype = torch.float8_e4m3fn
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)