Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
from typing import Optional
import pytest
import torch
@@ -24,8 +24,8 @@ def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: List[int],
kv_lens: List[int],
query_lens: list[int],
kv_lens: list[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
@@ -35,7 +35,7 @@ def ref_paged_attn(
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs: List[torch.Tensor] = []
outputs: list[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
@@ -88,8 +88,8 @@ def ref_paged_attn(
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
use_out: bool,
kv_lens: List[int],
num_heads: Tuple[int, int],
kv_lens: list[int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
@@ -174,8 +174,8 @@ def test_flash_attn_with_paged_kv(
@torch.inference_mode()
def test_varlen_with_paged_kv(
use_out: bool,
seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,