[Core] Support configuration parsing plugin (#24277)

Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
Signed-off-by: Xingyu Liu <38244988+charlotte12l@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Xingyu Liu
2025-09-10 11:32:43 -07:00
committed by GitHub
parent 4032949630
commit 9fb74c27a7
6 changed files with 237 additions and 107 deletions

View File

@@ -1,13 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import json
import os
import time
from functools import cache, partial
from pathlib import Path
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, Literal, Optional, TypeVar, Union
import huggingface_hub
from huggingface_hub import get_safetensors_metadata, hf_hub_download
@@ -27,6 +26,7 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils.config_parser_base import ConfigParserBase
from vllm.transformers_utils.utils import check_gguf_file
if envs.VLLM_USE_MODELSCOPE:
@@ -100,10 +100,163 @@ _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = {
}
class ConfigFormat(str, enum.Enum):
AUTO = "auto"
HF = "hf"
MISTRAL = "mistral"
class HFConfigParser(ConfigParserBase):
def parse(self,
model: Union[str, Path],
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
**kwargs) -> tuple[dict, PretrainedConfig]:
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
config_dict, _ = PretrainedConfig.get_config_dict(
model,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
# Use custom model class if it's in our registry
model_type = config_dict.get("model_type")
if model_type is None:
model_type = "speculators" if config_dict.get(
"speculators_config") is not None else model_type
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(
model,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
else:
try:
kwargs = _maybe_update_auto_config_kwargs(
kwargs, model_type=model_type)
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
except ValueError as e:
if (not trust_remote_code
and "requires you to execute the configuration file"
in str(e)):
err_msg = (
"Failed to load the model config. If the model "
"is a custom model not yet available in the "
"HuggingFace transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
config = _maybe_remap_hf_config_attrs(config)
return config_dict, config
class MistralConfigParser(ConfigParserBase):
def parse(self,
model: Union[str, Path],
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
**kwargs) -> tuple[dict, PretrainedConfig]:
# This function loads a params.json config which
# should be used when loading models in mistral format
config_dict = _download_mistral_config_file(model, revision)
if (max_position_embeddings :=
config_dict.get("max_position_embeddings")) is None:
max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
model, revision, **kwargs)
config_dict["max_position_embeddings"] = max_position_embeddings
from vllm.transformers_utils.configs.mistral import adapt_config_dict
config = adapt_config_dict(config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
if ((sliding_window := getattr(config, "sliding_window", None))
and isinstance(sliding_window, list)):
pattern_repeats = config.num_hidden_layers // len(sliding_window)
layer_types = sliding_window * pattern_repeats
config.layer_types = [
"full_attention" if layer_type is None else "sliding_attention"
for layer_type in layer_types
]
config.sliding_window = next(filter(None, sliding_window), None)
return config_dict, config
_CONFIG_FORMAT_TO_CONFIG_PARSER: dict[str, type[ConfigParserBase]] = {
"hf": HFConfigParser,
"mistral": MistralConfigParser,
}
ConfigFormat = Literal[
"auto",
"hf",
"mistral",
]
def get_config_parser(config_format: str) -> ConfigParserBase:
"""Get the config parser for a given config format."""
if config_format not in _CONFIG_FORMAT_TO_CONFIG_PARSER:
raise ValueError(f"Unknown config format `{config_format}`.")
return _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format]()
def register_config_parser(config_format: str):
"""Register a customized vllm config parser.
When a config format is not supported by vllm, you can register a customized
config parser to support it.
Args:
config_format (str): The config parser format name.
Examples:
>>> from vllm.transformers_utils.config import (get_config_parser,
register_config_parser)
>>> from vllm.transformers_utils.config_parser_base import ConfigParserBase
>>>
>>> @register_config_parser("custom_config_parser")
... class CustomConfigParser(ConfigParserBase):
... def parse(self,
... model: Union[str, Path],
... trust_remote_code: bool,
... revision: Optional[str] = None,
... code_revision: Optional[str] = None,
... **kwargs) -> tuple[dict, PretrainedConfig]:
... raise NotImplementedError
>>>
>>> type(get_config_parser("custom_config_parser"))
<class 'CustomConfigParser'>
""" # noqa: E501
def _wrapper(config_parser_cls):
if config_format in _CONFIG_FORMAT_TO_CONFIG_PARSER:
logger.warning(
"Config format `%s` is already registered, and will be "
"overwritten by the new parser class `%s`.", config_format,
config_parser_cls)
if not issubclass(config_parser_cls, ConfigParserBase):
raise ValueError("The config parser must be a subclass of "
"`ConfigParserBase`.")
_CONFIG_FORMAT_TO_CONFIG_PARSER[config_format] = config_parser_cls
logger.info("Registered config parser `%s` with config format `%s`",
config_parser_cls, config_format)
return config_parser_cls
return _wrapper
_R = TypeVar("_R")
@@ -350,7 +503,7 @@ def get_config(
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
config_format: Union[str, ConfigFormat] = "auto",
hf_overrides_kw: Optional[dict[str, Any]] = None,
hf_overrides_fn: Optional[Callable[[PretrainedConfig],
PretrainedConfig]] = None,
@@ -363,20 +516,22 @@ def get_config(
kwargs["gguf_file"] = Path(model).name
model = Path(model).parent
if config_format == ConfigFormat.AUTO:
if config_format == "auto":
try:
if is_gguf or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision):
config_format = ConfigFormat.HF
config_format = "hf"
elif file_or_path_exists(model,
MISTRAL_CONFIG_NAME,
revision=revision):
config_format = ConfigFormat.MISTRAL
config_format = "mistral"
else:
raise ValueError(
"Could not detect config format for no config file found. "
"Ensure your model has either config.json (HF format) "
"or params.json (Mistral format).")
"With config_format 'auto', ensure your model has either"
"config.json (HF format) or params.json (Mistral format)."
"Otherwise please specify your_custom_config_format"
"in engine args for customized config parser")
except Exception as e:
error_message = (
@@ -395,92 +550,14 @@ def get_config(
raise ValueError(error_message) from e
if config_format == ConfigFormat.HF:
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
config_dict, _ = PretrainedConfig.get_config_dict(
model,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
# Use custom model class if it's in our registry
model_type = config_dict.get("model_type")
if model_type is None:
model_type = "speculators" if config_dict.get(
"speculators_config") is not None else model_type
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(
model,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
else:
try:
kwargs = _maybe_update_auto_config_kwargs(
kwargs, model_type=model_type)
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
except ValueError as e:
if (not trust_remote_code
and "requires you to execute the configuration file"
in str(e)):
err_msg = (
"Failed to load the model config. If the model "
"is a custom model not yet available in the "
"HuggingFace transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
config = _maybe_remap_hf_config_attrs(config)
elif config_format == ConfigFormat.MISTRAL:
# This function loads a params.json config which
# should be used when loading models in mistral format
config_dict = _download_mistral_config_file(model, revision)
if (max_position_embeddings :=
config_dict.get("max_position_embeddings")) is None:
max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
model, revision, **kwargs)
config_dict["max_position_embeddings"] = max_position_embeddings
from vllm.transformers_utils.configs.mistral import adapt_config_dict
config = adapt_config_dict(config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
if ((sliding_window := getattr(config, "sliding_window", None))
and isinstance(sliding_window, list)):
pattern_repeats = config.num_hidden_layers // len(sliding_window)
layer_types = sliding_window * pattern_repeats
config.layer_types = [
"full_attention" if layer_type is None else "sliding_attention"
for layer_type in layer_types
]
config.sliding_window = next(filter(None, sliding_window), None)
else:
supported_formats = [
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
]
raise ValueError(
f"Unsupported config format: {config_format}. "
f"Supported formats are: {', '.join(supported_formats)}. "
f"Ensure your model uses one of these configuration formats "
f"or specify the correct format explicitly.")
config_parser = get_config_parser(config_format)
config_dict, config = config_parser.parse(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
**kwargs,
)
# Special architecture mapping check for GGUF models
if is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
@@ -914,7 +991,7 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
hf_config = get_config(model=model,
trust_remote_code=trust_remote_code_val,
revision=revision,
config_format=ConfigFormat.HF)
config_format="hf")
if hf_value := hf_config.get_text_config().max_position_embeddings:
max_position_embeddings = hf_value
except Exception as e: