[Bugfix] Fix offline mode when using mistral_common (#9457)

This commit is contained in:
sasha0552
2024-10-19 01:12:32 +00:00
committed by GitHub
parent 0c9a5258f9
commit 337ed76671
2 changed files with 62 additions and 28 deletions

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download
from mistral_common.protocol.instruct.request import ChatCompletionRequest
# yapf: disable
@@ -24,6 +25,26 @@ class Encoding:
input_ids: List[int]
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
["models", *repo_id.split("/")]))
if revision is None:
revision_file = os.path.join(repo_cache, "refs", "main")
if os.path.isfile(revision_file):
with open(revision_file) as file:
revision = file.read()
if revision:
revision_dir = os.path.join(repo_cache, "snapshots", revision)
if os.path.isdir(revision_dir):
return os.listdir(revision_dir)
return []
def find_tokenizer_file(files: List[str]):
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
@@ -90,9 +111,16 @@ class MistralTokenizer:
@staticmethod
def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
revision: Optional[str]) -> str:
api = HfApi()
repo_info = api.model_info(tokenizer_name)
files = [s.rfilename for s in repo_info.siblings]
try:
hf_api = HfApi()
files = hf_api.list_repo_files(repo_id=tokenizer_name,
revision=revision)
except ConnectionError as exc:
files = list_local_repo_files(repo_id=tokenizer_name,
revision=revision)
if len(files) == 0:
raise exc
filename = find_tokenizer_file(files)