Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import random
from typing import List, Optional, Tuple
from typing import Optional
import pytest
import torch
@@ -85,8 +85,8 @@ def ref_single_query_cached_kv_attention(
block_table = block_tables_lst[i]
seq_len = int(seq_lens_lst[i])
keys_lst: List[torch.Tensor] = []
values_lst: List[torch.Tensor] = []
keys_lst: list[torch.Tensor] = []
values_lst: list[torch.Tensor] = []
for j in range(seq_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
@@ -133,7 +133,7 @@ def test_paged_attention(
kv_cache_factory,
version: str,
num_seqs: int,
num_heads: Tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
use_alibi: bool,
block_size: int,
@@ -166,7 +166,7 @@ def test_paged_attention(
# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables_lst: List[List[int]] = []
block_tables_lst: list[list[int]] = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
@@ -334,7 +334,7 @@ def test_paged_attention(
def ref_multi_query_kv_attention(
cu_seq_lens: List[int],
cu_seq_lens: list[int],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@@ -342,7 +342,7 @@ def ref_multi_query_kv_attention(
dtype: torch.dtype,
) -> torch.Tensor:
num_seqs = len(cu_seq_lens) - 1
ref_outputs: List[torch.Tensor] = []
ref_outputs: list[torch.Tensor] = []
for i in range(num_seqs):
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
@@ -378,7 +378,7 @@ def ref_multi_query_kv_attention(
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
num_heads: Tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,