Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -7,7 +7,6 @@ Run `pytest tests/kernels/quantization/test_machete_mm.py`.
import math
from dataclasses import dataclass, fields
from typing import Optional
import pytest
import torch
@@ -50,11 +49,11 @@ MNK_SHAPES = [
class TypeConfig:
act_type: torch.dtype
weight_type: ScalarType
output_type: Optional[torch.dtype]
group_scale_type: Optional[torch.dtype]
group_zero_type: Optional[torch.dtype]
channel_scale_type: Optional[torch.dtype]
token_scale_type: Optional[torch.dtype]
output_type: torch.dtype | None
group_scale_type: torch.dtype | None
group_zero_type: torch.dtype | None
channel_scale_type: torch.dtype | None
token_scale_type: torch.dtype | None
@dataclass
@@ -63,10 +62,10 @@ class Tensors:
a_ref: torch.Tensor
a: torch.Tensor
w_q: torch.Tensor
w_g_s: Optional[torch.Tensor]
w_g_zp: Optional[torch.Tensor]
w_ch_s: Optional[torch.Tensor]
w_tok_s: Optional[torch.Tensor]
w_g_s: torch.Tensor | None
w_g_zp: torch.Tensor | None
w_ch_s: torch.Tensor | None
w_tok_s: torch.Tensor | None
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
@@ -74,7 +73,7 @@ class Tensors:
# NOTE: None "Scale Type" means the act type is floating point
# None "Output Type" means the output type is the same as the act type
TestTypeTuple = tuple[
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
]
TEST_TYPES = [
# GPTQ style
@@ -139,11 +138,11 @@ def rand_data(shape, dtype=torch.float16, scale=1, offset=0):
return torch.randint(-8, 7, shape, dtype=dtype, device="cuda")
def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
def maybe_convert_zeropoints(zps: torch.Tensor | None, s: torch.Tensor):
return zps if zps is None else -1 * s * (zps.to(s.dtype))
def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool:
def group_size_valid(shape: tuple[int, int, int], group_size: int | None) -> bool:
return group_size is None or group_size == -1 or shape[2] % group_size == 0
@@ -151,8 +150,8 @@ def machete_quantize_and_pack(
atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
stype: torch.dtype | None,
group_size: int | None,
zero_points: bool = False,
):
assert wtype.is_integer(), "TODO: support floating point weights"
@@ -178,8 +177,8 @@ def machete_quantize_and_pack(
def create_test_tensors(
shape: tuple[int, int, int],
types: TypeConfig,
group_size: Optional[int],
subset_stride_factor: Optional[int] = None,
group_size: int | None,
subset_stride_factor: int | None = None,
) -> Tensors:
m, n, k = shape
factor = subset_stride_factor or 1
@@ -243,8 +242,8 @@ def create_test_tensors(
def machete_mm_test_helper(
types: TypeConfig,
tensors: Tensors,
group_size: Optional[int] = None,
schedule: Optional[str] = None,
group_size: int | None = None,
schedule: str | None = None,
):
output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
output_ref_type = output_ref.dtype
@@ -294,7 +293,7 @@ def machete_mm_test_helper(
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_all_schedules(shape, types: TypeConfig):
group_sizes: list[Optional[int]] = []
group_sizes: list[int | None] = []
if types.group_scale_type is None:
group_sizes = [None]
else:
@@ -323,7 +322,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_heuristic(shape, types: TypeConfig):
group_sizes: list[Optional[int]] = []
group_sizes: list[int | None] = []
if types.group_scale_type is None:
group_sizes = [None]
else: