Further reduce the HTTP calls to huggingface.co (#13107)
This commit is contained in:
committed by
GitHub
parent
d59def4730
commit
7c4033acd4
@@ -4,12 +4,14 @@ import enum
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Literal, Optional, Type, Union
|
from typing import Any, Callable, Dict, Literal, Optional, Type, Union
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
|
from huggingface_hub import hf_hub_download
|
||||||
try_to_load_from_cache)
|
from huggingface_hub import list_repo_files as hf_list_repo_files
|
||||||
|
from huggingface_hub import try_to_load_from_cache
|
||||||
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
||||||
HFValidationError, LocalEntryNotFoundError,
|
HFValidationError, LocalEntryNotFoundError,
|
||||||
RepositoryNotFoundError,
|
RepositoryNotFoundError,
|
||||||
@@ -86,6 +88,65 @@ class ConfigFormat(str, enum.Enum):
|
|||||||
MISTRAL = "mistral"
|
MISTRAL = "mistral"
|
||||||
|
|
||||||
|
|
||||||
|
def with_retry(func: Callable[[], Any],
|
||||||
|
log_msg: str,
|
||||||
|
max_retries: int = 2,
|
||||||
|
retry_delay: int = 2):
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except Exception as e:
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
logger.error("%s: %s", log_msg, e)
|
||||||
|
raise
|
||||||
|
logger.error("%s: %s, retrying %d of %d", log_msg, e, attempt + 1,
|
||||||
|
max_retries)
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
retry_delay *= 2
|
||||||
|
|
||||||
|
|
||||||
|
# @cache doesn't cache exceptions
|
||||||
|
@cache
|
||||||
|
def list_repo_files(
|
||||||
|
repo_id: str,
|
||||||
|
*,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
repo_type: Optional[str] = None,
|
||||||
|
token: Union[str, bool, None] = None,
|
||||||
|
) -> list[str]:
|
||||||
|
|
||||||
|
def lookup_files():
|
||||||
|
try:
|
||||||
|
return hf_list_repo_files(repo_id,
|
||||||
|
revision=revision,
|
||||||
|
repo_type=repo_type,
|
||||||
|
token=token)
|
||||||
|
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||||
|
# Don't raise in offline mode,
|
||||||
|
# all we know is that we don't have this
|
||||||
|
# file cached.
|
||||||
|
return []
|
||||||
|
|
||||||
|
return with_retry(lookup_files, "Error retrieving file list")
|
||||||
|
|
||||||
|
|
||||||
|
def file_exists(
|
||||||
|
repo_id: str,
|
||||||
|
file_name: str,
|
||||||
|
*,
|
||||||
|
repo_type: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
token: Union[str, bool, None] = None,
|
||||||
|
) -> bool:
|
||||||
|
|
||||||
|
file_list = list_repo_files(repo_id,
|
||||||
|
repo_type=repo_type,
|
||||||
|
revision=revision,
|
||||||
|
token=token)
|
||||||
|
return file_name in file_list
|
||||||
|
|
||||||
|
|
||||||
|
# In offline mode the result can be a false negative
|
||||||
def file_or_path_exists(model: Union[str, Path], config_name: str,
|
def file_or_path_exists(model: Union[str, Path], config_name: str,
|
||||||
revision: Optional[str]) -> bool:
|
revision: Optional[str]) -> bool:
|
||||||
if Path(model).exists():
|
if Path(model).exists():
|
||||||
@@ -103,31 +164,10 @@ def file_or_path_exists(model: Union[str, Path], config_name: str,
|
|||||||
# hf_hub. This will fail in offline mode.
|
# hf_hub. This will fail in offline mode.
|
||||||
|
|
||||||
# Call HF to check if the file exists
|
# Call HF to check if the file exists
|
||||||
# 2 retries and exponential backoff
|
return file_exists(str(model),
|
||||||
max_retries = 2
|
|
||||||
retry_delay = 2
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
return file_exists(model,
|
|
||||||
config_name,
|
config_name,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
token=HF_TOKEN)
|
token=HF_TOKEN)
|
||||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
|
||||||
# Don't raise in offline mode,
|
|
||||||
# all we know is that we don't have this
|
|
||||||
# file cached.
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"Error checking file existence: %s, retrying %d of %d", e,
|
|
||||||
attempt + 1, max_retries)
|
|
||||||
if attempt == max_retries - 1:
|
|
||||||
logger.error("Error checking file existence: %s", e)
|
|
||||||
raise
|
|
||||||
time.sleep(retry_delay)
|
|
||||||
retry_delay *= 2
|
|
||||||
continue
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def patch_rope_scaling(config: PretrainedConfig) -> None:
|
def patch_rope_scaling(config: PretrainedConfig) -> None:
|
||||||
@@ -208,32 +248,7 @@ def get_config(
|
|||||||
revision=revision):
|
revision=revision):
|
||||||
config_format = ConfigFormat.MISTRAL
|
config_format = ConfigFormat.MISTRAL
|
||||||
else:
|
else:
|
||||||
# If we're in offline mode and found no valid config format, then
|
raise ValueError(f"No supported config format found in {model}.")
|
||||||
# raise an offline mode error to indicate to the user that they
|
|
||||||
# don't have files cached and may need to go online.
|
|
||||||
# This is conveniently triggered by calling file_exists().
|
|
||||||
|
|
||||||
# Call HF to check if the file exists
|
|
||||||
# 2 retries and exponential backoff
|
|
||||||
max_retries = 2
|
|
||||||
retry_delay = 2
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
file_exists(model,
|
|
||||||
HF_CONFIG_NAME,
|
|
||||||
revision=revision,
|
|
||||||
token=HF_TOKEN)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"Error checking file existence: %s, retrying %d of %d",
|
|
||||||
e, attempt + 1, max_retries)
|
|
||||||
if attempt == max_retries:
|
|
||||||
logger.error("Error checking file existence: %s", e)
|
|
||||||
raise e
|
|
||||||
time.sleep(retry_delay)
|
|
||||||
retry_delay *= 2
|
|
||||||
|
|
||||||
raise ValueError(f"No supported config format found in {model}")
|
|
||||||
|
|
||||||
if config_format == ConfigFormat.HF:
|
if config_format == ConfigFormat.HF:
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
@@ -339,10 +354,11 @@ def get_hf_file_to_dict(file_name: str,
|
|||||||
file_name=file_name,
|
file_name=file_name,
|
||||||
revision=revision)
|
revision=revision)
|
||||||
|
|
||||||
if file_path is None and file_or_path_exists(
|
if file_path is None:
|
||||||
model=model, config_name=file_name, revision=revision):
|
|
||||||
try:
|
try:
|
||||||
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
|
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
|
||||||
|
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||||
|
return None
|
||||||
except (RepositoryNotFoundError, RevisionNotFoundError,
|
except (RepositoryNotFoundError, RevisionNotFoundError,
|
||||||
EntryNotFoundError, LocalEntryNotFoundError) as e:
|
EntryNotFoundError, LocalEntryNotFoundError) as e:
|
||||||
logger.debug("File or repository not found in hf_hub_download", e)
|
logger.debug("File or repository not found in hf_hub_download", e)
|
||||||
@@ -363,6 +379,7 @@ def get_hf_file_to_dict(file_name: str,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
||||||
"""
|
"""
|
||||||
This function gets the pooling and normalize
|
This function gets the pooling and normalize
|
||||||
@@ -390,6 +407,8 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
|||||||
if modules_dict is None:
|
if modules_dict is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
logger.info("Found sentence-transformers modules configuration.")
|
||||||
|
|
||||||
pooling = next((item for item in modules_dict
|
pooling = next((item for item in modules_dict
|
||||||
if item["type"] == "sentence_transformers.models.Pooling"),
|
if item["type"] == "sentence_transformers.models.Pooling"),
|
||||||
None)
|
None)
|
||||||
@@ -408,6 +427,7 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
|||||||
if pooling_type_name is not None:
|
if pooling_type_name is not None:
|
||||||
pooling_type_name = get_pooling_config_name(pooling_type_name)
|
pooling_type_name = get_pooling_config_name(pooling_type_name)
|
||||||
|
|
||||||
|
logger.info("Found pooling configuration.")
|
||||||
return {"pooling_type": pooling_type_name, "normalize": normalize}
|
return {"pooling_type": pooling_type_name, "normalize": normalize}
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@@ -435,6 +455,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
def get_sentence_transformer_tokenizer_config(model: str,
|
def get_sentence_transformer_tokenizer_config(model: str,
|
||||||
revision: Optional[str] = 'main'
|
revision: Optional[str] = 'main'
|
||||||
):
|
):
|
||||||
@@ -491,6 +512,8 @@ def get_sentence_transformer_tokenizer_config(model: str,
|
|||||||
if not encoder_dict:
|
if not encoder_dict:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
logger.info("Found sentence-transformers tokenize configuration.")
|
||||||
|
|
||||||
if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")):
|
if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")):
|
||||||
return encoder_dict
|
return encoder_dict
|
||||||
return None
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user