Add explicit pooling classes for the Transformers backend (#25322)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Harry Mellor
2025-09-30 23:07:06 +01:00
committed by GitHub
parent 9a9f48dff7
commit a388252ac4
7 changed files with 295 additions and 137 deletions

View File

@@ -19,6 +19,7 @@ import vllm.envs as envs
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
MultiModalConfig)
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -40,7 +41,6 @@ if TYPE_CHECKING:
import vllm.model_executor.models as me_models
from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.scheduler import RunnerType
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.v1.sample.logits_processor import LogitsProcessor
else:
@@ -52,13 +52,12 @@ else:
"vllm.model_executor.models")
LoadConfig = Any
ParallelConfig = Any
RunnerType = Any
QuantizationMethods = Any
LogitsProcessor = Any
logger = init_logger(__name__)
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
RunnerOption = Literal["auto", RunnerType]
ConvertType = Literal["none", "embed", "classify", "reward"]
ConvertOption = Literal["auto", ConvertType]
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
@@ -668,8 +667,28 @@ class ModelConfig:
def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
if getattr(self, "runner_type", self.runner) == "pooling":
return "TransformersModel"
# Check if the architecture we're wrapping has defaults
runner = None
convert = None
if defaults := try_match_architecture_defaults(self.architectures[0]):
_, (runner, convert) = defaults
# Overwrite with user-specified values
if self.runner != "auto":
runner = self.runner
if self.convert not in {"auto", "none"}:
convert = self.convert
# Fall back to default values if still not set
if runner is None:
runner = "generate"
if convert in {None, "none"}:
convert = "embed"
# Resolve Transformers backend pooling classes
if runner == "pooling":
if convert == "embed":
return "TransformersEmbeddingModel"
if convert == "classify":
return "TransformersForSequenceClassification"
# Resolve Transformers backend generate classes
if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal
@@ -678,7 +697,9 @@ class ModelConfig:
def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers backend class."""
return self.architecture == self._get_transformers_backend_cls()
used_cls = self._model_info.architecture
transformers_backend_cls = self._get_transformers_backend_cls()
return used_cls == transformers_backend_cls
@property
def registry(self):