Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -25,7 +25,7 @@ DTYPES = [torch.float16, torch.bfloat16]
|
||||
@torch.inference_mode()
|
||||
def test_merge_kernel(
|
||||
num_tokens: int,
|
||||
num_heads: Tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
@@ -85,8 +85,8 @@ CASES = [
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@torch.inference_mode()
|
||||
def test_cascade(
|
||||
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
|
||||
num_heads: Tuple[int, int],
|
||||
seq_lens_and_common_prefix: tuple[list[tuple[int, int]], int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
|
||||
Reference in New Issue
Block a user