[Model] Add Mistral Tokenization to improve robustness and chat encoding (#7739)
This commit is contained in:
committed by
GitHub
parent
9606c7197d
commit
6fc4e6e07a
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -9,12 +10,14 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
|
||||
from vllm.transformers_utils.tokenizers import (BaichuanTokenizer,
|
||||
MistralTokenizer)
|
||||
from vllm.utils import make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||
MistralTokenizer]
|
||||
|
||||
|
||||
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
@@ -99,45 +102,64 @@ def get_tokenizer(
|
||||
kwargs["gguf_file"] = Path(tokenizer_name).name
|
||||
tokenizer_name = Path(tokenizer_name).parent
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the tokenizer class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
if (not trust_remote_code and
|
||||
("does not exist or is not currently imported." in str(e)
|
||||
or "requires you to execute the tokenizer file" in str(e))):
|
||||
err_msg = (
|
||||
"Failed to load the tokenizer. If the tokenizer is a custom "
|
||||
"tokenizer not yet available in the HuggingFace transformers "
|
||||
"library, consider setting `trust_remote_code=True` in LLM "
|
||||
"or using the `--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
except AttributeError as e:
|
||||
if "BaichuanTokenizer" in str(e):
|
||||
# This is for the error "'BaichuanTokenizer' object has no
|
||||
# attribute 'sp_model'".
|
||||
tokenizer = BaichuanTokenizer.from_pretrained(
|
||||
# if tokenizer is from official mistral org
|
||||
is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
|
||||
if is_from_mistral_org and tokenizer_mode != "mistral":
|
||||
warnings.warn(
|
||||
'It is strongly recommended to run mistral models with '
|
||||
'`--tokenizer_mode "mistral"` to ensure correct '
|
||||
'encoding and decoding.',
|
||||
FutureWarning,
|
||||
stacklevel=2)
|
||||
|
||||
if tokenizer_mode == "mistral":
|
||||
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
|
||||
revision=revision)
|
||||
else:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs)
|
||||
else:
|
||||
raise e
|
||||
**kwargs,
|
||||
)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the tokenizer class not existing or not
|
||||
# currently being imported,
|
||||
# suggest using the --trust-remote-code flag.
|
||||
if not trust_remote_code and (
|
||||
"does not exist or is not currently imported." in str(e)
|
||||
or "requires you to execute the tokenizer file" in str(e)):
|
||||
err_msg = ("Failed to load the tokenizer. If the tokenizer "
|
||||
"is a custom tokenizer not yet available in the "
|
||||
"HuggingFace transformers library, consider "
|
||||
"setting `trust_remote_code=True` in LLM or using "
|
||||
"the `--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
except AttributeError as e:
|
||||
if "BaichuanTokenizer" in str(e):
|
||||
# This is for the error "'BaichuanTokenizer' object has no
|
||||
# attribute 'sp_model'".
|
||||
tokenizer = BaichuanTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
logger.warning(
|
||||
"Using a slow tokenizer. This might cause a significant "
|
||||
"slowdown. Consider using a fast tokenizer instead.")
|
||||
return get_cached_tokenizer(tokenizer)
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
logger.warning(
|
||||
"Using a slow tokenizer. This might cause a significant "
|
||||
"slowdown. Consider using a fast tokenizer instead.")
|
||||
tokenizer = get_cached_tokenizer(tokenizer)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_lora_tokenizer(lora_request: LoRARequest, *args,
|
||||
|
||||
Reference in New Issue
Block a user