[Misc] Consolidate pooler config overrides (#10351)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
112
vllm/config.py
112
vllm/config.py
@@ -112,10 +112,6 @@ class ModelConfig:
|
||||
the model name will be the same as `model`.
|
||||
limit_mm_per_prompt: Maximum number of data items per modality
|
||||
per prompt. Only applicable for multimodal models.
|
||||
override_neuron_config: Initialize non default neuron config or
|
||||
override default neuron config that are specific to Neuron devices,
|
||||
this argument will be used to configure the neuron config that
|
||||
can not be gathered from the vllm arguments.
|
||||
config_format: The config format which shall be loaded.
|
||||
Defaults to 'auto' which defaults to 'hf'.
|
||||
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
||||
@@ -123,20 +119,12 @@ class ModelConfig:
|
||||
HuggingFace config.
|
||||
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
||||
for multi-modal data, e.g., image processor.
|
||||
pooling_type: Used to configure the pooling method in the embedding
|
||||
model.
|
||||
pooling_norm: Used to determine whether to normalize the pooled
|
||||
data in the embedding model.
|
||||
pooling_softmax: Used to determine whether to softmax the pooled
|
||||
data in the embedding model.
|
||||
pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates
|
||||
that the score corresponding to the pooling_step_tag_id in the
|
||||
generated sentence should be returned. Otherwise, it returns
|
||||
the scores for all tokens.
|
||||
pooling_returned_token_ids: pooling_returned_token_ids represents a
|
||||
list of indices for the vocabulary dimensions to be extracted,
|
||||
such as the token IDs of good_token and bad_token in the
|
||||
math-shepherd-mistral-7b-prm model.
|
||||
override_neuron_config: Initialize non default neuron config or
|
||||
override default neuron config that are specific to Neuron devices,
|
||||
this argument will be used to configure the neuron config that
|
||||
can not be gathered from the vllm arguments.
|
||||
override_pooling_config: Initialize non default pooling config or
|
||||
override default pooling config for the embedding model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -166,16 +154,12 @@ class ModelConfig:
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
chat_template_text_format: str = "string",
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
pooling_type: Optional[str] = None,
|
||||
pooling_norm: Optional[bool] = None,
|
||||
pooling_softmax: Optional[bool] = None,
|
||||
pooling_step_tag_id: Optional[int] = None,
|
||||
pooling_returned_token_ids: Optional[List[int]] = None) -> None:
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
@@ -280,13 +264,7 @@ class ModelConfig:
|
||||
supported_tasks, task = self._resolve_task(task, self.hf_config)
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task: Final = task
|
||||
self.pooler_config = self._init_pooler_config(
|
||||
pooling_type,
|
||||
pooling_norm,
|
||||
pooling_softmax,
|
||||
pooling_step_tag_id,
|
||||
pooling_returned_token_ids,
|
||||
)
|
||||
self.pooler_config = self._init_pooler_config(override_pooler_config)
|
||||
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
@@ -311,27 +289,21 @@ class ModelConfig:
|
||||
|
||||
def _init_pooler_config(
|
||||
self,
|
||||
pooling_type: Optional[str] = None,
|
||||
pooling_norm: Optional[bool] = None,
|
||||
pooling_softmax: Optional[bool] = None,
|
||||
pooling_step_tag_id: Optional[int] = None,
|
||||
pooling_returned_token_ids: Optional[List[int]] = None
|
||||
override_pooler_config: Optional["PoolerConfig"],
|
||||
) -> Optional["PoolerConfig"]:
|
||||
|
||||
if self.task == "embedding":
|
||||
pooling_config = get_pooling_config(self.model, self.revision)
|
||||
if pooling_config is not None:
|
||||
# override if user does not
|
||||
# specifies pooling_type and/or pooling_norm
|
||||
if pooling_type is None:
|
||||
pooling_type = pooling_config["pooling_type"]
|
||||
if pooling_norm is None:
|
||||
pooling_norm = pooling_config["normalize"]
|
||||
return PoolerConfig(
|
||||
pooling_type=pooling_type,
|
||||
pooling_norm=pooling_norm,
|
||||
pooling_softmax=pooling_softmax,
|
||||
pooling_step_tag_id=pooling_step_tag_id,
|
||||
pooling_returned_token_ids=pooling_returned_token_ids)
|
||||
user_config = override_pooler_config or PoolerConfig()
|
||||
|
||||
base_config = get_pooling_config(self.model, self.revision)
|
||||
if base_config is not None:
|
||||
# Only set values that are not overridden by the user
|
||||
for k, v in base_config.items():
|
||||
if getattr(user_config, k) is None:
|
||||
setattr(user_config, k, v)
|
||||
|
||||
return user_config
|
||||
|
||||
return None
|
||||
|
||||
def _init_attention_free(self) -> bool:
|
||||
@@ -1786,13 +1758,43 @@ class MultiModalConfig:
|
||||
|
||||
@dataclass
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of pooler in embedding model"""
|
||||
"""Controls the behavior of output pooling in embedding models."""
|
||||
|
||||
pooling_type: Optional[str] = None
|
||||
pooling_norm: Optional[bool] = None
|
||||
pooling_softmax: Optional[bool] = None
|
||||
pooling_step_tag_id: Optional[int] = None
|
||||
pooling_returned_token_ids: Optional[List[int]] = None
|
||||
"""
|
||||
The pooling method of the embedding model. This should be a key in
|
||||
:class:`vllm.model_executor.layers.pooler.PoolingType`.
|
||||
"""
|
||||
|
||||
normalize: Optional[bool] = None
|
||||
"""
|
||||
Whether to normalize the pooled outputs. Usually, this should be set to
|
||||
``True`` for embedding outputs.
|
||||
"""
|
||||
|
||||
softmax: Optional[bool] = None
|
||||
"""
|
||||
Whether to apply softmax to the pooled outputs. Usually, this should be set
|
||||
to ``True`` for classification outputs.
|
||||
"""
|
||||
|
||||
step_tag_id: Optional[int] = None
|
||||
"""
|
||||
If set, only the score corresponding to the ``step_tag_id`` in the
|
||||
generated sentence should be returned. Otherwise, the scores for all tokens
|
||||
are returned.
|
||||
"""
|
||||
|
||||
returned_token_ids: Optional[List[int]] = None
|
||||
"""
|
||||
A list of indices for the vocabulary dimensions to be extracted,
|
||||
such as the token IDs of ``good_token`` and ``bad_token`` in the
|
||||
``math-shepherd-mistral-7b-prm`` model.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_json(json_str: str) -> "PoolerConfig":
|
||||
return PoolerConfig(**json.loads(json_str))
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
|
||||
Reference in New Issue
Block a user