[Core] Parse vLLM engine required fields from hf_config to model_arch_config (#28454)

Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
Signed-off-by: Xingyu Liu <38244988+charlotte12l@users.noreply.github.com>
This commit is contained in:
Xingyu Liu
2026-01-02 16:13:15 -07:00
committed by GitHub
parent a0e9ee83c7
commit 0eee877f67
11 changed files with 1121 additions and 287 deletions

View File

@@ -10,10 +10,12 @@ from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.model_arch import (
ModelArchitectureConfig,
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
@@ -31,7 +33,6 @@ from vllm.transformers_utils.config import (
is_rope_parameters_nested,
try_get_dense_modules,
try_get_generation_config,
try_get_safetensors_metadata,
try_get_tokenizer_config,
uses_mrope,
uses_xdrope_dim,
@@ -42,10 +43,13 @@ from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf,
split_remote_gguf,
)
from vllm.transformers_utils.model_arch_config_convertor import (
MODEL_ARCH_CONFIG_CONVERTORS,
ModelArchConfigConvertorBase,
)
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype
if TYPE_CHECKING:
from transformers import PretrainedConfig
@@ -483,6 +487,7 @@ class ModelConfig:
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=self.hf_token, revision=self.revision
)
self.model_arch_config = self.get_model_arch_config()
architectures = self.architectures
registry = self.registry
@@ -600,6 +605,15 @@ class ModelConfig:
self._verify_cuda_graph()
self._verify_bnb_config()
def get_model_arch_config(
self,
) -> ModelArchitectureConfig:
convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get(
self.hf_config.model_type, ModelArchConfigConvertorBase
)
convertor = convertor_cls(self.hf_config, self.hf_text_config)
return convertor.convert()
@field_validator("tokenizer", "max_model_len", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
@@ -675,7 +689,7 @@ class ModelConfig:
@property
def architectures(self) -> list[str]:
return getattr(self.hf_config, "architectures", [])
return self.model_arch_config.architectures
@property
def architecture(self) -> str:
@@ -835,56 +849,16 @@ class ModelConfig:
return convert_type
def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
quant_cfg = getattr(hf_config, "quantization_config", None)
if quant_cfg is None:
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr(hf_config, "compression_config", None)
else:
# Set quant_method for ModelOpt models.
producer_name = quant_cfg.get("producer", {}).get("name")
if producer_name == "modelopt":
quant_algo = quant_cfg.get("quantization", {}).get("quant_algo")
if quant_algo is not None:
quant_algo_upper = str(quant_algo).upper()
if quant_algo_upper in {
"FP8",
"FP8_PER_CHANNEL_PER_TOKEN",
"FP8_PB_WO",
}:
quant_cfg["quant_method"] = "modelopt"
elif quant_algo_upper == "NVFP4":
quant_cfg["quant_method"] = "modelopt_fp4"
else:
raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}")
return quant_cfg
def _verify_quantization(self) -> None:
supported_quantization = me_quant.QUANTIZATION_METHODS
if self.quantization is not None:
self.quantization = cast(me_quant.QuantizationMethods, self.quantization)
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config(self.hf_config)
if quant_cfg is None and (
text_config := getattr(self.hf_config, "text_config", None)
):
# Check the text config as well for multi-modal models.
quant_cfg = self._parse_quant_hf_config(text_config)
quant_cfg = self.model_arch_config.quantization_config
if quant_cfg is not None:
# Use the community standard 'quant_method'
quant_method = quant_cfg.get("quant_method", "").lower()
# Normalize library names
quant_method = quant_method.replace(
"compressed_tensors", "compressed-tensors"
)
quant_cfg["quant_method"] = quant_method
quant_method = quant_cfg["quant_method"]
# Quantization methods which are overrides (i.e. they have a
# `override_quantization_method` method) must be checked in order
# of preference (this is particularly important for GPTQ).
@@ -966,7 +940,7 @@ class ModelConfig:
logger.warning(
"CUDA graph is not supported for %s on ROCm yet, fallback "
"to eager mode.",
self.hf_config.model_type,
self.model_arch_config.model_type,
)
self.enforce_eager = True
@@ -977,11 +951,9 @@ class ModelConfig:
# TODO Remove this when bitsandbytes supports.
"""
is_bitsandbytes = self.quantization == "bitsandbytes"
has_quantization_config = (
getattr(self.hf_config, "quantization_config", None) is not None
)
has_quantization_config = self.model_arch_config.quantization_config is not None
is_8bit = (
self.hf_config.quantization_config.get("load_in_8bit", False)
self.model_arch_config.quantization_config.get("load_in_8bit", False)
if has_quantization_config
else False
)
@@ -1051,9 +1023,7 @@ class ModelConfig:
self,
parallel_config: ParallelConfig,
) -> None:
total_num_attention_heads = getattr(
self.hf_text_config, "num_attention_heads", 0
)
total_num_attention_heads = self.model_arch_config.total_num_attention_heads
tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError(
@@ -1104,10 +1074,10 @@ class ModelConfig:
return getattr(self.hf_text_config, "sliding_window", None)
def get_vocab_size(self) -> int:
return getattr(self.hf_text_config, "vocab_size", 0)
return self.model_arch_config.vocab_size
def get_hidden_size(self) -> int:
return getattr(self.hf_text_config, "hidden_size", 0)
return self.model_arch_config.hidden_size
def get_inputs_embeds_size(self) -> int:
# The size of inputs_embeds is usually identical to the size
@@ -1120,29 +1090,7 @@ class ModelConfig:
@property
def is_deepseek_mla(self) -> bool:
if not hasattr(self.hf_text_config, "model_type"):
return False
elif self.hf_text_config.model_type in (
"deepseek_v2",
"deepseek_v3",
"deepseek_v32",
"deepseek_mtp",
"kimi_k2",
"kimi_linear",
"longcat_flash",
"pangu_ultra_moe",
"pangu_ultra_moe_mtp",
):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == "eagle":
# if the model is an EAGLE module, check for the
# underlying architecture
return (
self.hf_text_config.model.model_type
in ("deepseek_v2", "deepseek_v3", "deepseek_v32")
and self.hf_text_config.kv_lora_rank is not None
)
return False
return self.model_arch_config.is_deepseek_mla
@cached_property
def is_mm_prefix_lm(self) -> bool:
@@ -1158,103 +1106,11 @@ class ModelConfig:
return self.hf_config.model_type in MM_PREFIX_LM_MODELS
def get_head_size(self) -> int:
# TODO remove hard code
if self.is_deepseek_mla:
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0)
if self.use_mla:
return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
else:
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0)
if qk_rope_head_dim and qk_nope_head_dim:
return qk_rope_head_dim + qk_nope_head_dim
if hasattr(self.hf_text_config, "model_type") and (
self.hf_text_config.model_type == "zamba2"
):
return self.hf_text_config.attention_head_dim
if self.is_attention_free:
return 0
# NOTE: Some configs may set head_dim=None in the config
if getattr(self.hf_text_config, "head_dim", None) is not None:
return self.hf_text_config.head_dim
# NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head`
if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None:
return self.hf_text_config.hidden_size_per_head
# FIXME(woosuk): This may not be true for all models.
return (
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads
)
return self.model_arch_config.head_size
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
# For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)
)
if not new_decoder_arch_falcon and getattr(
self.hf_text_config, "multi_query", False
):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For DBRX and MPT
if self.hf_config.model_type == "mpt":
if "kv_n_heads" in self.hf_config.attn_config:
return self.hf_config.attn_config["kv_n_heads"]
return self.hf_config.num_attention_heads
if self.hf_config.model_type == "dbrx":
return getattr(
self.hf_config.attn_config,
"kv_n_heads",
self.hf_config.num_attention_heads,
)
if self.hf_config.model_type == "nemotron-nas":
for block in self.hf_config.block_configs:
if not block.attention.no_op:
return (
self.hf_config.num_attention_heads
// block.attention.n_heads_in_group
)
raise RuntimeError(
"Could not determine the number of key-value attention heads "
"from model configuration. "
f"Model: {self.model}, Architecture: {self.architectures}. "
"This usually indicates an unsupported model architecture or "
"missing configuration. "
"Please check if your model is supported at: "
"https://docs.vllm.ai/en/latest/models/supported_models.html"
)
if self.is_attention_free:
return 0
attributes = [
# For Falcon:
"n_head_kv",
"num_kv_heads",
# For LLaMA-2:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
]
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
default_factory = lambda: self.hf_text_config.num_attention_heads
return getattr_iter(
self.hf_text_config, attributes, default_factory=default_factory
)
return self.model_arch_config.total_num_kv_heads
def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int:
"""Returns the number of KV heads per GPU."""
@@ -1270,46 +1126,14 @@ class ModelConfig:
return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size)
def get_num_attention_heads(self, parallel_config: ParallelConfig) -> int:
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
num_heads = self.model_arch_config.total_num_attention_heads
return num_heads // parallel_config.tensor_parallel_size
def get_num_experts(self) -> int:
"""Returns the number of experts in the model."""
num_expert_names = [
"num_experts", # Jamba
"moe_num_experts", # Dbrx
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
if isinstance(num_experts, list):
# Ernie VL's remote code uses list[int]...
# The values are always the same so we just take the first one.
return num_experts[0]
# Coerce to 0 if explicitly set to None
return num_experts or 0
return self.model_arch_config.num_experts
def get_total_num_hidden_layers(self) -> int:
if (
self.hf_text_config.model_type == "deepseek_mtp"
or self.hf_config.model_type == "mimo_mtp"
or self.hf_config.model_type == "glm4_moe_mtp"
or self.hf_config.model_type == "ernie_mtp"
or self.hf_config.model_type == "qwen3_next_mtp"
or self.hf_config.model_type == "pangu_ultra_moe_mtp"
):
total_num_hidden_layers = getattr(
self.hf_text_config, "num_nextn_predict_layers", 0
)
elif self.hf_config.model_type == "longcat_flash_mtp":
total_num_hidden_layers = getattr(
self.hf_text_config, "num_nextn_predict_layers", 1
)
else:
total_num_hidden_layers = getattr(
self.hf_text_config, "num_hidden_layers", 0
)
return total_num_hidden_layers
return self.model_arch_config.total_num_hidden_layers
def get_layers_start_end_indices(
self, parallel_config: ParallelConfig
@@ -1360,9 +1184,7 @@ class ModelConfig:
self.hf_text_config, "layers_block_type", None
)
if layers_block_type_value is not None:
if hasattr(self.hf_text_config, "model_type") and (
self.hf_text_config.model_type == "zamba2"
):
if self.model_arch_config.text_model_type == "zamba2":
if attn_block_type:
return sum(
t == "hybrid" for t in layers_block_type_value[start:end]
@@ -1677,6 +1499,7 @@ class ModelConfig:
)
max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
model_arch_config=self.model_arch_config,
tokenizer_config=tokenizer_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
@@ -1907,46 +1730,6 @@ def _check_valid_dtype(model_type: str, dtype: torch.dtype):
return True
def _find_dtype(
model_id: str,
config: PretrainedConfig,
*,
revision: str | None,
):
# NOTE: getattr(config, "dtype", torch.float32) is not correct
# because config.dtype can be None.
config_dtype = getattr(config, "dtype", None)
# Fallbacks for multi-modal models if the root config
# does not define dtype
if config_dtype is None:
config_dtype = getattr(config.get_text_config(), "dtype", None)
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "dtype", None)
if config_dtype is None and hasattr(config, "encoder_config"):
config_dtype = getattr(config.encoder_config, "dtype", None)
# Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None:
repo_mt = try_get_safetensors_metadata(model_id, revision=revision)
if repo_mt and (files_mt := repo_mt.files_metadata):
param_dtypes: set[torch.dtype] = {
_SAFETENSORS_TO_TORCH_DTYPE[dtype_str]
for file_mt in files_mt.values()
for dtype_str in file_mt.parameter_count
if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE
}
if param_dtypes:
return common_broadcastable_dtype(param_dtypes)
if config_dtype is None:
config_dtype = torch.float32
return config_dtype
def _resolve_auto_dtype(
model_type: str,
config_dtype: torch.dtype,
@@ -2001,7 +1784,9 @@ def _get_and_verify_dtype(
is_pooling_model: bool,
revision: str | None = None,
) -> torch.dtype:
config_dtype = _find_dtype(model_id, config, revision=revision)
config_dtype = ModelArchConfigConvertorBase.get_torch_dtype(
config, model_id, revision=revision
)
model_type = config.model_type
if isinstance(dtype, str):
@@ -2064,6 +1849,7 @@ def _get_head_dtype(
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
model_arch_config: ModelArchitectureConfig,
tokenizer_config: dict | None,
max_model_len: int | None,
disable_sliding_window: bool,
@@ -2072,36 +1858,9 @@ def _get_and_verify_max_len(
encoder_config: Any | None = None,
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# ChatGLM2
"seq_length",
# Command-R
"model_max_length",
# Whisper
"max_target_positions",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
# Choose the smallest "max_length" from the possible keys
max_len_key = None
for key in possible_keys:
max_len = getattr(hf_config, key, None)
if max_len is not None:
max_len_key = key if max_len < derived_max_model_len else max_len_key
derived_max_model_len = min(derived_max_model_len, max_len)
# For Command-R / Cohere, Cohere2 / Aya Vision models
if tmp_max_len := getattr(hf_config, "model_max_length", None):
max_len_key = "model_max_length"
derived_max_model_len = tmp_max_len
(derived_max_model_len, max_len_key) = (
model_arch_config.derived_max_model_len_and_key
)
# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
@@ -2134,10 +1893,9 @@ def _get_and_verify_max_len(
default_max_len = 2048
logger.warning(
"The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: "
"%s. Assuming the model's maximum length is %d.",
possible_keys,
"The model's config.json does not contain any of the keys "
"to determine the original maximum length of the model. "
"Assuming the model's maximum length is %d.",
default_max_len,
)
derived_max_model_len = default_max_len