diff --git a/vllm/model_executor/layers/pooler/common.py b/vllm/model_executor/layers/pooler/common.py index 7dc77cf79..d8aa78e70 100644 --- a/vllm/model_executor/layers/pooler/common.py +++ b/vllm/model_executor/layers/pooler/common.py @@ -2,12 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable from dataclasses import dataclass +from typing import TypeVar import torch from vllm.pooling_params import PoolingParams +_T = TypeVar("_T", bound=torch.Tensor | list[torch.Tensor]) + +ProjectorFn = Callable[[torch.Tensor], torch.Tensor] ClassifierFn = Callable[[torch.Tensor], torch.Tensor] +ActivationFn = Callable[[_T], _T] @dataclass(frozen=True) @@ -24,4 +29,4 @@ class PoolingParamsUpdate: params.requires_token_ids = self.requires_token_ids -__all__ = ["ClassifierFn", "PoolingParamsUpdate"] +__all__ = ["ActivationFn", "ClassifierFn", "ProjectorFn", "PoolingParamsUpdate"] diff --git a/vllm/model_executor/layers/pooler/seqwise/heads.py b/vllm/model_executor/layers/pooler/seqwise/heads.py index 24aed94fd..21a94a89e 100644 --- a/vllm/model_executor/layers/pooler/seqwise/heads.py +++ b/vllm/model_executor/layers/pooler/seqwise/heads.py @@ -7,14 +7,7 @@ from typing import TypeAlias import torch import torch.nn as nn -from vllm.config import get_current_vllm_config -from vllm.model_executor.layers.pooler import ClassifierFn -from vllm.model_executor.layers.pooler.activations import ( - PoolerActivation, - PoolerNormalize, - resolve_classifier_act_fn, -) -from vllm.model_executor.models.adapters import _load_st_projector +from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn from vllm.tasks import PoolingTask from vllm.v1.pool.metadata import PoolingMetadata @@ -38,17 +31,17 @@ class SequencePoolerHead(nn.Module, ABC): class EmbeddingPoolerHead(SequencePoolerHead): - def __init__(self) -> None: + def __init__( + self, + projector: ProjectorFn | None = None, + head_dtype: torch.dtype | str | None = None, + activation: ActivationFn | None = None, + ) -> None: super().__init__() - # Load ST projector if available - vllm_config = get_current_vllm_config() - model_config = vllm_config.model_config - - self.projector = _load_st_projector(model_config) - self.head_dtype = model_config.head_dtype - - self.activation = PoolerNormalize() + self.projector = projector + self.head_dtype = head_dtype + self.activation = activation def get_supported_tasks(self) -> Set[PoolingTask]: return {"embed"} @@ -65,7 +58,8 @@ class EmbeddingPoolerHead(SequencePoolerHead): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_dimension] - pooled_data = pooled_data.to(self.head_dtype) + if self.head_dtype is not None: + pooled_data = pooled_data.to(self.head_dtype) # Apply ST projector if self.projector is not None: @@ -88,15 +82,16 @@ class EmbeddingPoolerHead(SequencePoolerHead): ] # for normalize - flags = [p.normalize for p in pooling_params] - if len(set(flags)) == 1: - if flags[0]: - pooled_data = self.activation(pooled_data) - else: - pooled_data = [ - self.activation(vecs) if f else vecs - for vecs, f in zip(pooled_data, flags) - ] + if self.activation is not None: + flags = [p.normalize for p in pooling_params] + if len(set(flags)) == 1: + if flags[0]: + pooled_data = self.activation(pooled_data) + else: + pooled_data = [ + self.activation(vecs) if f else vecs + for vecs, f in zip(pooled_data, flags) + ] # pooled_data shape: [batchsize, embedding_dimension] return pooled_data @@ -106,20 +101,16 @@ class ClassifierPoolerHead(SequencePoolerHead): def __init__( self, classifier: ClassifierFn | None = None, - act_fn: PoolerActivation | str | None = None, + logit_bias: float | None = None, + head_dtype: torch.dtype | str | None = None, + activation: ActivationFn | None = None, ) -> None: super().__init__() - vllm_config = get_current_vllm_config() - model_config = vllm_config.model_config - self.classifier = classifier - self.logit_bias: float | None = model_config.pooler_config.logit_bias - self.head_dtype = model_config.head_dtype - - self.act_fn = resolve_classifier_act_fn( - model_config, static_num_labels=True, act_fn=act_fn - ) + self.logit_bias = logit_bias + self.head_dtype = head_dtype + self.activation = activation def get_supported_tasks(self) -> Set[PoolingTask]: return {"classify", "score"} @@ -136,7 +127,8 @@ class ClassifierPoolerHead(SequencePoolerHead): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_size] - pooled_data = pooled_data.to(self.head_dtype) + if self.head_dtype is not None: + pooled_data = pooled_data.to(self.head_dtype) if self.classifier is not None: pooled_data = self.classifier(pooled_data) @@ -145,13 +137,15 @@ class ClassifierPoolerHead(SequencePoolerHead): if self.logit_bias is not None: pooled_data -= self.logit_bias - flags = [p.use_activation for p in pooling_params] - if len(set(flags)) == 1: - scores = self.act_fn(pooled_data) if flags[0] else pooled_data - else: - scores = [ - self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags) - ] + if self.activation is not None: + flags = [p.use_activation for p in pooling_params] + if len(set(flags)) == 1: + pooled_data = self.activation(pooled_data) if flags[0] else pooled_data + else: + pooled_data = [ + self.activation(vecs) if f else vecs + for vecs, f in zip(pooled_data, flags) + ] - # scores shape: [batchsize, num_labels] - return scores + # pooled_data shape: [batchsize, num_labels] + return pooled_data diff --git a/vllm/model_executor/layers/pooler/seqwise/poolers.py b/vllm/model_executor/layers/pooler/seqwise/poolers.py index db867fb60..eb058849e 100644 --- a/vllm/model_executor/layers/pooler/seqwise/poolers.py +++ b/vllm/model_executor/layers/pooler/seqwise/poolers.py @@ -5,10 +5,15 @@ from typing import TypeAlias import torch -from vllm.config import PoolerConfig +from vllm.config import PoolerConfig, get_current_vllm_config from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate from vllm.model_executor.layers.pooler.abstract import Pooler -from vllm.model_executor.layers.pooler.activations import PoolerActivation +from vllm.model_executor.layers.pooler.activations import ( + PoolerActivation, + PoolerNormalize, + resolve_classifier_act_fn, +) +from vllm.model_executor.models.adapters import _load_st_projector from vllm.tasks import POOLING_TASKS, PoolingTask from vllm.v1.pool.metadata import PoolingMetadata @@ -86,7 +91,14 @@ class SequencePooler(Pooler): def pooler_for_embed(pooler_config: PoolerConfig): pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type()) - head = EmbeddingPoolerHead() + + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + head = EmbeddingPoolerHead( + projector=_load_st_projector(model_config), + head_dtype=model_config.head_dtype, + activation=PoolerNormalize(), + ) return SequencePooler(pooling=pooling, head=head) @@ -101,6 +113,15 @@ def pooler_for_classify( if pooling is None: pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type()) - head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn) + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + head = ClassifierPoolerHead( + classifier=classifier, + logit_bias=model_config.pooler_config.logit_bias, + head_dtype=model_config.head_dtype, + activation=resolve_classifier_act_fn( + model_config, static_num_labels=True, act_fn=act_fn + ), + ) return SequencePooler(pooling=pooling, head=head) diff --git a/vllm/model_executor/layers/pooler/tokwise/heads.py b/vllm/model_executor/layers/pooler/tokwise/heads.py index 7421ff5c2..923b9b977 100644 --- a/vllm/model_executor/layers/pooler/tokwise/heads.py +++ b/vllm/model_executor/layers/pooler/tokwise/heads.py @@ -7,14 +7,7 @@ from typing import TypeAlias import torch import torch.nn as nn -from vllm.config import get_current_vllm_config -from vllm.model_executor.layers.pooler import ClassifierFn -from vllm.model_executor.layers.pooler.activations import ( - PoolerActivation, - PoolerNormalize, - resolve_classifier_act_fn, -) -from vllm.model_executor.models.adapters import _load_st_projector +from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn from vllm.pooling_params import PoolingParams from vllm.tasks import PoolingTask from vllm.v1.pool.metadata import PoolingMetadata @@ -49,17 +42,17 @@ class TokenPoolerHead(nn.Module, ABC): class TokenEmbeddingPoolerHead(TokenPoolerHead): - def __init__(self) -> None: + def __init__( + self, + projector: ProjectorFn | None = None, + head_dtype: torch.dtype | str | None = None, + activation: ActivationFn | None = None, + ) -> None: super().__init__() - # Load ST projector if available - vllm_config = get_current_vllm_config() - model_config = vllm_config.model_config - - self.projector = _load_st_projector(model_config) - self.head_dtype = model_config.head_dtype - - self.activation = PoolerNormalize() + self.projector = projector + self.head_dtype = head_dtype + self.activation = activation def get_supported_tasks(self) -> Set[PoolingTask]: return {"token_embed"} @@ -73,7 +66,8 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead): if pooled_data is None: return None - pooled_data = pooled_data.to(self.head_dtype) + if self.head_dtype is not None: + pooled_data = pooled_data.to(self.head_dtype) # pooled_data shape: [n_tokens, hidden_dimension] # Apply ST projector @@ -85,7 +79,7 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead): pooled_data = pooled_data[..., : pooling_param.dimensions] # for normalize - if pooling_param.normalize: + if self.activation is not None and pooling_param.normalize: pooled_data = self.activation(pooled_data) # pooled_data shape: [n_tokens, embedding_dimension] @@ -96,20 +90,16 @@ class TokenClassifierPoolerHead(TokenPoolerHead): def __init__( self, classifier: ClassifierFn | None = None, - act_fn: PoolerActivation | str | None = None, + logit_bias: float | None = None, + head_dtype: torch.dtype | str | None = None, + activation: ActivationFn | None = None, ) -> None: super().__init__() - vllm_config = get_current_vllm_config() - model_config = vllm_config.model_config - self.classifier = classifier - self.logit_bias: float | None = model_config.pooler_config.logit_bias - self.head_dtype = model_config.head_dtype - - self.act_fn = resolve_classifier_act_fn( - model_config, static_num_labels=False, act_fn=act_fn - ) + self.logit_bias = logit_bias + self.head_dtype = head_dtype + self.activation = activation def get_supported_tasks(self) -> Set[PoolingTask]: return {"token_classify"} @@ -123,7 +113,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead): if pooled_data is None: return None - pooled_data = pooled_data.to(self.head_dtype) + if self.head_dtype is not None: + pooled_data = pooled_data.to(self.head_dtype) # hidden_states shape: [n_token, hidden_size] if self.classifier is not None: @@ -135,8 +126,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead): if self.logit_bias is not None: scores -= self.logit_bias - if pooling_param.use_activation: - scores = self.act_fn(scores) + if self.activation is not None and pooling_param.use_activation: + scores = self.activation(scores) # scores shape: [n_token, num_labels] return scores diff --git a/vllm/model_executor/layers/pooler/tokwise/poolers.py b/vllm/model_executor/layers/pooler/tokwise/poolers.py index 991daaeba..8b4fe5568 100644 --- a/vllm/model_executor/layers/pooler/tokwise/poolers.py +++ b/vllm/model_executor/layers/pooler/tokwise/poolers.py @@ -5,10 +5,15 @@ from typing import TypeAlias import torch -from vllm.config import PoolerConfig +from vllm.config import PoolerConfig, get_current_vllm_config from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate from vllm.model_executor.layers.pooler.abstract import Pooler -from vllm.model_executor.layers.pooler.activations import PoolerActivation +from vllm.model_executor.layers.pooler.activations import ( + PoolerActivation, + PoolerNormalize, + resolve_classifier_act_fn, +) +from vllm.model_executor.models.adapters import _load_st_projector from vllm.tasks import POOLING_TASKS, PoolingTask from vllm.v1.pool.metadata import PoolingMetadata @@ -86,7 +91,14 @@ class TokenPooler(Pooler): def pooler_for_token_embed(pooler_config: PoolerConfig): pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type()) - head = TokenEmbeddingPoolerHead() + + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + head = TokenEmbeddingPoolerHead( + projector=_load_st_projector(model_config), + head_dtype=model_config.head_dtype, + activation=PoolerNormalize(), + ) return TokenPooler(pooling=pooling, head=head) @@ -101,6 +113,15 @@ def pooler_for_token_classify( if pooling is None: pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type()) - head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn) + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + head = TokenClassifierPoolerHead( + classifier=classifier, + logit_bias=model_config.pooler_config.logit_bias, + head_dtype=model_config.head_dtype, + activation=resolve_classifier_act_fn( + model_config, static_num_labels=False, act_fn=act_fn + ), + ) return TokenPooler(pooling=pooling, head=head) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 59e768853..10952bcd9 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -8,7 +8,7 @@ from torch import nn from transformers import BertConfig from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, PoolerConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention.encoder_only_attention import ( @@ -24,11 +24,11 @@ from vllm.model_executor.layers.pooler import ( Pooler, PoolingParamsUpdate, ) +from vllm.model_executor.layers.pooler.activations import LambdaPoolerActivation from vllm.model_executor.layers.pooler.seqwise import ( + EmbeddingPoolerHead, SequencePooler, - SequencePoolerHeadOutput, SequencePoolerOutput, - SequencePoolingMethodOutput, get_seq_pooling_method, ) from vllm.model_executor.layers.pooler.tokwise import ( @@ -94,26 +94,32 @@ class BertEmbedding(nn.Module): class BertPooler(SequencePooler): - def __init__(self, config: BertConfig, pooler_config: PoolerConfig): + def __init__(self, model_config: ModelConfig): + pooler_config = model_config.pooler_config + assert pooler_config is not None + + config: BertConfig = model_config.hf_config + super().__init__( pooling=get_seq_pooling_method(pooler_config.seq_pooling_type), - head=self.head, + # We set this dummy to avoid adding parameters to nn.Module too early + head=nn.Identity(), ) - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() + head_dtype = model_config.head_dtype + self.dense = nn.Linear( + config.hidden_size, + config.hidden_size, + dtype=head_dtype, + ) + self.act_fn = nn.Tanh() - def head( - self, - pooled_data: SequencePoolingMethodOutput, - pooling_metadata: PoolingMetadata, - ) -> SequencePoolerHeadOutput: - if isinstance(pooled_data, list): - pooled_data = torch.stack(pooled_data) - - pooled_data = self.dense(pooled_data) - pooled_data = self.activation(pooled_data) - return pooled_data + # Use lambdas so that weights are not registered under `self.head` + self.head = EmbeddingPoolerHead( + projector=lambda x: self.dense(x), + head_dtype=head_dtype, + activation=LambdaPoolerActivation(self.act_fn), + ) class BertEncoder(nn.Module): @@ -449,12 +455,7 @@ class BertPoolingModel(BertModel): embedding_class=embedding_class, ) - config = vllm_config.model_config.hf_config - - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - - self.pooler = BertPooler(config, pooler_config) + self.pooler = BertPooler(vllm_config.model_config) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: other_weights, loaded_stacked_params = self._load_weights(weights) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 8f9617062..02950dc9e 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -466,10 +466,7 @@ class BertWithRope(nn.Module, SupportsQuant): ) if add_pooling_layer: - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - - self.pooler = BertPooler(self.config, pooler_config) + self.pooler = BertPooler(vllm_config.model_config) else: self.pooler = None diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 08ace0c8e..b5c6946b6 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -5,7 +5,7 @@ from collections.abc import Set import numpy as np import torch -from vllm.config import ModelConfig, PoolerConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.pooler import ( DispatchPooler, @@ -13,8 +13,8 @@ from vllm.model_executor.layers.pooler import ( ) from vllm.model_executor.layers.pooler.activations import PoolerNormalize from vllm.model_executor.layers.pooler.seqwise import ( + EmbeddingPoolerHead, SequencePooler, - SequencePoolerHeadOutput, SequencePoolingMethod, SequencePoolingMethodOutput, get_seq_pooling_method, @@ -178,25 +178,22 @@ class GritLMMeanPool(SequencePoolingMethod): class GritLMPooler(SequencePooler): - def __init__(self, model_config: ModelConfig, pooler_config: PoolerConfig): + def __init__(self, model_config: ModelConfig): + pooler_config = model_config.pooler_config + assert pooler_config is not None + super().__init__( pooling=( GritLMMeanPool(model_config) if pooler_config.seq_pooling_type == "MEAN" else get_seq_pooling_method(pooler_config.seq_pooling_type) ), - head=self.head, + head=EmbeddingPoolerHead( + head_dtype=model_config.head_dtype, + activation=PoolerNormalize(), + ), ) - self.activation = PoolerNormalize() - - def head( - self, - pooled_data: SequencePoolingMethodOutput, - pooling_metadata: PoolingMetadata, - ) -> SequencePoolerHeadOutput: - return self.activation(pooled_data) - @default_pooling_type(seq_pooling_type="MEAN") class GritLM(LlamaForCausalLM): @@ -240,6 +237,6 @@ class GritLM(LlamaForCausalLM): self.pooler = DispatchPooler( { "token_embed": pooler_for_token_embed(pooler_config), - "embed": GritLMPooler(vllm_config.model_config, pooler_config), + "embed": GritLMPooler(vllm_config.model_config), } ) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 2b56540e6..f0d9ecfa9 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -8,17 +8,17 @@ from transformers import ModernBertConfig from transformers.activations import ACT2FN from vllm.compilation.decorators import support_torch_compile -from vllm.config import PoolerConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.attention.encoder_only_attention import ( EncoderOnlyAttention, ) from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.pooler import DispatchPooler +from vllm.model_executor.layers.pooler.activations import LambdaPoolerActivation from vllm.model_executor.layers.pooler.seqwise import ( + EmbeddingPoolerHead, SequencePooler, - SequencePoolerHeadOutput, - SequencePoolingMethodOutput, get_seq_pooling_method, ) from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify @@ -26,7 +26,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from vllm.v1.pool.metadata import PoolingMetadata from .interfaces import SupportsCrossEncoding from .interfaces_base import attn_type, default_pooling_type @@ -282,7 +281,11 @@ class ModernBertModel(nn.Module): class ModernBertPooler(SequencePooler): - def __init__(self, config: ModernBertConfig, pooler_config: PoolerConfig): + def __init__(self, model_config: ModelConfig): + pooler_config = model_config.pooler_config + assert pooler_config is not None + + config: ModernBertConfig = model_config.hf_config hf_pooling_type = config.classifier_pooling.upper() # vllm_pooling_type = pooler_config.seq_pooling_type # Currently we don't have a way to see if the user set the pooling type @@ -290,27 +293,30 @@ class ModernBertPooler(SequencePooler): super().__init__( pooling=get_seq_pooling_method(hf_pooling_type), - head=self.head, + # We set this dummy to avoid adding parameters to nn.Module too early + head=nn.Identity(), ) + head_dtype = model_config.head_dtype self.dense = nn.Linear( - config.hidden_size, config.hidden_size, config.classifier_bias + config.hidden_size, + config.hidden_size, + config.classifier_bias, + dtype=head_dtype, ) self.act = nn.GELU() self.norm = nn.LayerNorm( - config.hidden_size, eps=config.norm_eps, bias=config.norm_bias + config.hidden_size, + eps=config.norm_eps, + bias=config.norm_bias, ) - def head( - self, - pooled_data: SequencePoolingMethodOutput, - pooling_metadata: PoolingMetadata, - ) -> SequencePoolerHeadOutput: - if isinstance(pooled_data, list): - pooled_data = torch.stack(pooled_data) - - pooled_data = pooled_data.to(self.dense.weight.dtype) - return self.norm(self.act(self.dense(pooled_data))) + # Use lambdas so that weights are not registered under `self.head` + self.head = EmbeddingPoolerHead( + projector=lambda x: self.dense(x), + head_dtype=head_dtype, + activation=LambdaPoolerActivation(lambda x: self.norm(self.act(x))), + ) @default_pooling_type(seq_pooling_type="CLS") @@ -335,7 +341,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooling = ModernBertPooler(config, pooler_config) + self.pooling = ModernBertPooler(vllm_config.model_config) self.pooler = DispatchPooler.for_seq_cls( pooler_config,