Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from itertools import accumulate, product
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
scaling_factors: List[int] = [1, 2, 4]
|
||||
scaling_factors: list[int] = [1, 2, 4]
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
||||
"rope_type": "linear",
|
||||
"factor": tuple(scaling_factors)
|
||||
@@ -234,7 +234,7 @@ def test_rope_module_cache():
|
||||
})
|
||||
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
|
||||
ROPE_SCALINGS, DTYPES)
|
||||
rope_setting_id_map: Dict[str, int] = {}
|
||||
rope_setting_id_map: dict[str, int] = {}
|
||||
for setting in product(*settings):
|
||||
head_size, rotary_dim, max_position, base, \
|
||||
is_neox_stype, rope_scaling, dtype = setting
|
||||
|
||||
Reference in New Issue
Block a user