Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -87,8 +87,8 @@ def ref_single_query_cached_kv_attention(
|
||||
block_table = block_tables_lst[i]
|
||||
seq_len = int(seq_lens_lst[i])
|
||||
|
||||
keys_lst: List[torch.Tensor] = []
|
||||
values_lst: List[torch.Tensor] = []
|
||||
keys_lst: list[torch.Tensor] = []
|
||||
values_lst: list[torch.Tensor] = []
|
||||
for j in range(seq_len):
|
||||
block_number = int(block_table[j // block_size])
|
||||
block_offset = j % block_size
|
||||
@@ -162,7 +162,7 @@ def test_paged_attention(
|
||||
kv_cache_factory,
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
num_heads: Tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
@@ -331,7 +331,7 @@ def test_paged_attention(
|
||||
|
||||
|
||||
def ref_multi_query_kv_attention(
|
||||
cu_seq_lens: List[int],
|
||||
cu_seq_lens: list[int],
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@@ -376,7 +376,7 @@ def ref_multi_query_kv_attention(
|
||||
@torch.inference_mode()
|
||||
def test_varlen_blocksparse_attention_prefill(
|
||||
num_seqs: int,
|
||||
num_heads: Tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
blocksparse_local_blocks: int,
|
||||
blocksparse_vert_stride: int,
|
||||
|
||||
Reference in New Issue
Block a user