Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,7 +7,6 @@ Run `pytest tests/kernels/quantization/test_machete_mm.py`.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -50,11 +49,11 @@ MNK_SHAPES = [
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: Optional[torch.dtype]
|
||||
group_scale_type: Optional[torch.dtype]
|
||||
group_zero_type: Optional[torch.dtype]
|
||||
channel_scale_type: Optional[torch.dtype]
|
||||
token_scale_type: Optional[torch.dtype]
|
||||
output_type: torch.dtype | None
|
||||
group_scale_type: torch.dtype | None
|
||||
group_zero_type: torch.dtype | None
|
||||
channel_scale_type: torch.dtype | None
|
||||
token_scale_type: torch.dtype | None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -63,10 +62,10 @@ class Tensors:
|
||||
a_ref: torch.Tensor
|
||||
a: torch.Tensor
|
||||
w_q: torch.Tensor
|
||||
w_g_s: Optional[torch.Tensor]
|
||||
w_g_zp: Optional[torch.Tensor]
|
||||
w_ch_s: Optional[torch.Tensor]
|
||||
w_tok_s: Optional[torch.Tensor]
|
||||
w_g_s: torch.Tensor | None
|
||||
w_g_zp: torch.Tensor | None
|
||||
w_ch_s: torch.Tensor | None
|
||||
w_tok_s: torch.Tensor | None
|
||||
|
||||
|
||||
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||
@@ -74,7 +73,7 @@ class Tensors:
|
||||
# NOTE: None "Scale Type" means the act type is floating point
|
||||
# None "Output Type" means the output type is the same as the act type
|
||||
TestTypeTuple = tuple[
|
||||
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
|
||||
list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
|
||||
]
|
||||
TEST_TYPES = [
|
||||
# GPTQ style
|
||||
@@ -139,11 +138,11 @@ def rand_data(shape, dtype=torch.float16, scale=1, offset=0):
|
||||
return torch.randint(-8, 7, shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
|
||||
def maybe_convert_zeropoints(zps: torch.Tensor | None, s: torch.Tensor):
|
||||
return zps if zps is None else -1 * s * (zps.to(s.dtype))
|
||||
|
||||
|
||||
def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool:
|
||||
def group_size_valid(shape: tuple[int, int, int], group_size: int | None) -> bool:
|
||||
return group_size is None or group_size == -1 or shape[2] % group_size == 0
|
||||
|
||||
|
||||
@@ -151,8 +150,8 @@ def machete_quantize_and_pack(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
stype: torch.dtype | None,
|
||||
group_size: int | None,
|
||||
zero_points: bool = False,
|
||||
):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
@@ -178,8 +177,8 @@ def machete_quantize_and_pack(
|
||||
def create_test_tensors(
|
||||
shape: tuple[int, int, int],
|
||||
types: TypeConfig,
|
||||
group_size: Optional[int],
|
||||
subset_stride_factor: Optional[int] = None,
|
||||
group_size: int | None,
|
||||
subset_stride_factor: int | None = None,
|
||||
) -> Tensors:
|
||||
m, n, k = shape
|
||||
factor = subset_stride_factor or 1
|
||||
@@ -243,8 +242,8 @@ def create_test_tensors(
|
||||
def machete_mm_test_helper(
|
||||
types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: Optional[int] = None,
|
||||
schedule: Optional[str] = None,
|
||||
group_size: int | None = None,
|
||||
schedule: str | None = None,
|
||||
):
|
||||
output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
|
||||
output_ref_type = output_ref.dtype
|
||||
@@ -294,7 +293,7 @@ def machete_mm_test_helper(
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
def test_machete_all_schedules(shape, types: TypeConfig):
|
||||
group_sizes: list[Optional[int]] = []
|
||||
group_sizes: list[int | None] = []
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
else:
|
||||
@@ -323,7 +322,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
def test_machete_heuristic(shape, types: TypeConfig):
|
||||
group_sizes: list[Optional[int]] = []
|
||||
group_sizes: list[int | None] = []
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user