Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
from typing import Type
import pytest
import torch
@@ -71,7 +70,7 @@ def cutlass_fp8_gemm_helper(m: int,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
@@ -109,7 +108,7 @@ def cutlass_int8_gemm_helper(m: int,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
@@ -187,7 +186,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: Type[torch.dtype],
out_dtype: type[torch.dtype],
use_bias: bool):
cutlass_int8_gemm_helper(512,
512,
@@ -208,7 +207,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: Type[torch.dtype],
out_dtype: type[torch.dtype],
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,
@@ -227,7 +226,7 @@ def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
reason="FP8 blockwise is not supported on this GPU type.")
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: Type[torch.dtype],
out_dtype: type[torch.dtype],
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,