Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from itertools import accumulate, product
from typing import Callable, Dict, List, Optional
from typing import Callable, Optional
import pytest
import torch
@@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora(
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
scaling_factors: List[int] = [1, 2, 4]
scaling_factors: list[int] = [1, 2, 4]
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"rope_type": "linear",
"factor": tuple(scaling_factors)
@@ -234,7 +234,7 @@ def test_rope_module_cache():
})
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES)
rope_setting_id_map: Dict[str, int] = {}
rope_setting_id_map: dict[str, int] = {}
for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting