Update deprecated type hinting in model_loader (#18130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,8 @@ import dataclasses
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
from typing import Generator, Iterable, List, Optional, Tuple, cast
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Optional, cast
|
||||
|
||||
import huggingface_hub
|
||||
import torch
|
||||
@@ -92,7 +93,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
revision: Optional[str],
|
||||
fall_back_to_pt: bool,
|
||||
allow_patterns_overrides: Optional[list[str]],
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
) -> tuple[str, list[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
@@ -138,7 +139,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_weights_files: List[str] = []
|
||||
hf_weights_files: list[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
@@ -173,7 +174,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, source: "Source"
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path, source.revision, source.fall_back_to_pt,
|
||||
@@ -238,7 +239,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model: nn.Module,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
primary_weights = DefaultModelLoader.Source(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
|
||||
Reference in New Issue
Block a user