Update deprecated type hinting in model_executor/layers (#18056)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 12:17:23 +01:00
committed by GitHub
parent 906f0598fc
commit 6223dd8114
87 changed files with 523 additions and 523 deletions

View File

@@ -4,7 +4,7 @@
import functools
import json
import os
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import torch
@@ -32,7 +32,7 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
@@ -95,7 +95,7 @@ def apply_w8a8_block_fp8_linear(
def apply_w8a8_block_fp8_linear_fake(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@@ -114,7 +114,7 @@ direct_register_custom_op(
def input_to_float8(
x: torch.Tensor,
dtype: Optional[torch.dtype] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values "
"with tensor-wise quantization."""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
@@ -129,7 +129,7 @@ def input_to_float8(
def block_quant_to_tensor_quant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""This function converts block-wise quantization to tensor-wise
quantization. The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
@@ -247,7 +247,7 @@ def per_token_group_quant_fp8(
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
@@ -258,7 +258,7 @@ def per_token_group_quant_fp8(
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
@@ -412,7 +412,7 @@ def _w8a8_block_fp8_matmul(
@functools.lru_cache
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[Dict[int, Any]]:
block_k: int) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
@@ -452,7 +452,7 @@ def w8a8_block_fp8_matmul(
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise