Update deprecated type hinting in model_loader (#18130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
# ruff: noqa: SIM117
|
||||
import glob
|
||||
import os
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -48,7 +49,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]) -> List[str]:
|
||||
revision: Optional[str]) -> list[str]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
@@ -87,7 +88,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_or_path: str,
|
||||
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
revision: str) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_weights_files = self._prepare_weights(model_or_path, revision)
|
||||
return runai_safetensors_weights_iterator(
|
||||
|
||||
Reference in New Issue
Block a user