diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index b09e76015..59e768853 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -25,11 +25,11 @@ from vllm.model_executor.layers.pooler import ( PoolingParamsUpdate, ) from vllm.model_executor.layers.pooler.seqwise import ( - CLSPool, SequencePooler, SequencePoolerHeadOutput, SequencePoolerOutput, SequencePoolingMethodOutput, + get_seq_pooling_method, ) from vllm.model_executor.layers.pooler.tokwise import ( pooler_for_token_classify, @@ -94,9 +94,9 @@ class BertEmbedding(nn.Module): class BertPooler(SequencePooler): - def __init__(self, config: BertConfig): + def __init__(self, config: BertConfig, pooler_config: PoolerConfig): super().__init__( - pooling=CLSPool(), + pooling=get_seq_pooling_method(pooler_config.seq_pooling_type), head=self.head, ) @@ -450,7 +450,11 @@ class BertPoolingModel(BertModel): ) config = vllm_config.model_config.hf_config - self.pooler = BertPooler(config) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = BertPooler(config, pooler_config) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: other_weights, loaded_stacked_params = self._load_weights(weights) @@ -711,6 +715,8 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), ) + # None of vLLM's built-in sequence pooling types are + # applicable so it is overwritten by SPLADESparsePooler pooling_mode = getattr(self, "_splade_pooling", "max") cls_id = getattr(cfg, "cls_token_id", None) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index cfe350db1..8f9617062 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -453,6 +453,7 @@ class BertWithRope(nn.Module, SupportsQuant): add_pooling_layer: bool = False, ): super().__init__() + self.vllm_config = vllm_config self.add_pooling_layer = add_pooling_layer self.config = vllm_config.model_config.hf_config @@ -463,7 +464,14 @@ class BertWithRope(nn.Module, SupportsQuant): rotary_kwargs=self.config.rotary_kwargs, prefix=f"{prefix}.encoder", ) - self.pooler = BertPooler(self.config) if add_pooling_layer else None + + if add_pooling_layer: + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = BertPooler(self.config, pooler_config) + else: + self.pooler = None def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings(input_ids) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 34d7e5c92..08ace0c8e 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -5,7 +5,7 @@ from collections.abc import Set import numpy as np import torch -from vllm.config import ModelConfig, VllmConfig +from vllm.config import ModelConfig, PoolerConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.pooler import ( DispatchPooler, @@ -17,6 +17,7 @@ from vllm.model_executor.layers.pooler.seqwise import ( SequencePoolerHeadOutput, SequencePoolingMethod, SequencePoolingMethodOutput, + get_seq_pooling_method, ) from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed from vllm.model_executor.models.llama import LlamaForCausalLM @@ -177,9 +178,13 @@ class GritLMMeanPool(SequencePoolingMethod): class GritLMPooler(SequencePooler): - def __init__(self, model_config: ModelConfig): + def __init__(self, model_config: ModelConfig, pooler_config: PoolerConfig): super().__init__( - pooling=GritLMMeanPool(model_config), + pooling=( + GritLMMeanPool(model_config) + if pooler_config.seq_pooling_type == "MEAN" + else get_seq_pooling_method(pooler_config.seq_pooling_type) + ), head=self.head, ) @@ -235,6 +240,6 @@ class GritLM(LlamaForCausalLM): self.pooler = DispatchPooler( { "token_embed": pooler_for_token_embed(pooler_config), - "embed": GritLMPooler(vllm_config.model_config), + "embed": GritLMPooler(vllm_config.model_config, pooler_config), } ) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index b80258daf..2b56540e6 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -8,7 +8,7 @@ from transformers import ModernBertConfig from transformers.activations import ACT2FN from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.attention.encoder_only_attention import ( EncoderOnlyAttention, @@ -282,9 +282,14 @@ class ModernBertModel(nn.Module): class ModernBertPooler(SequencePooler): - def __init__(self, config: ModernBertConfig): + def __init__(self, config: ModernBertConfig, pooler_config: PoolerConfig): + hf_pooling_type = config.classifier_pooling.upper() + # vllm_pooling_type = pooler_config.seq_pooling_type + # Currently we don't have a way to see if the user set the pooling type + # explicitly or not, so we always use the HF pooling type for now. + super().__init__( - pooling=get_seq_pooling_method(config.classifier_pooling.upper()), + pooling=get_seq_pooling_method(hf_pooling_type), head=self.head, ) @@ -314,7 +319,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config + self.config = config self.model = ModernBertModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert") @@ -324,11 +331,12 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): config.num_labels, dtype=vllm_config.model_config.head_dtype, ) - self.pooling = ModernBertPooler(config) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None + self.pooling = ModernBertPooler(config, pooler_config) + self.pooler = DispatchPooler.for_seq_cls( pooler_config, pooling=self.pooling, diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index f52123901..7bf9a6882 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,7 +9,6 @@ from transformers import RobertaConfig from vllm.config import ModelConfig, VllmConfig from vllm.model_executor.layers.pooler import DispatchPooler -from vllm.model_executor.layers.pooler.seqwise import CLSPool from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.models.bert import ( TOKEN_TYPE_SHIFT, @@ -86,7 +85,7 @@ class RobertaClassificationHead(nn.Module): ) def forward(self, x: torch.Tensor) -> torch.Tensor: - # CLSPool has already been applied in `pooling` + # Token extraction has already been applied in `pooler.pooling` x = self.dense(x) x = torch.tanh(x) x = self.out_proj(x) @@ -194,7 +193,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): self.pooler = DispatchPooler.for_seq_cls( pooler_config, - pooling=CLSPool(), classifier=self.classifier, ) diff --git a/vllm/model_executor/models/transformers/pooling.py b/vllm/model_executor/models/transformers/pooling.py index 470ca48ee..8f3173c33 100644 --- a/vllm/model_executor/models/transformers/pooling.py +++ b/vllm/model_executor/models/transformers/pooling.py @@ -23,7 +23,6 @@ from transformers import AutoModelForSequenceClassification from vllm.config.utils import getattr_iter from vllm.model_executor.layers.pooler import DispatchPooler -from vllm.model_executor.layers.pooler.seqwise import CLSPool from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.interfaces_base import VllmModelForPooling @@ -32,7 +31,7 @@ if TYPE_CHECKING: class EmbeddingMixin(VllmModelForPooling): - default_pooling_type = "CLS" + default_seq_pooling_type = "CLS" def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): # Skip VllmModelForPooling.__init__ and call the next class in MRO @@ -47,7 +46,7 @@ class EmbeddingMixin(VllmModelForPooling): class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): - default_pooling_type = "CLS" + default_seq_pooling_type = "CLS" def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): # Skip VllmModelForPooling.__init__ and call the next class in MRO @@ -85,8 +84,10 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): self.init_parameters(self.classifier, dtype=self.model_config.head_dtype) class ClassifierWithReshape(self.classifier.__class__): - """CLSPool has already been applied in `pooling`. - Add dim to match expected input shape of `classifier.forward`.""" + """ + Token extraction has already been applied in `pooler.pooling`. + Add dim to match expected input shape of `classifier.forward`. + """ def forward(self, *args, **kwargs): if len(args) > 0: @@ -97,6 +98,5 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): self.pooler = DispatchPooler.for_seq_cls( pooler_config, - pooling=CLSPool(), classifier=self.classifier, )