Update deprecated type hinting in models (#18132)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, Union, overload)
|
||||
from typing import Callable, Literal, Optional, Protocol, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -58,8 +58,8 @@ class WeightsMapper:
|
||||
return key
|
||||
|
||||
def apply(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
return ((out_name, data) for name, data in weights
|
||||
if (out_name := self._map_name(name)) is not None)
|
||||
|
||||
@@ -84,8 +84,8 @@ class AutoWeightsLoader:
|
||||
self,
|
||||
module: nn.Module,
|
||||
*,
|
||||
skip_prefixes: Optional[List[str]] = None,
|
||||
ignore_unexpected_prefixes: Optional[List[str]] = None,
|
||||
skip_prefixes: Optional[list[str]] = None,
|
||||
ignore_unexpected_prefixes: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -95,8 +95,8 @@ class AutoWeightsLoader:
|
||||
|
||||
def _groupby_prefix(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
|
||||
weights_by_parts = ((weight_name.split(".", 1), weight_data)
|
||||
for weight_name, weight_data in weights)
|
||||
|
||||
@@ -129,7 +129,7 @@ class AutoWeightsLoader:
|
||||
self,
|
||||
base_prefix: str,
|
||||
param: nn.Parameter,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
) -> Iterable[str]:
|
||||
for weight_name, weight_data in weights:
|
||||
weight_qualname = self._get_qualname(base_prefix, weight_name)
|
||||
@@ -159,7 +159,7 @@ class AutoWeightsLoader:
|
||||
yield weight_qualname
|
||||
|
||||
def _add_loadable_non_param_tensors(self, module: nn.Module,
|
||||
child_params: Dict[str, torch.Tensor]):
|
||||
child_params: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Add tensor names that are not in the model params that may be in the
|
||||
safetensors, e.g., batch normalization stats.
|
||||
@@ -182,7 +182,7 @@ class AutoWeightsLoader:
|
||||
self,
|
||||
base_prefix: str,
|
||||
module: nn.Module,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
) -> Iterable[str]:
|
||||
if isinstance(module, PPMissingLayer):
|
||||
return
|
||||
@@ -251,10 +251,10 @@ class AutoWeightsLoader:
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
*,
|
||||
mapper: Optional[WeightsMapper] = None,
|
||||
) -> Set[str]:
|
||||
) -> set[str]:
|
||||
if mapper is not None:
|
||||
weights = mapper.apply(weights)
|
||||
|
||||
@@ -292,13 +292,13 @@ def flatten_bn(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
@overload
|
||||
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def flatten_bn(
|
||||
x: Union[List[torch.Tensor], torch.Tensor],
|
||||
x: Union[list[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
concat: Literal[True],
|
||||
) -> torch.Tensor:
|
||||
@@ -307,18 +307,18 @@ def flatten_bn(
|
||||
|
||||
@overload
|
||||
def flatten_bn(
|
||||
x: Union[List[torch.Tensor], torch.Tensor],
|
||||
x: Union[list[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
concat: bool = False,
|
||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def flatten_bn(
|
||||
x: Union[List[torch.Tensor], torch.Tensor],
|
||||
x: Union[list[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
concat: bool = False,
|
||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
"""
|
||||
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
|
||||
|
||||
@@ -442,7 +442,7 @@ def merge_multimodal_embeddings(
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
multimodal_embeddings: NestedTensors,
|
||||
placeholder_token_id: Union[int, List[int]],
|
||||
placeholder_token_id: Union[int, list[int]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||
@@ -596,7 +596,7 @@ def make_layers(
|
||||
num_hidden_layers: int,
|
||||
layer_fn: LayerFn,
|
||||
prefix: str,
|
||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||
) -> tuple[int, int, torch.nn.ModuleList]:
|
||||
"""Make a list of layers with the given layer function, taking
|
||||
pipeline parallelism into account.
|
||||
"""
|
||||
@@ -614,10 +614,10 @@ def make_layers(
|
||||
|
||||
|
||||
# NOTE: don't use lru_cache here because it can prevent garbage collection
|
||||
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
|
||||
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
|
||||
|
||||
|
||||
def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
|
||||
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
|
||||
"""Get the names of the missing layers in a pipeline parallel model."""
|
||||
model_id = id(model)
|
||||
if model_id in _model_to_pp_missing_layer_names:
|
||||
@@ -645,7 +645,7 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
|
||||
for missing_layer_name in get_pp_missing_layer_names(model))
|
||||
|
||||
|
||||
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
|
||||
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
batch_size: int,
|
||||
@@ -684,7 +684,7 @@ def extract_layer_index(layer_name: str) -> int:
|
||||
- "model.encoder.layers.0.sub.1" -> ValueError
|
||||
"""
|
||||
subnames = layer_name.split(".")
|
||||
int_vals: List[int] = []
|
||||
int_vals: list[int] = []
|
||||
for subname in subnames:
|
||||
try:
|
||||
int_vals.append(int(subname))
|
||||
|
||||
Reference in New Issue
Block a user