[Bugfix] Fix unable to load some models (#10312)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -3,8 +3,8 @@ import enum
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
|
||||
Mapping, Optional, Set, Tuple, Type, Union)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List,
|
||||
Literal, Mapping, Optional, Set, Tuple, Type, Union)
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@@ -20,7 +20,7 @@ from vllm.transformers_utils.config import (
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
print_warning_once)
|
||||
identity, print_warning_once)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@@ -44,6 +44,9 @@ TaskOption = Literal["auto", "generate", "embedding"]
|
||||
# "draft" is only used internally for speculative decoding
|
||||
_Task = Literal["generate", "embedding", "draft"]
|
||||
|
||||
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
|
||||
PretrainedConfig]]
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""Configuration for the model.
|
||||
@@ -115,7 +118,9 @@ class ModelConfig:
|
||||
can not be gathered from the vllm arguments.
|
||||
config_format: The config format which shall be loaded.
|
||||
Defaults to 'auto' which defaults to 'hf'.
|
||||
hf_overrides: Arguments to be forwarded to the HuggingFace config.
|
||||
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
||||
HuggingFace config. If a callable, it is called to update the
|
||||
HuggingFace config.
|
||||
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
||||
for multi-modal data, e.g., image processor.
|
||||
pooling_type: Used to configure the pooling method in the embedding
|
||||
@@ -164,7 +169,7 @@ class ModelConfig:
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
chat_template_text_format: str = "string",
|
||||
hf_overrides: Optional[Dict[str, Any]] = None,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
pooling_type: Optional[str] = None,
|
||||
pooling_norm: Optional[bool] = None,
|
||||
@@ -182,15 +187,23 @@ class ModelConfig:
|
||||
|
||||
if hf_overrides is None:
|
||||
hf_overrides = {}
|
||||
|
||||
if callable(hf_overrides):
|
||||
hf_overrides_kw = {}
|
||||
hf_overrides_fn = hf_overrides
|
||||
else:
|
||||
hf_overrides_kw = hf_overrides
|
||||
hf_overrides_fn = identity
|
||||
|
||||
if rope_scaling is not None:
|
||||
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
|
||||
hf_overrides.update(hf_override)
|
||||
hf_overrides_kw.update(hf_override)
|
||||
msg = ("`--rope-scaling` will be removed in a future release. "
|
||||
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
if rope_theta is not None:
|
||||
hf_override = {"rope_theta": rope_theta}
|
||||
hf_overrides.update(hf_override)
|
||||
hf_overrides_kw.update(hf_override)
|
||||
msg = ("`--rope-theta` will be removed in a future release. "
|
||||
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
@@ -207,9 +220,12 @@ class ModelConfig:
|
||||
self.max_logprobs = max_logprobs
|
||||
self.disable_sliding_window = disable_sliding_window
|
||||
self.skip_tokenizer_init = skip_tokenizer_init
|
||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision, config_format,
|
||||
**hf_overrides)
|
||||
|
||||
hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision, config_format, **hf_overrides_kw)
|
||||
hf_config = hf_overrides_fn(hf_config)
|
||||
self.hf_config = hf_config
|
||||
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.encoder_config = self._get_encoder_config()
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
|
||||
Reference in New Issue
Block a user