Update deprecated type hinting in model_executor/layers (#18056)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 12:17:23 +01:00
committed by GitHub
parent 906f0598fc
commit 6223dd8114
87 changed files with 523 additions and 523 deletions

View File

@@ -5,7 +5,7 @@ import functools
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
import torch
@@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
def apply_w8a8_block_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
@@ -43,7 +43,7 @@ def apply_w8a8_block_int8_linear(
def input_to_int8(
x: torch.Tensor,
dtype: torch.dtype = torch.int8) -> Tuple[torch.Tensor, torch.Tensor]:
dtype: torch.dtype = torch.int8) -> tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to int8 values with
tensor-wise quantization."""
iinfo = torch.iinfo(dtype)
@@ -58,7 +58,7 @@ def input_to_int8(
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
block_size: list[int],
) -> torch.Tensor:
"""This function conducts block-wise dequantization.
The inputs are block-wise quantization tensor `x_q_block`,
@@ -211,7 +211,7 @@ def per_token_group_quant_int8(
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
@@ -225,7 +225,7 @@ def per_token_group_quant_int8(
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (x.shape[-1] % group_size == 0
@@ -358,7 +358,7 @@ def _w8a8_block_int8_matmul(
@functools.lru_cache
def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[Dict[int, Any]]:
block_k: int) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
@@ -399,7 +399,7 @@ def w8a8_block_int8_matmul(
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise