Patch Mistral config (#37104)
Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import asdict
|
||||
from functools import cache, partial
|
||||
from importlib.metadata import version
|
||||
@@ -10,8 +11,10 @@ from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import get_safetensors_metadata
|
||||
import torch
|
||||
from huggingface_hub import constants, get_safetensors_metadata
|
||||
from packaging.version import Version
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
from transformers.models.auto.image_processing_auto import get_image_processor_config
|
||||
from transformers.models.auto.modeling_auto import (
|
||||
@@ -28,6 +31,7 @@ from vllm.transformers_utils.utils import (
|
||||
parse_safetensors_file_metadata,
|
||||
without_trust_remote_code,
|
||||
)
|
||||
from vllm.utils.torch_utils import common_broadcastable_dtype
|
||||
|
||||
from .config_parser_base import ConfigParserBase
|
||||
from .gguf_utils import (
|
||||
@@ -135,6 +139,19 @@ def is_rope_parameters_nested(rope_parameters: dict[str, Any]) -> bool:
|
||||
return set(rope_parameters.keys()).issubset(ALLOWED_ATTENTION_LAYER_TYPES)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _mistral_patch_hf_hub_constants() -> Iterator[None]:
|
||||
hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE
|
||||
hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE
|
||||
constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors"
|
||||
constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file
|
||||
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
|
||||
|
||||
|
||||
class HFConfigParser(ConfigParserBase):
|
||||
def parse(
|
||||
self,
|
||||
@@ -245,6 +262,25 @@ class MistralConfigParser(ConfigParserBase):
|
||||
except OSError: # Not found
|
||||
hf_config_dict = {}
|
||||
|
||||
if config_dict.get("dtype") is None:
|
||||
with _mistral_patch_hf_hub_constants():
|
||||
model_str = model if isinstance(model, str) else model.as_posix()
|
||||
param_mt = get_safetensors_params_metadata(model_str, revision=revision)
|
||||
if param_mt:
|
||||
param_dtypes: set[torch.dtype] = {
|
||||
_SAFETENSORS_TO_TORCH_DTYPE[dtype]
|
||||
for info in param_mt.values()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and dtype in _SAFETENSORS_TO_TORCH_DTYPE
|
||||
}
|
||||
|
||||
if param_dtypes:
|
||||
config_dict["dtype"] = common_broadcastable_dtype(param_dtypes)
|
||||
logger.info_once(
|
||||
"Inferred from consolidated*.safetensors files "
|
||||
f"{config_dict['dtype']} dtype."
|
||||
)
|
||||
|
||||
config = adapt_config_dict(config_dict, defaults=hf_config_dict)
|
||||
|
||||
return config_dict, config
|
||||
|
||||
@@ -113,12 +113,13 @@ def _remap_mistral_vision_args(config: dict) -> dict:
|
||||
|
||||
def _remap_mistral_yarn_args(config: dict) -> dict:
|
||||
yarn_config_map = {
|
||||
"factor": "factor",
|
||||
"original_max_position_embeddings": "original_max_position_embeddings",
|
||||
"beta": "beta_fast",
|
||||
"alpha": "beta_slow",
|
||||
"apply_scale": "apply_yarn_scaling",
|
||||
"factor": ("factor", float),
|
||||
"original_max_position_embeddings": ("original_max_position_embeddings", int),
|
||||
"beta": ("beta_fast", float),
|
||||
"alpha": ("beta_slow", float),
|
||||
"apply_scale": ("apply_yarn_scaling", bool),
|
||||
}
|
||||
|
||||
yarn_config = config.get("yarn") or {}
|
||||
config["rope_parameters"] = {
|
||||
"rope_type": "yarn",
|
||||
@@ -128,9 +129,10 @@ def _remap_mistral_yarn_args(config: dict) -> dict:
|
||||
if rope_theta := config.pop("rope_theta", None):
|
||||
config["rope_parameters"]["rope_theta"] = rope_theta
|
||||
|
||||
for old_name, new_name in yarn_config_map.items():
|
||||
for old_name, (new_name, cast) in yarn_config_map.items():
|
||||
if old_name in yarn_config:
|
||||
config["rope_parameters"][new_name] = yarn_config.pop(old_name)
|
||||
# Cast to remove Transformers > v5 type warnings
|
||||
config["rope_parameters"][new_name] = cast(yarn_config.pop(old_name))
|
||||
|
||||
assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"
|
||||
|
||||
@@ -154,6 +156,7 @@ def _remap_general_mistral_args(config: dict) -> dict:
|
||||
"tie_word_embeddings": ("tied_embeddings", False),
|
||||
"max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
|
||||
"max_position_embeddings": ("max_position_embeddings", 128_000),
|
||||
"dtype": ("dtype", config.get("dtype")),
|
||||
}
|
||||
|
||||
for key, new_key in config_mapping.items():
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import final
|
||||
|
||||
import torch
|
||||
from huggingface_hub import constants
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
@@ -25,22 +22,6 @@ from vllm.utils.torch_utils import common_broadcastable_dtype
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _maybe_patch_hf_hub_constants(config_format: ConfigFormat) -> Iterator[None]:
|
||||
if config_format == "mistral":
|
||||
hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE
|
||||
hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE
|
||||
constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors"
|
||||
constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file
|
||||
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
class ModelArchConfigConvertorBase:
|
||||
def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig):
|
||||
self.hf_config = hf_config
|
||||
@@ -164,8 +145,7 @@ class ModelArchConfigConvertorBase:
|
||||
|
||||
# Try to read the dtype of the weights if they are in safetensors format
|
||||
if config_dtype is None:
|
||||
with _maybe_patch_hf_hub_constants(config_format):
|
||||
param_mt = get_safetensors_params_metadata(model_id, revision=revision)
|
||||
param_mt = get_safetensors_params_metadata(model_id, revision=revision)
|
||||
|
||||
if param_mt:
|
||||
param_dtypes: set[torch.dtype] = {
|
||||
|
||||
Reference in New Issue
Block a user