Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import flashinfer
|
||||
import pytest
|
||||
@@ -19,8 +19,8 @@ def ref_paged_attn(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
query_lens: List[int],
|
||||
kv_lens: List[int],
|
||||
query_lens: list[int],
|
||||
kv_lens: list[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: Optional[int] = None,
|
||||
@@ -30,7 +30,7 @@ def ref_paged_attn(
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
|
||||
outputs: List[torch.Tensor] = []
|
||||
outputs: list[torch.Tensor] = []
|
||||
start_idx = 0
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
@@ -78,8 +78,8 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_decode_with_paged_kv(
|
||||
kv_lens: List[int],
|
||||
num_heads: Tuple[int, int],
|
||||
kv_lens: list[int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
@@ -168,8 +168,8 @@ def test_flashinfer_decode_with_paged_kv(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
||||
num_heads: Tuple[int, int],
|
||||
def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int, dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
@@ -270,7 +270,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int],
|
||||
seq_lens: list[tuple[int, int]], num_heads: tuple[int, int],
|
||||
head_size: int, dtype: torch.dtype, block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
pytest.skip("TODO: fix the accuracy issue")
|
||||
@@ -378,8 +378,8 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
kv_lens: List[int],
|
||||
num_heads: Tuple[int, int],
|
||||
kv_lens: list[int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
|
||||
Reference in New Issue
Block a user