Improve configs - ModelConfig (#17130)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-30 03:38:22 +01:00
committed by GitHub
parent 2c4f59afc3
commit 13698db634
36 changed files with 490 additions and 648 deletions

View File

@@ -16,9 +16,8 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace)
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, TypeVar, Union, cast, get_args,
get_origin)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args, get_origin)
import torch
from pydantic import BaseModel, Field, PrivateAttr
@@ -211,103 +210,190 @@ def get_field(cls: ConfigType, name: str) -> Field:
f"{cls.__name__}.{name} must have a default value or default factory.")
class ModelConfig:
"""Configuration for the model.
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
Args:
model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
task: The task to use the model for. Each vLLM instance only supports
one task, even if the same model can be used for multiple tasks.
When the model only supports one task, "auto" can be used to select
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer,
"mistral" will always use the tokenizer from `mistral_common`, and
"custom" will use --tokenizer to select the preregistered tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
spec_target_max_model_len: Specify the the maximum length for spec
decoding draft models.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
If None, the user did not specify, so default to False.
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
max_logprobs: Maximum number of log probabilities. Defaults to 20.
disable_sliding_window: Whether to disable sliding window. If True,
we will disable the sliding window functionality of the model.
If the model does not support sliding window, this argument is
ignored.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified,
the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data items per modality
per prompt. Only applicable for multimodal models.
mm_processor_kwargs: Overrides for the multi-modal processor obtained
from `AutoProcessor.from_pretrained`.
disable_mm_preprocessor_cache: If True, disable caching of the
processed multi-modal inputs.
use_async_output_proc: Whether to use async output processor.
Defaults to True.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
override_pooler_config: Initialize non default pooling config or
override default pooling config for the pooling model.
logits_processor_pattern: Optional regex pattern specifying valid
logits processor qualified names that can be passed with the
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
model_impl: Which implementation of the model to use:
"auto" will try to use the vLLM implementation if it exists and
fall back to the Transformers implementation if no vLLM
implementation is available.
"vllm" will use the vLLM model implementation.
"transformers" will use the Transformers model implementation.
override_generation_config: Override the generation config with the
given config.
"""
@config
@dataclass
class ModelConfig:
"""Configuration for the model."""
model: str = "facebook/opt-125m"
"""Name or path of the Hugging Face model to use. It is also used as the
content for `model_name` tag in metrics output when `served_model_name` is
not specified."""
task: Literal[TaskOption, Literal["draft"]] = "auto"
"""The task to use the model for. Each vLLM instance only supports one
task, even if the same model can be used for multiple tasks. When the model
only supports one task, "auto" can be used to select it; otherwise, you
must specify explicitly which task to use."""
tokenizer: str = None # type: ignore
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
tokenizer_mode: TokenizerMode = "auto"
"""Tokenizer mode:\n
- "auto" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "custom" will use --tokenizer to select the preregistered tokenizer."""
trust_remote_code: bool = False
"""Trust remote code (e.g., from HuggingFace) when downloading the model
and tokenizer."""
dtype: Union[ModelDType, torch.dtype] = "auto"
"""Data type for model weights and activations:\n
- "auto" will use FP16 precision for FP32 and FP16 models, and BF16
precision for BF16 models.\n
- "half" for FP16. Recommended for AWQ quantization.\n
- "float16" is the same as "half".\n
- "bfloat16" for a balance between precision and range.\n
- "float" is shorthand for FP32 precision.\n
- "float32" for FP32 precision."""
seed: Optional[int] = None
"""Random seed for reproducibility."""
hf_config_path: Optional[str] = None
"""Name or path of the Hugging Face config to use. If unspecified, model
name or path will be used."""
allowed_local_media_path: str = ""
"""Allowing API requests to read local images or videos from directories
specified by the server file system. This is a security risk. Should only
be enabled in trusted environments."""
revision: Optional[str] = None
"""The specific model version to use. It can be a branch name, a tag name,
or a commit id. If unspecified, will use the default version."""
code_revision: Optional[str] = None
"""The specific revision to use for the model code on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
rope_scaling: dict[str, Any] = field(default_factory=dict)
"""RoPE scaling configuration in JSON format. For example,
`{"rope_type":"dynamic","factor":2.0}`."""
rope_theta: Optional[float] = None
"""RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE
theta improves the performance of the scaled model."""
tokenizer_revision: Optional[str] = None
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
max_model_len: int = None # type: ignore
"""Model context length (prompt and output). If unspecified, will be
automatically derived from the model config.
When passing via `--max-model-len`, supports k/m/g/K/M/G in human-readable
format. Examples:\n
- 1k -> 1000\n
- 1K -> 1024\n
- 25.6k -> 25,600"""
spec_target_max_model_len: Optional[int] = None
"""Specify the the maximum length for spec decoding draft models."""
quantization: Optional[QuantizationMethods] = None
"""Method used to quantize the weights. If `None`, we first check the
`quantization_config` attribute in the model config file. If that is
`None`, we assume the model weights are not quantized and use `dtype` to
determine the data type of the weights."""
enforce_eager: bool = False
"""Whether to always use eager-mode PyTorch. If True, we will disable CUDA
graph and always execute the model in eager mode. If False, we will use
CUDA graph and eager execution in hybrid for maximal performance and
flexibility."""
max_seq_len_to_capture: int = 8192
"""Maximum sequence len covered by CUDA graphs. When a sequence has context
length larger than this, we fall back to eager mode. Additionally for
encoder-decoder models, if the sequence length of the encoder input is
larger than this, we fall back to the eager mode."""
max_logprobs: int = 20
"""Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API."""
disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
model does not support sliding window, this argument is ignored."""
disable_cascade_attn: bool = False
"""Disable cascade attention for V1. While cascade attention does not
change the mathematical correctness, disabling it could be useful for
preventing potential numerical issues. Note that even if this is set to
False, cascade attention will be only used when the heuristic tells that
it's beneficial."""
skip_tokenizer_init: bool = False
"""Skip initialization of tokenizer and detokenizer. Expects valid
`prompt_token_ids` and `None` for prompt from the input. The generated
output will contain token ids."""
served_model_name: Optional[Union[str, list[str]]] = None
"""The model name(s) used in the API. If multiple names are provided, the
server will respond to any of the provided names. The model name in the
model field of a response will be the first name in this list. If not
specified, the model name will be the same as the `--model` argument. Noted
that this name(s) will also be used in `model_name` tag content of
prometheus metrics, if multiple names provided, metrics tag will take the
first one."""
limit_mm_per_prompt: dict[str, int] = field(default_factory=dict)
"""Maximum number of data items per modality per prompt. Only applicable
for multimodal models."""
use_async_output_proc: bool = True
"""Whether to use async output processor."""
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
"""The format of the model config to load:\n
- "auto" will try to load the config in hf format if available else it
will try to load in mistral format.\n
- "hf" will load the config in hf format.\n
- "mistral" will load the config in mistral format."""
hf_token: Optional[Union[bool, str]] = None
"""The token to use as HTTP bearer authorization for remote files . If
`True`, will use the token generated when running `huggingface-cli login`
(stored in `~/.huggingface`)."""
hf_overrides: HfOverrides = field(default_factory=dict)
"""If a dictionary, contains arguments to be forwarded to the Hugging Face
config. If a callable, it is called to update the HuggingFace config. When
specified via CLI, the argument must be a valid JSON string."""
mm_processor_kwargs: Optional[dict[str, Any]] = None
"""Arguments to be forwarded to the model's processor for multi-modal data,
e.g., image processor. Overrides for the multi-modal processor obtained
from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
When specified via CLI, the argument must be a valid JSON string."""
disable_mm_preprocessor_cache: bool = False
"""If `True`, disable caching of the multi-modal preprocessor/mapper (not
recommended)."""
override_neuron_config: dict[str, Any] = field(default_factory=dict)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
configure the neuron config that can not be gathered from the vllm
arguments. e.g. `{"cast_logits_dtype": "bloat16"}`. When specified via CLI,
the argument must be a valid JSON string."""
pooler_config: Optional["PoolerConfig"] = field(init=False)
"""Pooler config which controls the behaviour of output pooling in pooling
models."""
override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
"""Initialize non-default pooling config or override default pooling config
for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`.
When specified via CLI, the argument must be a valid JSON string."""
logits_processor_pattern: Optional[str] = None
"""Optional regex pattern specifying valid logits processor qualified names
that can be passed with the `logits_processors` extra completion argument.
Defaults to `None`, which allows no processors."""
generation_config: str = "auto"
"""The folder path to the generation config. Defaults to `"auto"`, the
generation config will be loaded from model path. If set to `"vllm"`, no
generation config is loaded, vLLM defaults will be used. If set to a folder
path, the generation config will be loaded from the specified folder path.
If `max_new_tokens` is specified in generation config, then it sets a
server-wide limit on the number of output tokens for all requests."""
override_generation_config: dict[str, Any] = field(default_factory=dict)
"""Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If
used with `--generation-config auto`, the override parameters will be
merged with the default config from the model. If used with
`--generation-config vllm`, only the override parameters are used.
When specified via CLI, the argument must be a valid JSON string."""
enable_sleep_mode: bool = False
"""Enable sleep mode for the engine (only cuda platform is supported)."""
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value
"""Which implementation of the model to use:\n
- "auto" will try to use the vLLM implementation, if it exists, and fall
back to the Transformers implementation if no vLLM implementation is
available.\n
- "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation."""
def compute_hash(self) -> str:
"""
@@ -342,92 +428,43 @@ class ModelConfig:
assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest()
def __init__(
self,
model: str,
task: Literal[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
hf_config_path: Optional[str] = None,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
disable_cascade_attn: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, list[str]]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
disable_mm_preprocessor_cache: bool = False,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None,
override_neuron_config: Optional[dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: str = "auto",
enable_sleep_mode: bool = False,
override_generation_config: Optional[dict[str, Any]] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
self.model = maybe_model_redirect(model)
self.tokenizer = maybe_model_redirect(tokenizer)
def __post_init__(self) -> None:
self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default.
if self.tokenizer is None:
self.tokenizer = self.model
if self.tokenizer_revision is None:
self.tokenizer_revision = self.revision
self.tokenizer = maybe_model_redirect(self.tokenizer)
self.hf_config_path = hf_config_path
if isinstance(hf_config_path, str):
self.hf_config_path = maybe_model_redirect(hf_config_path)
if isinstance(self.hf_config_path, str):
self.hf_config_path = maybe_model_redirect(self.hf_config_path)
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.allowed_local_media_path = allowed_local_media_path
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.model_impl = model_impl
if hf_overrides is None:
hf_overrides = {}
if callable(hf_overrides):
if callable(self.hf_overrides):
hf_overrides_kw = {}
hf_overrides_fn = hf_overrides
hf_overrides_fn = self.hf_overrides
else:
hf_overrides_kw = hf_overrides
hf_overrides_kw = self.hf_overrides
hf_overrides_fn = None
if rope_scaling is not None:
hf_override: dict[str, Any] = {"rope_scaling": rope_scaling}
if self.rope_scaling:
hf_override: dict[str, Any] = {"rope_scaling": self.rope_scaling}
hf_overrides_kw.update(hf_override)
hf_overrides_str = json.dumps(hf_overrides)
hf_overrides_str = json.dumps(hf_overrides_kw)
msg = (
"`--rope-scaling` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
if rope_theta is not None:
hf_override = {"rope_theta": rope_theta}
if self.rope_theta is not None:
hf_override = {"rope_theta": self.rope_theta}
hf_overrides_kw.update(hf_override)
hf_overrides_str = json.dumps(hf_overrides)
hf_overrides_str = json.dumps(hf_overrides_kw)
msg = (
"`--rope-theta` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)
self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer)
if (backend := envs.VLLM_ATTENTION_BACKEND
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
@@ -437,20 +474,6 @@ class ModelConfig:
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
"for instructions on how to install it.")
# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
else:
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.disable_cascade_attn = disable_cascade_attn
self.skip_tokenizer_init = skip_tokenizer_init
self.enable_sleep_mode = enable_sleep_mode
from vllm.platforms import current_platform
if (self.enable_sleep_mode
@@ -458,9 +481,12 @@ class ModelConfig:
raise ValueError(
"Sleep mode is not supported on current platform.")
if isinstance(self.config_format, str):
self.config_format = ConfigFormat(self.config_format)
hf_config = get_config(self.hf_config_path or self.model,
trust_remote_code, revision, code_revision,
config_format)
self.trust_remote_code, self.revision,
self.code_revision, self.config_format)
if hf_overrides_kw:
logger.info("Overriding HF config with %s", hf_overrides_kw)
@@ -476,13 +502,8 @@ class ModelConfig:
"attention_chunk_size", None)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=hf_token, revision=revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.use_async_output_proc = use_async_output_proc
# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
self.enforce_eager = False
self.model, hf_token=self.hf_token, revision=self.revision)
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
@@ -515,18 +536,14 @@ class ModelConfig:
self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
max_model_len=self.max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=spec_target_max_model_len,
spec_target_max_model_len=self.spec_target_max_model_len,
encoder_config=self.encoder_config)
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = self._init_multimodal_config(
limit_mm_per_prompt=limit_mm_per_prompt,
mm_processor_kwargs=mm_processor_kwargs,
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
)
self.served_model_name = get_served_model_name(self.model,
self.served_model_name)
self.multimodal_config = self._init_multimodal_config()
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
@@ -535,24 +552,19 @@ class ModelConfig:
self.has_noops = self._init_has_noops()
self.has_inner_state = self._init_has_inner_state()
if current_platform.is_neuron():
self.override_neuron_config = override_neuron_config
else:
self.override_neuron_config = None
if (not current_platform.is_neuron() and self.override_neuron_config):
raise ValueError(
"`override_neuron_config` is only supported on Neuron.")
supported_tasks, task = self._resolve_task(task)
supported_tasks, task = self._resolve_task(self.task)
self.supported_tasks = supported_tasks
self.task: Final = task
self.task = task
if self.task in ("draft", "generate"):
self.truncation_side = "left"
else:
self.truncation_side = "right"
self.pooler_config = self._init_pooler_config(override_pooler_config)
self.logits_processor_pattern = logits_processor_pattern
self.generation_config = generation_config
self.override_generation_config = override_generation_config or {}
self.pooler_config = self._init_pooler_config()
self._verify_quantization()
self._verify_cuda_graph()
@@ -591,26 +603,21 @@ class ModelConfig:
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = s3_tokenizer.dir
def _init_multimodal_config(
self,
limit_mm_per_prompt: Optional[dict[str, int]],
mm_processor_kwargs: Optional[dict[str, Any]],
disable_mm_preprocessor_cache: bool,
) -> Optional["MultiModalConfig"]:
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
if self.registry.is_multimodal_model(self.architectures):
return MultiModalConfig(
limit_per_prompt=limit_mm_per_prompt or {},
mm_processor_kwargs=mm_processor_kwargs or {},
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
)
limit_per_prompt=self.limit_mm_per_prompt,
mm_processor_kwargs=self.mm_processor_kwargs,
disable_mm_preprocessor_cache=self.
disable_mm_preprocessor_cache)
if limit_mm_per_prompt:
if self.limit_mm_per_prompt:
raise ValueError("`limit_mm_per_prompt` is only supported for "
"multimodal models.")
if mm_processor_kwargs:
if self.mm_processor_kwargs:
raise ValueError("`mm_processor_kwargs` is only supported for "
"multimodal models.")
if disable_mm_preprocessor_cache:
if self.disable_mm_preprocessor_cache:
raise ValueError("`disable_mm_preprocessor_cache` is only "
"supported for multimodal models.")
@@ -620,31 +627,32 @@ class ModelConfig:
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)
def _init_pooler_config(
self,
override_pooler_config: Optional["PoolerConfig"],
) -> Optional["PoolerConfig"]:
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
if self.runner_type == "pooling":
user_config = override_pooler_config or PoolerConfig()
if isinstance(self.override_pooler_config, dict):
self.override_pooler_config = PoolerConfig(
**self.override_pooler_config)
pooler_config = self.override_pooler_config or PoolerConfig()
base_config = get_pooling_config(self.model, self.revision)
if base_config is not None:
# Only set values that are not overridden by the user
for k, v in base_config.items():
if getattr(user_config, k) is None:
setattr(user_config, k, v)
if getattr(pooler_config, k) is None:
setattr(pooler_config, k, v)
if self.is_matryoshka:
if user_config.normalize is None:
user_config.normalize = True
elif not user_config.normalize:
if pooler_config.normalize is None:
pooler_config.normalize = True
elif not pooler_config.normalize:
raise ValueError(
"`normalize` must be enabled (set to True) "
"for models that are compatible with "
"Matryoshka Representation.")
return user_config
return pooler_config
return None
@@ -662,11 +670,11 @@ class ModelConfig:
return self.registry.model_has_inner_state(self.architectures)
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
if tokenizer_mode not in get_args(TokenizerMode):
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto', 'slow', 'mistral' or 'custom'.")
f"one of {get_args(TokenizerMode)}.")
self.tokenizer_mode = tokenizer_mode
def _get_preferred_task(
@@ -781,7 +789,8 @@ class ModelConfig:
"quark", "nvfp4", "bitblas", "gptq_bitblas"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
self.quantization = cast(QuantizationMethods,
self.quantization.lower())
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config()
@@ -857,8 +866,6 @@ class ModelConfig:
"non-quantized models.", self.quantization)
def _verify_cuda_graph(self) -> None:
if self.max_seq_len_to_capture is None:
self.max_seq_len_to_capture = self.max_model_len
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)
ROCM_UNSUPPORTED_MODELS = ['mllama']
@@ -1294,7 +1301,7 @@ class ModelConfig:
@property
def runner_type(self) -> RunnerType:
return _TASK_RUNNER[self.task]
return _TASK_RUNNER[cast(_ResolvedTask, self.task)]
@property
def is_v1_compatible(self) -> bool:
@@ -2201,7 +2208,7 @@ class SpeculativeConfig:
according to the log probability settings in SamplingParams."""
# Draft model configuration
quantization: Optional[str] = None
quantization: Optional[QuantizationMethods] = None
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
@@ -2386,7 +2393,6 @@ class SpeculativeConfig:
code_revision=self.code_revision,
tokenizer_revision=self.target_model_config.
tokenizer_revision,
max_model_len=None,
spec_target_max_model_len=self.target_model_config.
max_model_len,
quantization=self.quantization,
@@ -2793,30 +2799,31 @@ class PromptAdapterConfig:
class MultiModalConfig:
"""Controls the behavior of multimodal models."""
limit_per_prompt: dict[str, int] = field(default_factory=dict)
limit_per_prompt: dict[str, int] = get_field(ModelConfig,
"limit_mm_per_prompt")
"""
The maximum number of input items allowed per prompt for each modality.
This should be a JSON string that will be parsed into a dictionary.
Defaults to 1 (V0) or 999 (V1) for each modality.
For example, to allow up to 16 images and 2 videos per prompt:
:code:`{"images": 16, "videos": 2}`
`{"images": 16, "videos": 2}`
"""
mm_processor_kwargs: Optional[dict[str, object]] = None
"""
Overrides for the multi-modal processor obtained from
:meth:`transformers.AutoProcessor.from_pretrained`.
`transformers.AutoProcessor.from_pretrained`.
The available overrides depend on the model that is being run.
For example, for Phi-3-Vision:
:code:`{"num_crops": 4}`.
`{"num_crops": 4}`.
"""
disable_mm_preprocessor_cache: bool = False
"""
If :code:`True`, disable caching of the processed multi-modal inputs.
If `True`, disable caching of the processed multi-modal inputs.
"""
def compute_hash(self) -> str:
@@ -2907,10 +2914,6 @@ class PoolerConfig:
usedforsecurity=False).hexdigest()
return hash_str
@staticmethod
def from_json(json_str: str) -> "PoolerConfig":
return PoolerConfig(**json.loads(json_str))
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,