Update deprecated type hinting in models (#18132)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 06:06:50 +01:00
committed by GitHub
parent 83f74c698f
commit 26d0419309
130 changed files with 971 additions and 901 deletions

View File

@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import itertools
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, Union, overload)
from typing import Callable, Literal, Optional, Protocol, Union, overload
import torch
import torch.nn as nn
@@ -58,8 +58,8 @@ class WeightsMapper:
return key
def apply(
self, weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
return ((out_name, data) for name, data in weights
if (out_name := self._map_name(name)) is not None)
@@ -84,8 +84,8 @@ class AutoWeightsLoader:
self,
module: nn.Module,
*,
skip_prefixes: Optional[List[str]] = None,
ignore_unexpected_prefixes: Optional[List[str]] = None,
skip_prefixes: Optional[list[str]] = None,
ignore_unexpected_prefixes: Optional[list[str]] = None,
) -> None:
super().__init__()
@@ -95,8 +95,8 @@ class AutoWeightsLoader:
def _groupby_prefix(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
weights_by_parts = ((weight_name.split(".", 1), weight_data)
for weight_name, weight_data in weights)
@@ -129,7 +129,7 @@ class AutoWeightsLoader:
self,
base_prefix: str,
param: nn.Parameter,
weights: Iterable[Tuple[str, torch.Tensor]],
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[str]:
for weight_name, weight_data in weights:
weight_qualname = self._get_qualname(base_prefix, weight_name)
@@ -159,7 +159,7 @@ class AutoWeightsLoader:
yield weight_qualname
def _add_loadable_non_param_tensors(self, module: nn.Module,
child_params: Dict[str, torch.Tensor]):
child_params: dict[str, torch.Tensor]):
"""
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
@@ -182,7 +182,7 @@ class AutoWeightsLoader:
self,
base_prefix: str,
module: nn.Module,
weights: Iterable[Tuple[str, torch.Tensor]],
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[str]:
if isinstance(module, PPMissingLayer):
return
@@ -251,10 +251,10 @@ class AutoWeightsLoader:
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
weights: Iterable[tuple[str, torch.Tensor]],
*,
mapper: Optional[WeightsMapper] = None,
) -> Set[str]:
) -> set[str]:
if mapper is not None:
weights = mapper.apply(weights)
@@ -292,13 +292,13 @@ def flatten_bn(x: torch.Tensor) -> torch.Tensor:
@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]:
...
@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
x: Union[list[torch.Tensor], torch.Tensor],
*,
concat: Literal[True],
) -> torch.Tensor:
@@ -307,18 +307,18 @@ def flatten_bn(
@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
x: Union[list[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
) -> Union[list[torch.Tensor], torch.Tensor]:
...
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
x: Union[list[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
) -> Union[list[torch.Tensor], torch.Tensor]:
"""
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
@@ -442,7 +442,7 @@ def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: Union[int, List[int]],
placeholder_token_id: Union[int, list[int]],
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
@@ -596,7 +596,7 @@ def make_layers(
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str,
) -> Tuple[int, int, torch.nn.ModuleList]:
) -> tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
"""
@@ -614,10 +614,10 @@ def make_layers(
# NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
"""Get the names of the missing layers in a pipeline parallel model."""
model_id = id(model)
if model_id in _model_to_pp_missing_layer_names:
@@ -645,7 +645,7 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
for missing_layer_name in get_pp_missing_layer_names(model))
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
def make_empty_intermediate_tensors(
batch_size: int,
@@ -684,7 +684,7 @@ def extract_layer_index(layer_name: str) -> int:
- "model.encoder.layers.0.sub.1" -> ValueError
"""
subnames = layer_name.split(".")
int_vals: List[int] = []
int_vals: list[int] = []
for subname in subnames:
try:
int_vals.append(int(subname))