Update deprecated type hinting in model_executor/layers (#18056)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import List, Mapping, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -15,7 +16,7 @@ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
|
||||
int]):
|
||||
# -1 means full extent
|
||||
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||
@@ -56,9 +57,9 @@ def group_broadcast(t, shape):
|
||||
# (i.e. per-token-per-group)
|
||||
def scaled_quantize(
|
||||
x: torch.Tensor,
|
||||
group_shape: Tuple[int, int],
|
||||
group_shape: tuple[int, int],
|
||||
quant_dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
assert quant_dtype.is_floating_point, \
|
||||
"currently `scaled_quantize` only supports floating point dtypes " \
|
||||
@@ -97,9 +98,9 @@ def scaled_quantize(
|
||||
def scaled_dequantize(
|
||||
x_q: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
group_shape: Optional[Tuple[int, int]] = None,
|
||||
group_shape: Optional[tuple[int, int]] = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if group_shape is not None:
|
||||
group_shape = _normalize_quant_group_shape(x_q, group_shape)
|
||||
|
||||
@@ -173,8 +174,8 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
|
||||
def is_layer_skipped(
|
||||
prefix: str,
|
||||
ignored_layers: List[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
ignored_layers: list[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
|
||||
Reference in New Issue
Block a user