Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from types import MappingProxyType
|
||||
from typing import ClassVar, NamedTuple, Optional
|
||||
from typing import ClassVar, NamedTuple
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -91,7 +91,7 @@ class QuantKey:
|
||||
|
||||
dtype: torch.dtype
|
||||
scale: ScaleDesc
|
||||
scale2: Optional[ScaleDesc] = None
|
||||
scale2: ScaleDesc | None = None
|
||||
symmetric: bool = True
|
||||
|
||||
def __str__(self):
|
||||
@@ -205,7 +205,7 @@ def scaled_quantize(
|
||||
def scaled_dequantize(
|
||||
x_q: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
group_shape: Optional[GroupShape] = None,
|
||||
group_shape: GroupShape | None = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if group_shape is not None:
|
||||
@@ -336,7 +336,7 @@ def permute_rows(
|
||||
q_w: torch.Tensor,
|
||||
w_ref: torch.Tensor,
|
||||
group_size: int,
|
||||
test_perm: Optional[torch.Tensor] = None,
|
||||
test_perm: torch.Tensor | None = None,
|
||||
):
|
||||
assert q_w.shape == w_ref.shape
|
||||
|
||||
@@ -365,7 +365,7 @@ def permute_rows(
|
||||
def quantize_weights(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
group_size: int | None,
|
||||
zero_points: bool = False,
|
||||
ref_zero_points_after_scales: bool = False,
|
||||
):
|
||||
@@ -466,7 +466,7 @@ def gptq_quantize_weights(
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
test_perm: Optional[torch.Tensor] = None,
|
||||
test_perm: torch.Tensor | None = None,
|
||||
):
|
||||
size_k, _ = w.shape
|
||||
|
||||
|
||||
Reference in New Issue
Block a user