Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -4,10 +4,11 @@
import json
import os
import time
from collections.abc import Callable
from dataclasses import asdict
from functools import cache, partial
from pathlib import Path
from typing import Any, Callable, Literal, Optional, TypeVar, Union
from typing import Any, Literal, TypeVar
import huggingface_hub
from huggingface_hub import (
@@ -47,7 +48,7 @@ MISTRAL_CONFIG_NAME = "params.json"
logger = init_logger(__name__)
def _get_hf_token() -> Optional[str]:
def _get_hf_token() -> str | None:
"""
Get the HuggingFace token from environment variable.
@@ -108,10 +109,10 @@ _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = {
class HFConfigParser(ConfigParserBase):
def parse(
self,
model: Union[str, Path],
model: str | Path,
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
revision: str | None = None,
code_revision: str | None = None,
**kwargs,
) -> tuple[dict, PretrainedConfig]:
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
@@ -173,10 +174,10 @@ class HFConfigParser(ConfigParserBase):
class MistralConfigParser(ConfigParserBase):
def parse(
self,
model: Union[str, Path],
model: str | Path,
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
revision: str | None = None,
code_revision: str | None = None,
**kwargs,
) -> tuple[dict, PretrainedConfig]:
# This function loads a params.json config which
@@ -247,8 +248,8 @@ def register_config_parser(config_format: str):
... self,
... model: Union[str, Path],
... trust_remote_code: bool,
... revision: Optional[str] = None,
... code_revision: Optional[str] = None,
... revision: str | None = None,
... code_revision: str | None = None,
... **kwargs,
... ) -> tuple[dict, PretrainedConfig]:
... raise NotImplementedError
@@ -310,9 +311,9 @@ def with_retry(
def list_repo_files(
repo_id: str,
*,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
token: Union[str, bool, None] = None,
revision: str | None = None,
repo_type: str | None = None,
token: str | bool | None = None,
) -> list[str]:
def lookup_files() -> list[str]:
# directly list files if model is local
@@ -348,9 +349,9 @@ def file_exists(
repo_id: str,
file_name: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
token: Union[str, bool, None] = None,
repo_type: str | None = None,
revision: str | None = None,
token: str | bool | None = None,
) -> bool:
file_list = list_repo_files(
repo_id, repo_type=repo_type, revision=revision, token=token
@@ -360,7 +361,7 @@ def file_exists(
# In offline mode the result can be a false negative
def file_or_path_exists(
model: Union[str, Path], config_name: str, revision: Optional[str]
model: str | Path, config_name: str, revision: str | None
) -> bool:
if (local_path := Path(model)).exists():
return (local_path / config_name).is_file()
@@ -493,10 +494,10 @@ def maybe_override_with_speculators(
model: str,
tokenizer: str,
trust_remote_code: bool,
revision: Optional[str] = None,
vllm_speculative_config: Optional[dict[str, Any]] = None,
revision: str | None = None,
vllm_speculative_config: dict[str, Any] | None = None,
**kwargs,
) -> tuple[str, str, Optional[dict[str, Any]]]:
) -> tuple[str, str, dict[str, Any] | None]:
"""
Resolve model configuration when speculators are detected.
@@ -551,13 +552,13 @@ def maybe_override_with_speculators(
def get_config(
model: Union[str, Path],
model: str | Path,
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
config_format: Union[str, ConfigFormat] = "auto",
hf_overrides_kw: Optional[dict[str, Any]] = None,
hf_overrides_fn: Optional[Callable[[PretrainedConfig], PretrainedConfig]] = None,
revision: str | None = None,
code_revision: str | None = None,
config_format: str | ConfigFormat = "auto",
hf_overrides_kw: dict[str, Any] | None = None,
hf_overrides_fn: Callable[[PretrainedConfig], PretrainedConfig] | None = None,
**kwargs,
) -> PretrainedConfig:
# Separate model folder from file path for GGUF models
@@ -669,8 +670,8 @@ def get_config(
def try_get_local_file(
model: Union[str, Path], file_name: str, revision: Optional[str] = "main"
) -> Optional[Path]:
model: str | Path, file_name: str, revision: str | None = "main"
) -> Path | None:
file_path = Path(model) / file_name
if file_path.is_file():
return file_path
@@ -687,7 +688,7 @@ def try_get_local_file(
def get_hf_file_to_dict(
file_name: str, model: Union[str, Path], revision: Optional[str] = "main"
file_name: str, model: str | Path, revision: str | None = "main"
):
"""
Downloads a file from the Hugging Face Hub and returns
@@ -735,7 +736,7 @@ def get_hf_file_to_dict(
@cache
def get_pooling_config(model: str, revision: Optional[str] = "main") -> Optional[dict]:
def get_pooling_config(model: str, revision: str | None = "main") -> dict | None:
"""
This function gets the pooling and normalize
config from the model - only applies to
@@ -799,7 +800,7 @@ def get_pooling_config(model: str, revision: Optional[str] = "main") -> Optional
return None
def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
def get_pooling_config_name(pooling_name: str) -> str | None:
if "pooling_mode_" in pooling_name:
pooling_name = pooling_name.replace("pooling_mode_", "")
@@ -820,7 +821,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
@cache
def get_sentence_transformer_tokenizer_config(
model: Union[str, Path], revision: Optional[str] = "main"
model: str | Path, revision: str | None = "main"
):
"""
Returns the tokenization configuration dictionary for a
@@ -958,9 +959,9 @@ def maybe_register_config_serialize_by_value() -> None:
def get_hf_image_processor_config(
model: Union[str, Path],
hf_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
model: str | Path,
hf_token: bool | str | None = None,
revision: str | None = None,
**kwargs,
) -> dict[str, Any]:
# ModelScope does not provide an interface for image_processor
@@ -992,9 +993,9 @@ def get_hf_text_config(config: PretrainedConfig):
def try_get_generation_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
config_format: Union[str, ConfigFormat] = "auto",
) -> Optional[GenerationConfig]:
revision: str | None = None,
config_format: str | ConfigFormat = "auto",
) -> GenerationConfig | None:
try:
return GenerationConfig.from_pretrained(
model,
@@ -1016,7 +1017,7 @@ def try_get_generation_config(
def try_get_safetensors_metadata(
model: str,
*,
revision: Optional[str] = None,
revision: str | None = None,
):
get_safetensors_metadata_partial = partial(
get_safetensors_metadata,
@@ -1034,10 +1035,10 @@ def try_get_safetensors_metadata(
def try_get_tokenizer_config(
pretrained_model_name_or_path: Union[str, os.PathLike],
pretrained_model_name_or_path: str | os.PathLike,
trust_remote_code: bool,
revision: Optional[str] = None,
) -> Optional[dict[str, Any]]:
revision: str | None = None,
) -> dict[str, Any] | None:
try:
return get_tokenizer_config(
pretrained_model_name_or_path,
@@ -1051,7 +1052,7 @@ def try_get_tokenizer_config(
def get_safetensors_params_metadata(
model: str,
*,
revision: Optional[str] = None,
revision: str | None = None,
) -> dict[str, Any]:
"""
Get the safetensors metadata for remote model repository.
@@ -1112,7 +1113,7 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
return max_position_embeddings
def get_model_path(model: Union[str, Path], revision: Optional[str] = None):
def get_model_path(model: str | Path, revision: str | None = None):
if os.path.exists(model):
return model
assert huggingface_hub.constants.HF_HUB_OFFLINE
@@ -1132,8 +1133,8 @@ def get_model_path(model: Union[str, Path], revision: Optional[str] = None):
def get_hf_file_bytes(
file_name: str, model: Union[str, Path], revision: Optional[str] = "main"
) -> Optional[bytes]:
file_name: str, model: str | Path, revision: str | None = "main"
) -> bytes | None:
"""Get file contents from HuggingFace repository as bytes."""
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)