[Model] Standardize pooling heads (#32148)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-13 01:01:49 +08:00
committed by GitHub
parent 3f72639d36
commit 8863c2b25c
9 changed files with 182 additions and 149 deletions

View File

@@ -2,12 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar
import torch
from vllm.pooling_params import PoolingParams
_T = TypeVar("_T", bound=torch.Tensor | list[torch.Tensor])
ProjectorFn = Callable[[torch.Tensor], torch.Tensor]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
ActivationFn = Callable[[_T], _T]
@dataclass(frozen=True)
@@ -24,4 +29,4 @@ class PoolingParamsUpdate:
params.requires_token_ids = self.requires_token_ids
__all__ = ["ClassifierFn", "PoolingParamsUpdate"]
__all__ = ["ActivationFn", "ClassifierFn", "ProjectorFn", "PoolingParamsUpdate"]

View File

@@ -7,14 +7,7 @@ from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
@@ -38,17 +31,17 @@ class SequencePoolerHead(nn.Module, ABC):
class EmbeddingPoolerHead(SequencePoolerHead):
def __init__(self) -> None:
def __init__(
self,
projector: ProjectorFn | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.projector = _load_st_projector(model_config)
self.head_dtype = model_config.head_dtype
self.activation = PoolerNormalize()
self.projector = projector
self.head_dtype = head_dtype
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"embed"}
@@ -65,7 +58,8 @@ class EmbeddingPoolerHead(SequencePoolerHead):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data = pooled_data.to(self.head_dtype)
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# Apply ST projector
if self.projector is not None:
@@ -88,15 +82,16 @@ class EmbeddingPoolerHead(SequencePoolerHead):
]
# for normalize
flags = [p.normalize for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
if self.activation is not None:
flags = [p.normalize for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
# pooled_data shape: [batchsize, embedding_dimension]
return pooled_data
@@ -106,20 +101,16 @@ class ClassifierPoolerHead(SequencePoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
logit_bias: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.classifier = classifier
self.logit_bias: float | None = model_config.pooler_config.logit_bias
self.head_dtype = model_config.head_dtype
self.act_fn = resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn
)
self.logit_bias = logit_bias
self.head_dtype = head_dtype
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"}
@@ -136,7 +127,8 @@ class ClassifierPoolerHead(SequencePoolerHead):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
pooled_data = pooled_data.to(self.head_dtype)
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None:
pooled_data = self.classifier(pooled_data)
@@ -145,13 +137,15 @@ class ClassifierPoolerHead(SequencePoolerHead):
if self.logit_bias is not None:
pooled_data -= self.logit_bias
flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
else:
scores = [
self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
]
if self.activation is not None:
flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
pooled_data = self.activation(pooled_data) if flags[0] else pooled_data
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
# scores shape: [batchsize, num_labels]
return scores
# pooled_data shape: [batchsize, num_labels]
return pooled_data

View File

@@ -5,10 +5,15 @@ from typing import TypeAlias
import torch
from vllm.config import PoolerConfig
from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import PoolerActivation
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.tasks import POOLING_TASKS, PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
@@ -86,7 +91,14 @@ class SequencePooler(Pooler):
def pooler_for_embed(pooler_config: PoolerConfig):
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = EmbeddingPoolerHead()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = EmbeddingPoolerHead(
projector=_load_st_projector(model_config),
head_dtype=model_config.head_dtype,
activation=PoolerNormalize(),
)
return SequencePooler(pooling=pooling, head=head)
@@ -101,6 +113,15 @@ def pooler_for_classify(
if pooling is None:
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = ClassifierPoolerHead(
classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias,
head_dtype=model_config.head_dtype,
activation=resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn
),
)
return SequencePooler(pooling=pooling, head=head)

View File

@@ -7,14 +7,7 @@ from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
@@ -49,17 +42,17 @@ class TokenPoolerHead(nn.Module, ABC):
class TokenEmbeddingPoolerHead(TokenPoolerHead):
def __init__(self) -> None:
def __init__(
self,
projector: ProjectorFn | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.projector = _load_st_projector(model_config)
self.head_dtype = model_config.head_dtype
self.activation = PoolerNormalize()
self.projector = projector
self.head_dtype = head_dtype
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed"}
@@ -73,7 +66,8 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
if pooled_data is None:
return None
pooled_data = pooled_data.to(self.head_dtype)
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
@@ -85,7 +79,7 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
pooled_data = pooled_data[..., : pooling_param.dimensions]
# for normalize
if pooling_param.normalize:
if self.activation is not None and pooling_param.normalize:
pooled_data = self.activation(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
@@ -96,20 +90,16 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
logit_bias: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.classifier = classifier
self.logit_bias: float | None = model_config.pooler_config.logit_bias
self.head_dtype = model_config.head_dtype
self.act_fn = resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn
)
self.logit_bias = logit_bias
self.head_dtype = head_dtype
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"}
@@ -123,7 +113,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
if pooled_data is None:
return None
pooled_data = pooled_data.to(self.head_dtype)
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
if self.classifier is not None:
@@ -135,8 +126,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
if self.logit_bias is not None:
scores -= self.logit_bias
if pooling_param.use_activation:
scores = self.act_fn(scores)
if self.activation is not None and pooling_param.use_activation:
scores = self.activation(scores)
# scores shape: [n_token, num_labels]
return scores

View File

@@ -5,10 +5,15 @@ from typing import TypeAlias
import torch
from vllm.config import PoolerConfig
from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import PoolerActivation
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.tasks import POOLING_TASKS, PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
@@ -86,7 +91,14 @@ class TokenPooler(Pooler):
def pooler_for_token_embed(pooler_config: PoolerConfig):
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenEmbeddingPoolerHead()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenEmbeddingPoolerHead(
projector=_load_st_projector(model_config),
head_dtype=model_config.head_dtype,
activation=PoolerNormalize(),
)
return TokenPooler(pooling=pooling, head=head)
@@ -101,6 +113,15 @@ def pooler_for_token_classify(
if pooling is None:
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenClassifierPoolerHead(
classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias,
head_dtype=model_config.head_dtype,
activation=resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn
),
)
return TokenPooler(pooling=pooling, head=head)

View File

@@ -8,7 +8,7 @@ from torch import nn
from transformers import BertConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.encoder_only_attention import (
@@ -24,11 +24,11 @@ from vllm.model_executor.layers.pooler import (
Pooler,
PoolingParamsUpdate,
)
from vllm.model_executor.layers.pooler.activations import LambdaPoolerActivation
from vllm.model_executor.layers.pooler.seqwise import (
EmbeddingPoolerHead,
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolerOutput,
SequencePoolingMethodOutput,
get_seq_pooling_method,
)
from vllm.model_executor.layers.pooler.tokwise import (
@@ -94,26 +94,32 @@ class BertEmbedding(nn.Module):
class BertPooler(SequencePooler):
def __init__(self, config: BertConfig, pooler_config: PoolerConfig):
def __init__(self, model_config: ModelConfig):
pooler_config = model_config.pooler_config
assert pooler_config is not None
config: BertConfig = model_config.hf_config
super().__init__(
pooling=get_seq_pooling_method(pooler_config.seq_pooling_type),
head=self.head,
# We set this dummy to avoid adding parameters to nn.Module too early
head=nn.Identity(),
)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
head_dtype = model_config.head_dtype
self.dense = nn.Linear(
config.hidden_size,
config.hidden_size,
dtype=head_dtype,
)
self.act_fn = nn.Tanh()
def head(
self,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
pooled_data = self.dense(pooled_data)
pooled_data = self.activation(pooled_data)
return pooled_data
# Use lambdas so that weights are not registered under `self.head`
self.head = EmbeddingPoolerHead(
projector=lambda x: self.dense(x),
head_dtype=head_dtype,
activation=LambdaPoolerActivation(self.act_fn),
)
class BertEncoder(nn.Module):
@@ -449,12 +455,7 @@ class BertPoolingModel(BertModel):
embedding_class=embedding_class,
)
config = vllm_config.model_config.hf_config
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = BertPooler(config, pooler_config)
self.pooler = BertPooler(vllm_config.model_config)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
other_weights, loaded_stacked_params = self._load_weights(weights)

View File

@@ -466,10 +466,7 @@ class BertWithRope(nn.Module, SupportsQuant):
)
if add_pooling_layer:
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = BertPooler(self.config, pooler_config)
self.pooler = BertPooler(vllm_config.model_config)
else:
self.pooler = None

View File

@@ -5,7 +5,7 @@ from collections.abc import Set
import numpy as np
import torch
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import (
DispatchPooler,
@@ -13,8 +13,8 @@ from vllm.model_executor.layers.pooler import (
)
from vllm.model_executor.layers.pooler.activations import PoolerNormalize
from vllm.model_executor.layers.pooler.seqwise import (
EmbeddingPoolerHead,
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolingMethod,
SequencePoolingMethodOutput,
get_seq_pooling_method,
@@ -178,25 +178,22 @@ class GritLMMeanPool(SequencePoolingMethod):
class GritLMPooler(SequencePooler):
def __init__(self, model_config: ModelConfig, pooler_config: PoolerConfig):
def __init__(self, model_config: ModelConfig):
pooler_config = model_config.pooler_config
assert pooler_config is not None
super().__init__(
pooling=(
GritLMMeanPool(model_config)
if pooler_config.seq_pooling_type == "MEAN"
else get_seq_pooling_method(pooler_config.seq_pooling_type)
),
head=self.head,
head=EmbeddingPoolerHead(
head_dtype=model_config.head_dtype,
activation=PoolerNormalize(),
),
)
self.activation = PoolerNormalize()
def head(
self,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
return self.activation(pooled_data)
@default_pooling_type(seq_pooling_type="MEAN")
class GritLM(LlamaForCausalLM):
@@ -240,6 +237,6 @@ class GritLM(LlamaForCausalLM):
self.pooler = DispatchPooler(
{
"token_embed": pooler_for_token_embed(pooler_config),
"embed": GritLMPooler(vllm_config.model_config, pooler_config),
"embed": GritLMPooler(vllm_config.model_config),
}
)

View File

@@ -8,17 +8,17 @@ from transformers import ModernBertConfig
from transformers.activations import ACT2FN
from vllm.compilation.decorators import support_torch_compile
from vllm.config import PoolerConfig, VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.attention.encoder_only_attention import (
EncoderOnlyAttention,
)
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.activations import LambdaPoolerActivation
from vllm.model_executor.layers.pooler.seqwise import (
EmbeddingPoolerHead,
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolingMethodOutput,
get_seq_pooling_method,
)
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding
from .interfaces_base import attn_type, default_pooling_type
@@ -282,7 +281,11 @@ class ModernBertModel(nn.Module):
class ModernBertPooler(SequencePooler):
def __init__(self, config: ModernBertConfig, pooler_config: PoolerConfig):
def __init__(self, model_config: ModelConfig):
pooler_config = model_config.pooler_config
assert pooler_config is not None
config: ModernBertConfig = model_config.hf_config
hf_pooling_type = config.classifier_pooling.upper()
# vllm_pooling_type = pooler_config.seq_pooling_type
# Currently we don't have a way to see if the user set the pooling type
@@ -290,27 +293,30 @@ class ModernBertPooler(SequencePooler):
super().__init__(
pooling=get_seq_pooling_method(hf_pooling_type),
head=self.head,
# We set this dummy to avoid adding parameters to nn.Module too early
head=nn.Identity(),
)
head_dtype = model_config.head_dtype
self.dense = nn.Linear(
config.hidden_size, config.hidden_size, config.classifier_bias
config.hidden_size,
config.hidden_size,
config.classifier_bias,
dtype=head_dtype,
)
self.act = nn.GELU()
self.norm = nn.LayerNorm(
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias,
)
def head(
self,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
pooled_data = pooled_data.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_data)))
# Use lambdas so that weights are not registered under `self.head`
self.head = EmbeddingPoolerHead(
projector=lambda x: self.dense(x),
head_dtype=head_dtype,
activation=LambdaPoolerActivation(lambda x: self.norm(self.act(x))),
)
@default_pooling_type(seq_pooling_type="CLS")
@@ -335,7 +341,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooling = ModernBertPooler(config, pooler_config)
self.pooling = ModernBertPooler(vllm_config.model_config)
self.pooler = DispatchPooler.for_seq_cls(
pooler_config,