[Model] Standardize pooling heads (#32148)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-13 01:01:49 +08:00
committed by GitHub
parent 3f72639d36
commit 8863c2b25c
9 changed files with 182 additions and 149 deletions

View File

@@ -2,12 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar
import torch
from vllm.pooling_params import PoolingParams
_T = TypeVar("_T", bound=torch.Tensor | list[torch.Tensor])
ProjectorFn = Callable[[torch.Tensor], torch.Tensor]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
ActivationFn = Callable[[_T], _T]
@dataclass(frozen=True)
@@ -24,4 +29,4 @@ class PoolingParamsUpdate:
params.requires_token_ids = self.requires_token_ids
__all__ = ["ClassifierFn", "PoolingParamsUpdate"]
__all__ = ["ActivationFn", "ClassifierFn", "ProjectorFn", "PoolingParamsUpdate"]

View File

@@ -7,14 +7,7 @@ from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
@@ -38,17 +31,17 @@ class SequencePoolerHead(nn.Module, ABC):
class EmbeddingPoolerHead(SequencePoolerHead):
def __init__(self) -> None:
def __init__(
self,
projector: ProjectorFn | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.projector = _load_st_projector(model_config)
self.head_dtype = model_config.head_dtype
self.activation = PoolerNormalize()
self.projector = projector
self.head_dtype = head_dtype
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"embed"}
@@ -65,7 +58,8 @@ class EmbeddingPoolerHead(SequencePoolerHead):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data = pooled_data.to(self.head_dtype)
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# Apply ST projector
if self.projector is not None:
@@ -88,15 +82,16 @@ class EmbeddingPoolerHead(SequencePoolerHead):
]
# for normalize
flags = [p.normalize for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
if self.activation is not None:
flags = [p.normalize for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
# pooled_data shape: [batchsize, embedding_dimension]
return pooled_data
@@ -106,20 +101,16 @@ class ClassifierPoolerHead(SequencePoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
logit_bias: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.classifier = classifier
self.logit_bias: float | None = model_config.pooler_config.logit_bias
self.head_dtype = model_config.head_dtype
self.act_fn = resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn
)
self.logit_bias = logit_bias
self.head_dtype = head_dtype
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"}
@@ -136,7 +127,8 @@ class ClassifierPoolerHead(SequencePoolerHead):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
pooled_data = pooled_data.to(self.head_dtype)
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None:
pooled_data = self.classifier(pooled_data)
@@ -145,13 +137,15 @@ class ClassifierPoolerHead(SequencePoolerHead):
if self.logit_bias is not None:
pooled_data -= self.logit_bias
flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
else:
scores = [
self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
]
if self.activation is not None:
flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
pooled_data = self.activation(pooled_data) if flags[0] else pooled_data
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
# scores shape: [batchsize, num_labels]
return scores
# pooled_data shape: [batchsize, num_labels]
return pooled_data

View File

@@ -5,10 +5,15 @@ from typing import TypeAlias
import torch
from vllm.config import PoolerConfig
from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import PoolerActivation
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.tasks import POOLING_TASKS, PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
@@ -86,7 +91,14 @@ class SequencePooler(Pooler):
def pooler_for_embed(pooler_config: PoolerConfig):
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = EmbeddingPoolerHead()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = EmbeddingPoolerHead(
projector=_load_st_projector(model_config),
head_dtype=model_config.head_dtype,
activation=PoolerNormalize(),
)
return SequencePooler(pooling=pooling, head=head)
@@ -101,6 +113,15 @@ def pooler_for_classify(
if pooling is None:
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = ClassifierPoolerHead(
classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias,
head_dtype=model_config.head_dtype,
activation=resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn
),
)
return SequencePooler(pooling=pooling, head=head)

View File

@@ -7,14 +7,7 @@ from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
@@ -49,17 +42,17 @@ class TokenPoolerHead(nn.Module, ABC):
class TokenEmbeddingPoolerHead(TokenPoolerHead):
def __init__(self) -> None:
def __init__(
self,
projector: ProjectorFn | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.projector = _load_st_projector(model_config)
self.head_dtype = model_config.head_dtype
self.activation = PoolerNormalize()
self.projector = projector
self.head_dtype = head_dtype
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed"}
@@ -73,7 +66,8 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
if pooled_data is None:
return None
pooled_data = pooled_data.to(self.head_dtype)
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
@@ -85,7 +79,7 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
pooled_data = pooled_data[..., : pooling_param.dimensions]
# for normalize
if pooling_param.normalize:
if self.activation is not None and pooling_param.normalize:
pooled_data = self.activation(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
@@ -96,20 +90,16 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
logit_bias: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.classifier = classifier
self.logit_bias: float | None = model_config.pooler_config.logit_bias
self.head_dtype = model_config.head_dtype
self.act_fn = resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn
)
self.logit_bias = logit_bias
self.head_dtype = head_dtype
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"}
@@ -123,7 +113,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
if pooled_data is None:
return None
pooled_data = pooled_data.to(self.head_dtype)
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
if self.classifier is not None:
@@ -135,8 +126,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
if self.logit_bias is not None:
scores -= self.logit_bias
if pooling_param.use_activation:
scores = self.act_fn(scores)
if self.activation is not None and pooling_param.use_activation:
scores = self.activation(scores)
# scores shape: [n_token, num_labels]
return scores

View File

@@ -5,10 +5,15 @@ from typing import TypeAlias
import torch
from vllm.config import PoolerConfig
from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import PoolerActivation
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.tasks import POOLING_TASKS, PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
@@ -86,7 +91,14 @@ class TokenPooler(Pooler):
def pooler_for_token_embed(pooler_config: PoolerConfig):
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenEmbeddingPoolerHead()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenEmbeddingPoolerHead(
projector=_load_st_projector(model_config),
head_dtype=model_config.head_dtype,
activation=PoolerNormalize(),
)
return TokenPooler(pooling=pooling, head=head)
@@ -101,6 +113,15 @@ def pooler_for_token_classify(
if pooling is None:
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenClassifierPoolerHead(
classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias,
head_dtype=model_config.head_dtype,
activation=resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn
),
)
return TokenPooler(pooling=pooling, head=head)