Add explicit pooling classes for the Transformers backend (#25322)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -19,6 +19,7 @@ import vllm.envs as envs
|
||||
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||
MultiModalConfig)
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import RunnerType
|
||||
from vllm.config.utils import assert_hashable, config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -40,7 +41,6 @@ if TYPE_CHECKING:
|
||||
import vllm.model_executor.models as me_models
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.scheduler import RunnerType
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
else:
|
||||
@@ -52,13 +52,12 @@ else:
|
||||
"vllm.model_executor.models")
|
||||
LoadConfig = Any
|
||||
ParallelConfig = Any
|
||||
RunnerType = Any
|
||||
QuantizationMethods = Any
|
||||
LogitsProcessor = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
|
||||
RunnerOption = Literal["auto", RunnerType]
|
||||
ConvertType = Literal["none", "embed", "classify", "reward"]
|
||||
ConvertOption = Literal["auto", ConvertType]
|
||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||
@@ -668,8 +667,28 @@ class ModelConfig:
|
||||
def _get_transformers_backend_cls(self) -> str:
|
||||
"""Determine which Transformers backend class will be used if
|
||||
`model_impl` is set to `transformers` or `auto`."""
|
||||
if getattr(self, "runner_type", self.runner) == "pooling":
|
||||
return "TransformersModel"
|
||||
# Check if the architecture we're wrapping has defaults
|
||||
runner = None
|
||||
convert = None
|
||||
if defaults := try_match_architecture_defaults(self.architectures[0]):
|
||||
_, (runner, convert) = defaults
|
||||
# Overwrite with user-specified values
|
||||
if self.runner != "auto":
|
||||
runner = self.runner
|
||||
if self.convert not in {"auto", "none"}:
|
||||
convert = self.convert
|
||||
# Fall back to default values if still not set
|
||||
if runner is None:
|
||||
runner = "generate"
|
||||
if convert in {None, "none"}:
|
||||
convert = "embed"
|
||||
# Resolve Transformers backend pooling classes
|
||||
if runner == "pooling":
|
||||
if convert == "embed":
|
||||
return "TransformersEmbeddingModel"
|
||||
if convert == "classify":
|
||||
return "TransformersForSequenceClassification"
|
||||
# Resolve Transformers backend generate classes
|
||||
if self.hf_config != self.hf_text_config:
|
||||
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
|
||||
# probably a composite config, i.e. multimodal
|
||||
@@ -678,7 +697,9 @@ class ModelConfig:
|
||||
|
||||
def using_transformers_backend(self) -> bool:
|
||||
"""Check if the model is using the Transformers backend class."""
|
||||
return self.architecture == self._get_transformers_backend_cls()
|
||||
used_cls = self._model_info.architecture
|
||||
transformers_backend_cls = self._get_transformers_backend_cls()
|
||||
return used_cls == transformers_backend_cls
|
||||
|
||||
@property
|
||||
def registry(self):
|
||||
|
||||
Reference in New Issue
Block a user