Move PoolerConfig from config/__init__.py to config/pooler.py (#25181)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -40,6 +40,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||
MultiModalConfig)
|
||||
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||
ParallelConfig)
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.config.structured_outputs import StructuredOutputsConfig
|
||||
@@ -406,13 +407,6 @@ class ModelConfig:
|
||||
hf_overrides: HfOverrides = field(default_factory=dict)
|
||||
"""If a dictionary, contains arguments to be forwarded to the Hugging Face
|
||||
config. If a callable, it is called to update the HuggingFace config."""
|
||||
pooler_config: Optional["PoolerConfig"] = field(init=False)
|
||||
"""Pooler config which controls the behaviour of output pooling in pooling
|
||||
models."""
|
||||
override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
|
||||
"""Initialize non-default pooling config or override default pooling config
|
||||
for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`.
|
||||
"""
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
"""Optional regex pattern specifying valid logits processor qualified names
|
||||
that can be passed with the `logits_processors` extra completion argument.
|
||||
@@ -448,6 +442,14 @@ class ModelConfig:
|
||||
io_processor_plugin: Optional[str] = None
|
||||
"""IOProcessor plugin name to load at model startup"""
|
||||
|
||||
# Pooler config
|
||||
pooler_config: Optional[PoolerConfig] = None
|
||||
"""Pooler config which controls the behaviour of output pooling in pooling
|
||||
models."""
|
||||
override_pooler_config: Optional[Union[dict, PoolerConfig]] = None
|
||||
"""[DEPRECATED] Use `pooler_config` instead. This field will be removed in
|
||||
v0.12.0 or v1.0.0, whichever is sooner."""
|
||||
|
||||
# Multimodal config and init vars
|
||||
multimodal_config: Optional[MultiModalConfig] = None
|
||||
"""Configuration for multimodal model. If `None`, this will be inferred
|
||||
@@ -709,7 +711,33 @@ class ModelConfig:
|
||||
self._architecture = arch
|
||||
logger.info("Resolved architecture: %s", arch)
|
||||
|
||||
self.pooler_config = self._init_pooler_config()
|
||||
# Init pooler config if needed
|
||||
if self.runner_type == "pooling":
|
||||
if self.override_pooler_config is not None:
|
||||
logger.warning_once(
|
||||
"`override_pooler_config` is deprecated and will be "
|
||||
"removed in v0.12.0 or v1.0.0, whichever is sooner. "
|
||||
"Please use `pooler_config` instead.")
|
||||
|
||||
if isinstance(self.override_pooler_config, dict):
|
||||
self.pooler_config = PoolerConfig(
|
||||
**self.override_pooler_config)
|
||||
else:
|
||||
self.pooler_config = self.override_pooler_config
|
||||
|
||||
if self.pooler_config is None:
|
||||
self.pooler_config = PoolerConfig()
|
||||
|
||||
base_config = get_pooling_config(self.model, self.revision)
|
||||
if base_config is not None:
|
||||
# Only set values that are not overridden by the user
|
||||
for k, v in base_config.items():
|
||||
if getattr(self.pooler_config, k) is None:
|
||||
setattr(self.pooler_config, k, v)
|
||||
|
||||
default_pooling_type = self._model_info.default_pooling_type
|
||||
if self.pooler_config.pooling_type is None:
|
||||
self.pooler_config.pooling_type = default_pooling_type
|
||||
|
||||
self.dtype: torch.dtype = _get_and_verify_dtype(
|
||||
self.model,
|
||||
@@ -869,29 +897,6 @@ class ModelConfig:
|
||||
return get_sentence_transformer_tokenizer_config(
|
||||
self.model, self.revision)
|
||||
|
||||
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
|
||||
if self.runner_type == "pooling":
|
||||
if isinstance(self.override_pooler_config, dict):
|
||||
self.override_pooler_config = PoolerConfig(
|
||||
**self.override_pooler_config)
|
||||
|
||||
pooler_config = self.override_pooler_config or PoolerConfig()
|
||||
|
||||
base_config = get_pooling_config(self.model, self.revision)
|
||||
if base_config is not None:
|
||||
# Only set values that are not overridden by the user
|
||||
for k, v in base_config.items():
|
||||
if getattr(pooler_config, k) is None:
|
||||
setattr(pooler_config, k, v)
|
||||
|
||||
default_pooling_type = self._model_info.default_pooling_type
|
||||
if pooler_config.pooling_type is None:
|
||||
pooler_config.pooling_type = default_pooling_type
|
||||
|
||||
return pooler_config
|
||||
|
||||
return None
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
|
||||
if tokenizer_mode not in get_args(TokenizerMode):
|
||||
@@ -1833,94 +1838,6 @@ class DeviceConfig:
|
||||
self.device = torch.device(self.device_type)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
|
||||
pooling_type: Optional[str] = None
|
||||
"""
|
||||
The pooling method of the pooling model. This should be a key in
|
||||
[`vllm.model_executor.layers.pooler.PoolingType`][].
|
||||
"""
|
||||
|
||||
## for embeddings models
|
||||
normalize: Optional[bool] = None
|
||||
"""
|
||||
Whether to normalize the embeddings outputs. Defaults to True.
|
||||
"""
|
||||
dimensions: Optional[int] = None
|
||||
"""
|
||||
Reduce the dimensions of embeddings if model
|
||||
support matryoshka representation. Defaults to None.
|
||||
"""
|
||||
enable_chunked_processing: Optional[bool] = None
|
||||
"""
|
||||
Whether to enable chunked processing for long inputs that exceed the model's
|
||||
maximum position embeddings. When enabled, long inputs will be split into
|
||||
chunks, processed separately, and then aggregated using weighted averaging.
|
||||
This allows embedding models to handle arbitrarily long text without CUDA
|
||||
errors. Defaults to False.
|
||||
"""
|
||||
max_embed_len: Optional[int] = None
|
||||
"""
|
||||
Maximum input length allowed for embedding generation. When set, allows
|
||||
inputs longer than max_embed_len to be accepted for embedding models.
|
||||
When an input exceeds max_embed_len, it will be handled according to
|
||||
the original max_model_len validation logic.
|
||||
Defaults to None (i.e. set to max_model_len).
|
||||
"""
|
||||
|
||||
## for classification models
|
||||
activation: Optional[bool] = None
|
||||
"""
|
||||
Whether to apply activation function to the classification outputs.
|
||||
Defaults to True.
|
||||
"""
|
||||
logit_bias: Optional[float] = None
|
||||
"""
|
||||
If provided, apply classification logit biases. Defaults to None.
|
||||
"""
|
||||
|
||||
## for reward models
|
||||
softmax: Optional[bool] = None
|
||||
"""
|
||||
Whether to apply softmax to the reward outputs.
|
||||
Defaults to True.
|
||||
"""
|
||||
step_tag_id: Optional[int] = None
|
||||
"""
|
||||
If set, only the score corresponding to the ``step_tag_id`` in the
|
||||
generated sentence should be returned. Otherwise, the scores for all tokens
|
||||
are returned.
|
||||
"""
|
||||
returned_token_ids: Optional[list[int]] = None
|
||||
"""
|
||||
A list of indices for the vocabulary dimensions to be extracted,
|
||||
such as the token IDs of ``good_token`` and ``bad_token`` in the
|
||||
``math-shepherd-mistral-7b-prm`` model.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
"float16": torch.float16,
|
||||
|
||||
Reference in New Issue
Block a user