Update deprecated type hinting in model_executor/layers (#18056)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -32,7 +32,7 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
@@ -95,7 +95,7 @@ def apply_w8a8_block_fp8_linear(
|
||||
def apply_w8a8_block_fp8_linear_fake(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -114,7 +114,7 @@ direct_register_custom_op(
|
||||
def input_to_float8(
|
||||
x: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function quantizes input values to float8 values "
|
||||
"with tensor-wise quantization."""
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
@@ -129,7 +129,7 @@ def input_to_float8(
|
||||
def block_quant_to_tensor_quant(
|
||||
x_q_block: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function converts block-wise quantization to tensor-wise
|
||||
quantization. The inputs are block-wise quantization tensor `x_q_block`,
|
||||
block-wise quantization scale and the block size.
|
||||
@@ -247,7 +247,7 @@ def per_token_group_quant_fp8(
|
||||
eps: float = 1e-10,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
column_major_scales: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
It converts the tensor values into signed float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
@@ -258,7 +258,7 @@ def per_token_group_quant_fp8(
|
||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
is supported for now.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
"""
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
@@ -412,7 +412,7 @@ def _w8a8_block_fp8_matmul(
|
||||
|
||||
@functools.lru_cache
|
||||
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
|
||||
block_k: int) -> Optional[Dict[int, Any]]:
|
||||
block_k: int) -> Optional[dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the w8a8 block fp8 kernel.
|
||||
The return value will be a dictionary that maps an irregular grid of
|
||||
@@ -452,7 +452,7 @@ def w8a8_block_fp8_matmul(
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
|
||||
Reference in New Issue
Block a user