Patch Mistral config (#37104)

Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
Julien Denize
2026-03-16 13:22:18 +01:00
committed by GitHub
parent f9e6db3034
commit ffbc2e5bdb
3 changed files with 49 additions and 30 deletions

View File

@@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Callable
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from dataclasses import asdict
from functools import cache, partial
from importlib.metadata import version
@@ -10,8 +11,10 @@ from pathlib import Path
from typing import Any, Literal, TypeAlias
import huggingface_hub
from huggingface_hub import get_safetensors_metadata
import torch
from huggingface_hub import constants, get_safetensors_metadata
from packaging.version import Version
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import get_image_processor_config
from transformers.models.auto.modeling_auto import (
@@ -28,6 +31,7 @@ from vllm.transformers_utils.utils import (
parse_safetensors_file_metadata,
without_trust_remote_code,
)
from vllm.utils.torch_utils import common_broadcastable_dtype
from .config_parser_base import ConfigParserBase
from .gguf_utils import (
@@ -135,6 +139,19 @@ def is_rope_parameters_nested(rope_parameters: dict[str, Any]) -> bool:
return set(rope_parameters.keys()).issubset(ALLOWED_ATTENTION_LAYER_TYPES)
@contextmanager
def _mistral_patch_hf_hub_constants() -> Iterator[None]:
hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE
hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE
constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors"
constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json"
try:
yield
finally:
constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
class HFConfigParser(ConfigParserBase):
def parse(
self,
@@ -245,6 +262,25 @@ class MistralConfigParser(ConfigParserBase):
except OSError: # Not found
hf_config_dict = {}
if config_dict.get("dtype") is None:
with _mistral_patch_hf_hub_constants():
model_str = model if isinstance(model, str) else model.as_posix()
param_mt = get_safetensors_params_metadata(model_str, revision=revision)
if param_mt:
param_dtypes: set[torch.dtype] = {
_SAFETENSORS_TO_TORCH_DTYPE[dtype]
for info in param_mt.values()
if (dtype := info.get("dtype", None))
and dtype in _SAFETENSORS_TO_TORCH_DTYPE
}
if param_dtypes:
config_dict["dtype"] = common_broadcastable_dtype(param_dtypes)
logger.info_once(
"Inferred from consolidated*.safetensors files "
f"{config_dict['dtype']} dtype."
)
config = adapt_config_dict(config_dict, defaults=hf_config_dict)
return config_dict, config

View File

@@ -113,12 +113,13 @@ def _remap_mistral_vision_args(config: dict) -> dict:
def _remap_mistral_yarn_args(config: dict) -> dict:
yarn_config_map = {
"factor": "factor",
"original_max_position_embeddings": "original_max_position_embeddings",
"beta": "beta_fast",
"alpha": "beta_slow",
"apply_scale": "apply_yarn_scaling",
"factor": ("factor", float),
"original_max_position_embeddings": ("original_max_position_embeddings", int),
"beta": ("beta_fast", float),
"alpha": ("beta_slow", float),
"apply_scale": ("apply_yarn_scaling", bool),
}
yarn_config = config.get("yarn") or {}
config["rope_parameters"] = {
"rope_type": "yarn",
@@ -128,9 +129,10 @@ def _remap_mistral_yarn_args(config: dict) -> dict:
if rope_theta := config.pop("rope_theta", None):
config["rope_parameters"]["rope_theta"] = rope_theta
for old_name, new_name in yarn_config_map.items():
for old_name, (new_name, cast) in yarn_config_map.items():
if old_name in yarn_config:
config["rope_parameters"][new_name] = yarn_config.pop(old_name)
# Cast to remove Transformers > v5 type warnings
config["rope_parameters"][new_name] = cast(yarn_config.pop(old_name))
assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"
@@ -154,6 +156,7 @@ def _remap_general_mistral_args(config: dict) -> dict:
"tie_word_embeddings": ("tied_embeddings", False),
"max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
"max_position_embeddings": ("max_position_embeddings", 128_000),
"dtype": ("dtype", config.get("dtype")),
}
for key, new_key in config_mapping.items():

View File

@@ -1,12 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator
from contextlib import contextmanager
from typing import final
import torch
from huggingface_hub import constants
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig
@@ -25,22 +22,6 @@ from vllm.utils.torch_utils import common_broadcastable_dtype
logger = init_logger(__name__)
@contextmanager
def _maybe_patch_hf_hub_constants(config_format: ConfigFormat) -> Iterator[None]:
if config_format == "mistral":
hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE
hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE
constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors"
constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json"
try:
yield
finally:
constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
else:
yield
class ModelArchConfigConvertorBase:
def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig):
self.hf_config = hf_config
@@ -164,8 +145,7 @@ class ModelArchConfigConvertorBase:
# Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None:
with _maybe_patch_hf_hub_constants(config_format):
param_mt = get_safetensors_params_metadata(model_id, revision=revision)
param_mt = get_safetensors_params_metadata(model_id, revision=revision)
if param_mt:
param_dtypes: set[torch.dtype] = {