Default model load/config/tokenizer to mistral format if relevant files exist (#28659)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Julien Denize
2025-11-21 22:58:59 +01:00
committed by GitHub
parent c68c7b403d
commit 57430fc95c
15 changed files with 230 additions and 34 deletions

View File

@@ -81,7 +81,7 @@ TaskOption = Literal[
"transcription",
"draft",
]
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal[
"raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"
@@ -130,7 +130,8 @@ class ModelConfig:
name or path will be used."""
tokenizer_mode: TokenizerMode = "auto"
"""Tokenizer mode:\n
- "auto" will use the fast tokenizer if available.\n
- "auto" will use "hf" tokenizer if Mistral's tokenizer is not available.\n
- "hf" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "custom" will use --tokenizer to select the preregistered tokenizer."""
@@ -241,8 +242,8 @@ class ModelConfig:
first one."""
config_format: str | ConfigFormat = "auto"
"""The format of the model config to load:\n
- "auto" will try to load the config in hf format if available else it
will try to load in mistral format.\n
- "auto" will try to load the config in hf format if available after trying
to load in mistral format.\n
- "hf" will load the config in hf format.\n
- "mistral" will load the config in mistral format."""
hf_token: bool | str | None = None

View File

@@ -30,6 +30,7 @@ logger = init_logger(__name__)
# if a new load format is added here
LoadFormats = Literal[
"auto",
"hf",
"bitsandbytes",
"dummy",
"fastsafetensors",
@@ -45,6 +46,7 @@ LoadFormats = Literal[
]
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
"auto": DefaultModelLoader,
"hf": DefaultModelLoader,
"bitsandbytes": BitsAndBytesModelLoader,
"dummy": DummyModelLoader,
"fastsafetensors": DefaultModelLoader,

View File

@@ -31,6 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator,
)
from vllm.platforms import current_platform
from vllm.transformers_utils.config import list_filtered_repo_files
logger = init_logger(__name__)
@@ -96,8 +97,25 @@ class DefaultModelLoader(BaseModelLoader):
load_format = self.load_config.load_format
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
# First check for 'auto' format that mistral files format are present.
# This is to load mistral models with official format by default.
if load_format == "auto":
load_format = (
"mistral"
if len(
list_filtered_repo_files(
model_name_or_path=model_name_or_path,
allow_patterns=["consolidated*.safetensors"],
revision=revision,
)
)
> 0
else "hf"
)
# Some quantized models use .pt files for storing the weights.
if load_format == "hf":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors" or load_format == "fastsafetensors":
use_safetensors = True

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import fnmatch
import json
import os
import time
@@ -355,6 +356,41 @@ def list_repo_files(
return with_retry(lookup_files, "Error retrieving file list")
def list_filtered_repo_files(
model_name_or_path: str,
allow_patterns: list[str],
revision: str | None = None,
repo_type: str | None = None,
token: str | bool | None = None,
) -> list[str]:
try:
all_files = list_repo_files(
repo_id=model_name_or_path,
revision=revision,
token=token,
repo_type=repo_type,
)
except Exception:
logger.error(
"Error retrieving file list. Please ensure your `model_name_or_path`"
"`repo_type`, `token` and `revision` arguments are correctly set. "
"Returning an empty list."
)
return []
file_list = []
# Filter patterns on filenames
for pattern in allow_patterns:
file_list.extend(
[
file
for file in all_files
if fnmatch.fnmatch(os.path.basename(file), pattern)
]
)
return file_list
def file_exists(
repo_id: str,
file_name: str,
@@ -619,10 +655,14 @@ def get_config(
if config_format == "auto":
try:
if is_gguf or file_or_path_exists(model, HF_CONFIG_NAME, revision=revision):
config_format = "hf"
elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
# First check for Mistral to avoid defaulting to
# Transformers implementation.
if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
config_format = "mistral"
elif is_gguf or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision
):
config_format = "hf"
else:
raise ValueError(
"Could not detect config format for no config file found. "

View File

@@ -118,7 +118,7 @@ def _remap_general_mistral_args(config: dict) -> dict:
"model_type": ("model_type", "transformer"),
"hidden_act": ("activation", "silu"),
"tie_word_embeddings": ("tied_embeddings", False),
"max_seq_len": ("max_seq_len", 128_000),
"max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
"max_position_embeddings": ("max_position_embeddings", 128_000),
}

View File

@@ -3,8 +3,8 @@
import contextlib
import copy
import importlib.util
import os
import warnings
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeAlias
@@ -15,7 +15,10 @@ from typing_extensions import assert_never
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
from vllm.transformers_utils.config import (
get_sentence_transformer_tokenizer_config,
list_filtered_repo_files,
)
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file
@@ -182,25 +185,29 @@ def get_tokenizer(
kwargs["gguf_file"] = Path(tokenizer_name).name
tokenizer_name = Path(tokenizer_name).parent
# if tokenizer is from official mistral org
is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
if is_from_mistral_org and tokenizer_mode != "mistral":
warnings.warn(
"It is strongly recommended to run mistral models with "
'`--tokenizer-mode "mistral"` to ensure correct '
"encoding and decoding.",
FutureWarning,
stacklevel=2,
# if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
# first to use official Mistral tokenizer if possible.
mistral_common_installed = importlib.util.find_spec("mistral_common") is not None
if tokenizer_mode == "auto" and mistral_common_installed:
allow_patterns = ["tekken.json", "tokenizer.model.v*"]
files_list = list_filtered_repo_files(
model_name_or_path=str(tokenizer_name),
allow_patterns=allow_patterns,
revision=revision,
)
if len(files_list) > 0:
tokenizer_mode = "mistral"
tokenizer: AnyTokenizer
if tokenizer_mode == "mistral":
logger.debug_once(f"Loading MistralTokenizer from {tokenizer_name}")
tokenizer = MistralTokenizer.from_pretrained(
str(tokenizer_name), revision=revision
)
elif tokenizer_mode == "custom":
from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
logger.debug_once(f"Loading CustomTokenizer from {tokenizer_name}")
tokenizer = TokenizerRegistry.get_tokenizer(
str(tokenizer_name),
*args,
@@ -210,6 +217,7 @@ def get_tokenizer(
)
else:
try:
logger.debug_once(f"Loading AutoTokenizer from {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,

View File

@@ -20,6 +20,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
@@ -300,12 +301,24 @@ class Processor:
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
if isinstance(self.tokenizer, MistralTokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'guidance' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_guidance_grammar(params, tokenizer=None)
elif backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(params)
elif backend == "lm-format-enforcer":
# lm format enforcer backend
if isinstance(self.tokenizer, MistralTokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_structured_output_request_lm_format_enforcer(params)
else:
# NOTE: backend must be "auto" here, because we have
@@ -320,9 +333,15 @@ class Processor:
except ValueError:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
validate_guidance_grammar(params, tokenizer=None)
params.structured_outputs._backend = "guidance"
# are not supported in xgrammar.
if isinstance(self.tokenizer, MistralTokenizer):
# Fall back to outlines if the tokenizer is Mistral
validate_structured_output_request_outlines(params)
params.structured_outputs._backend = "outlines"
else:
# Fall back to guidance by default.
validate_guidance_grammar(params, tokenizer=None)
params.structured_outputs._backend = "guidance"
# Remember that this backend was set automatically
params.structured_outputs._backend_was_auto = True