Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
from typing import Optional
import pytest
import torch
@@ -25,7 +25,7 @@ DTYPES = [torch.float16, torch.bfloat16]
@torch.inference_mode()
def test_merge_kernel(
num_tokens: int,
num_heads: Tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
):
@@ -85,8 +85,8 @@ CASES = [
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_cascade(
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
num_heads: Tuple[int, int],
seq_lens_and_common_prefix: tuple[list[tuple[int, int]], int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,