Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
|
||||
Run `pytest tests/kernels/test_cutlass.py`.
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -71,7 +70,7 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
@@ -109,7 +108,7 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
@@ -187,7 +186,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: Type[torch.dtype],
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_int8_gemm_helper(512,
|
||||
512,
|
||||
@@ -208,7 +207,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: Type[torch.dtype],
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(512,
|
||||
512,
|
||||
@@ -227,7 +226,7 @@ def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
|
||||
reason="FP8 blockwise is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: Type[torch.dtype],
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(512,
|
||||
512,
|
||||
|
||||
Reference in New Issue
Block a user