[Bugfix] Fix unable to load some models (#10312)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-15 08:55:54 +08:00
committed by GitHub
parent 11cd1ae6ad
commit 972112d82f
13 changed files with 340 additions and 59 deletions

View File

@@ -3,8 +3,8 @@ import enum
import json
import warnings
from dataclasses import dataclass, field, replace
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List,
Literal, Mapping, Optional, Set, Tuple, Type, Union)
import torch
from transformers import PretrainedConfig
@@ -20,7 +20,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
print_warning_once)
identity, print_warning_once)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@@ -44,6 +44,9 @@ TaskOption = Literal["auto", "generate", "embedding"]
# "draft" is only used internally for speculative decoding
_Task = Literal["generate", "embedding", "draft"]
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]]
class ModelConfig:
"""Configuration for the model.
@@ -115,7 +118,9 @@ class ModelConfig:
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
hf_overrides: Arguments to be forwarded to the HuggingFace config.
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
pooling_type: Used to configure the pooling method in the embedding
@@ -164,7 +169,7 @@ class ModelConfig:
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
hf_overrides: Optional[Dict[str, Any]] = None,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
@@ -182,15 +187,23 @@ class ModelConfig:
if hf_overrides is None:
hf_overrides = {}
if callable(hf_overrides):
hf_overrides_kw = {}
hf_overrides_fn = hf_overrides
else:
hf_overrides_kw = hf_overrides
hf_overrides_fn = identity
if rope_scaling is not None:
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
hf_overrides.update(hf_override)
hf_overrides_kw.update(hf_override)
msg = ("`--rope-scaling` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
if rope_theta is not None:
hf_override = {"rope_theta": rope_theta}
hf_overrides.update(hf_override)
hf_overrides_kw.update(hf_override)
msg = ("`--rope-theta` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
@@ -207,9 +220,12 @@ class ModelConfig:
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, config_format,
**hf_overrides)
hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, config_format, **hf_overrides_kw)
hf_config = hf_overrides_fn(hf_config)
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(