Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_semi_structured.py`.
"""
from typing import Tuple, Type
import pytest
import torch
@@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda')
b = torch.randn((n, k), device='cuda').t()
@@ -167,7 +166,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype],
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype],
use_bias: bool):
# Create tensors