Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,6 @@ Run `pytest tests/kernels/quantization/test_cutlass_w4a8.py`.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -60,10 +59,10 @@ SCHEDULES = [
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: Optional[torch.dtype]
|
||||
group_scale_type: Optional[torch.dtype]
|
||||
channel_scale_type: Optional[torch.dtype]
|
||||
token_scale_type: Optional[torch.dtype]
|
||||
output_type: torch.dtype | None
|
||||
group_scale_type: torch.dtype | None
|
||||
channel_scale_type: torch.dtype | None
|
||||
token_scale_type: torch.dtype | None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -80,7 +79,7 @@ class Tensors:
|
||||
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||
# Ch Scales Type, Tok Scales Type)
|
||||
TestTypeTuple = tuple[
|
||||
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
|
||||
list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
|
||||
]
|
||||
TEST_TYPES = [
|
||||
*(
|
||||
@@ -116,8 +115,8 @@ def cutlass_quantize_and_pack(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
stype: torch.dtype | None,
|
||||
group_size: int | None,
|
||||
zero_points: bool = False,
|
||||
):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
@@ -143,7 +142,7 @@ def cutlass_quantize_and_pack(
|
||||
|
||||
|
||||
def create_test_tensors(
|
||||
shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
|
||||
shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
|
||||
) -> Tensors:
|
||||
m, n, k = shape
|
||||
|
||||
@@ -185,8 +184,8 @@ def create_test_tensors(
|
||||
def mm_test_helper(
|
||||
types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: Optional[int] = None,
|
||||
schedule: Optional[str] = None,
|
||||
group_size: int | None = None,
|
||||
schedule: str | None = None,
|
||||
):
|
||||
# CUTLASS upstream uses fp8 with fastaccum as reference
|
||||
# https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406
|
||||
|
||||
Reference in New Issue
Block a user