Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
|
||||
Run `pytest tests/kernels/test_semi_structured.py`.
|
||||
"""
|
||||
from typing import Tuple, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
|
||||
|
||||
def make_rand_sparse_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda')
|
||||
b = torch.randn((n, k), device='cuda').t()
|
||||
|
||||
@@ -167,7 +166,7 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype],
|
||||
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
|
||||
# Create tensors
|
||||
|
||||
Reference in New Issue
Block a user