Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -85,8 +85,8 @@ def ref_single_query_cached_kv_attention(
|
||||
block_table = block_tables_lst[i]
|
||||
seq_len = int(seq_lens_lst[i])
|
||||
|
||||
keys_lst: List[torch.Tensor] = []
|
||||
values_lst: List[torch.Tensor] = []
|
||||
keys_lst: list[torch.Tensor] = []
|
||||
values_lst: list[torch.Tensor] = []
|
||||
for j in range(seq_len):
|
||||
block_number = int(block_table[j // block_size])
|
||||
block_offset = j % block_size
|
||||
@@ -133,7 +133,7 @@ def test_paged_attention(
|
||||
kv_cache_factory,
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
num_heads: Tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
@@ -166,7 +166,7 @@ def test_paged_attention(
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables_lst: List[List[int]] = []
|
||||
block_tables_lst: list[list[int]] = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
@@ -334,7 +334,7 @@ def test_paged_attention(
|
||||
|
||||
|
||||
def ref_multi_query_kv_attention(
|
||||
cu_seq_lens: List[int],
|
||||
cu_seq_lens: list[int],
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@@ -342,7 +342,7 @@ def ref_multi_query_kv_attention(
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
ref_outputs: List[torch.Tensor] = []
|
||||
ref_outputs: list[torch.Tensor] = []
|
||||
for i in range(num_seqs):
|
||||
start_idx = cu_seq_lens[i]
|
||||
end_idx = cu_seq_lens[i + 1]
|
||||
@@ -378,7 +378,7 @@ def ref_multi_query_kv_attention(
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
num_heads: Tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
|
||||
Reference in New Issue
Block a user