Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -39,7 +39,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||
def ref_rms_norm(rms_norm_layer: RMSNorm,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor]) \
|
||||
-> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
-> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if residual is not None:
|
||||
residual = residual.clone()
|
||||
out, residual = rms_norm_layer.forward_native(x, residual)
|
||||
@@ -54,7 +54,7 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: Optional[torch.Tensor],
|
||||
scale_ub: Optional[torch.Tensor]) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == torch.float8_e4m3fn
|
||||
|
||||
@@ -78,7 +78,7 @@ def ref_impl(rms_norm_layer: RMSNorm,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: Optional[torch.Tensor],
|
||||
scale_ub: Optional[torch.Tensor]) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype,
|
||||
residual, scale_ub)
|
||||
|
||||
@@ -88,7 +88,7 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: Optional[torch.Tensor],
|
||||
scale_ub: Optional[torch.Tensor]) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
if residual is not None:
|
||||
residual = residual.clone()
|
||||
out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS,
|
||||
@@ -102,7 +102,7 @@ def ops_impl(weight: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: Optional[torch.Tensor],
|
||||
scale_ub: Optional[torch.Tensor]) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual,
|
||||
scale_ub)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user