Update deprecated type hinting in model_loader (#18130)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 12:00:21 +01:00
committed by GitHub
parent a9944aabfa
commit 07ad27121f
12 changed files with 80 additions and 74 deletions

View File

@@ -3,7 +3,8 @@ import dataclasses
import glob
import os
import time
from typing import Generator, Iterable, List, Optional, Tuple, cast
from collections.abc import Generator, Iterable
from typing import Optional, cast
import huggingface_hub
import torch
@@ -92,7 +93,7 @@ class DefaultModelLoader(BaseModelLoader):
revision: Optional[str],
fall_back_to_pt: bool,
allow_patterns_overrides: Optional[list[str]],
) -> Tuple[str, List[str], bool]:
) -> tuple[str, list[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
@@ -138,7 +139,7 @@ class DefaultModelLoader(BaseModelLoader):
else:
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
hf_weights_files: list[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
@@ -173,7 +174,7 @@ class DefaultModelLoader(BaseModelLoader):
def _get_weights_iterator(
self, source: "Source"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt,
@@ -238,7 +239,7 @@ class DefaultModelLoader(BaseModelLoader):
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
) -> Generator[tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,