Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -19,7 +19,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||
def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
scale_ub: Optional[torch.tensor] = None) \
|
||||
-> Tuple[torch.tensor, torch.tensor]:
|
||||
-> tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
||||
if scale_ub is not None:
|
||||
@@ -68,7 +68,7 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
|
||||
# kernel
|
||||
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||
-> Tuple[torch.tensor, torch.tensor]:
|
||||
-> tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
fp8_traits = torch.finfo(FP8_DTYPE)
|
||||
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
|
||||
Reference in New Issue
Block a user