Update deprecated type hinting in model_loader (#18130)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 12:00:21 +01:00
committed by GitHub
parent a9944aabfa
commit 07ad27121f
12 changed files with 80 additions and 74 deletions

View File

@@ -2,7 +2,8 @@
# ruff: noqa: SIM117
import glob
import os
from typing import Generator, List, Optional, Tuple
from collections.abc import Generator
from typing import Optional
import torch
from torch import nn
@@ -48,7 +49,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]) -> List[str]:
revision: Optional[str]) -> list[str]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
@@ -87,7 +88,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
def _get_weights_iterator(
self, model_or_path: str,
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
revision: str) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(