[Model] Avoid hardcoding pooling type (#32119)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -23,7 +23,6 @@ from transformers import AutoModelForSequenceClassification
|
||||
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.pooler.seqwise import CLSPool
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
|
||||
|
||||
@@ -32,7 +31,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class EmbeddingMixin(VllmModelForPooling):
|
||||
default_pooling_type = "CLS"
|
||||
default_seq_pooling_type = "CLS"
|
||||
|
||||
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
# Skip VllmModelForPooling.__init__ and call the next class in MRO
|
||||
@@ -47,7 +46,7 @@ class EmbeddingMixin(VllmModelForPooling):
|
||||
|
||||
|
||||
class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
|
||||
default_pooling_type = "CLS"
|
||||
default_seq_pooling_type = "CLS"
|
||||
|
||||
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
# Skip VllmModelForPooling.__init__ and call the next class in MRO
|
||||
@@ -85,8 +84,10 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
|
||||
self.init_parameters(self.classifier, dtype=self.model_config.head_dtype)
|
||||
|
||||
class ClassifierWithReshape(self.classifier.__class__):
|
||||
"""CLSPool has already been applied in `pooling`.
|
||||
Add dim to match expected input shape of `classifier.forward`."""
|
||||
"""
|
||||
Token extraction has already been applied in `pooler.pooling`.
|
||||
Add dim to match expected input shape of `classifier.forward`.
|
||||
"""
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if len(args) > 0:
|
||||
@@ -97,6 +98,5 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
|
||||
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config,
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user