Update deprecated type hinting in model_executor/layers (#18056)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
def apply_w8a8_block_int8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
@@ -43,7 +43,7 @@ def apply_w8a8_block_int8_linear(
|
||||
|
||||
def input_to_int8(
|
||||
x: torch.Tensor,
|
||||
dtype: torch.dtype = torch.int8) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
dtype: torch.dtype = torch.int8) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function quantizes input values to int8 values with
|
||||
tensor-wise quantization."""
|
||||
iinfo = torch.iinfo(dtype)
|
||||
@@ -58,7 +58,7 @@ def input_to_int8(
|
||||
def block_dequant(
|
||||
x_q_block: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""This function conducts block-wise dequantization.
|
||||
The inputs are block-wise quantization tensor `x_q_block`,
|
||||
@@ -211,7 +211,7 @@ def per_token_group_quant_int8(
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype = torch.int8,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
|
||||
It converts the tensor values into signed int8 values and returns the
|
||||
@@ -225,7 +225,7 @@ def per_token_group_quant_int8(
|
||||
is supported for now.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
"""
|
||||
assert (x.shape[-1] % group_size == 0
|
||||
@@ -358,7 +358,7 @@ def _w8a8_block_int8_matmul(
|
||||
|
||||
@functools.lru_cache
|
||||
def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
|
||||
block_k: int) -> Optional[Dict[int, Any]]:
|
||||
block_k: int) -> Optional[dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the w8a8 block fp8 kernel.
|
||||
|
||||
@@ -399,7 +399,7 @@ def w8a8_block_int8_matmul(
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
|
||||
Reference in New Issue
Block a user