[Model] Standardize pooling heads (#32148)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-13 01:01:49 +08:00
committed by GitHub
parent 3f72639d36
commit 8863c2b25c
9 changed files with 182 additions and 149 deletions

View File

@@ -2,12 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar
import torch import torch
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
_T = TypeVar("_T", bound=torch.Tensor | list[torch.Tensor])
ProjectorFn = Callable[[torch.Tensor], torch.Tensor]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor] ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
ActivationFn = Callable[[_T], _T]
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -24,4 +29,4 @@ class PoolingParamsUpdate:
params.requires_token_ids = self.requires_token_ids params.requires_token_ids = self.requires_token_ids
__all__ = ["ClassifierFn", "PoolingParamsUpdate"] __all__ = ["ActivationFn", "ClassifierFn", "ProjectorFn", "PoolingParamsUpdate"]

View File

@@ -7,14 +7,7 @@ from typing import TypeAlias
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import get_current_vllm_config from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
from vllm.model_executor.layers.pooler import ClassifierFn
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
@@ -38,17 +31,17 @@ class SequencePoolerHead(nn.Module, ABC):
class EmbeddingPoolerHead(SequencePoolerHead): class EmbeddingPoolerHead(SequencePoolerHead):
def __init__(self) -> None: def __init__(
self,
projector: ProjectorFn | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__() super().__init__()
# Load ST projector if available self.projector = projector
vllm_config = get_current_vllm_config() self.head_dtype = head_dtype
model_config = vllm_config.model_config self.activation = activation
self.projector = _load_st_projector(model_config)
self.head_dtype = model_config.head_dtype
self.activation = PoolerNormalize()
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"embed"} return {"embed"}
@@ -65,7 +58,8 @@ class EmbeddingPoolerHead(SequencePoolerHead):
pooled_data = torch.stack(pooled_data) pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension] # pooled_data shape: [batchsize, hidden_dimension]
pooled_data = pooled_data.to(self.head_dtype) if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# Apply ST projector # Apply ST projector
if self.projector is not None: if self.projector is not None:
@@ -88,15 +82,16 @@ class EmbeddingPoolerHead(SequencePoolerHead):
] ]
# for normalize # for normalize
flags = [p.normalize for p in pooling_params] if self.activation is not None:
if len(set(flags)) == 1: flags = [p.normalize for p in pooling_params]
if flags[0]: if len(set(flags)) == 1:
pooled_data = self.activation(pooled_data) if flags[0]:
else: pooled_data = self.activation(pooled_data)
pooled_data = [ else:
self.activation(vecs) if f else vecs pooled_data = [
for vecs, f in zip(pooled_data, flags) self.activation(vecs) if f else vecs
] for vecs, f in zip(pooled_data, flags)
]
# pooled_data shape: [batchsize, embedding_dimension] # pooled_data shape: [batchsize, embedding_dimension]
return pooled_data return pooled_data
@@ -106,20 +101,16 @@ class ClassifierPoolerHead(SequencePoolerHead):
def __init__( def __init__(
self, self,
classifier: ClassifierFn | None = None, classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None, logit_bias: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.classifier = classifier self.classifier = classifier
self.logit_bias: float | None = model_config.pooler_config.logit_bias self.logit_bias = logit_bias
self.head_dtype = model_config.head_dtype self.head_dtype = head_dtype
self.activation = activation
self.act_fn = resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn
)
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"} return {"classify", "score"}
@@ -136,7 +127,8 @@ class ClassifierPoolerHead(SequencePoolerHead):
pooled_data = torch.stack(pooled_data) pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size] # pooled_data shape: [batchsize, hidden_size]
pooled_data = pooled_data.to(self.head_dtype) if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None: if self.classifier is not None:
pooled_data = self.classifier(pooled_data) pooled_data = self.classifier(pooled_data)
@@ -145,13 +137,15 @@ class ClassifierPoolerHead(SequencePoolerHead):
if self.logit_bias is not None: if self.logit_bias is not None:
pooled_data -= self.logit_bias pooled_data -= self.logit_bias
flags = [p.use_activation for p in pooling_params] if self.activation is not None:
if len(set(flags)) == 1: flags = [p.use_activation for p in pooling_params]
scores = self.act_fn(pooled_data) if flags[0] else pooled_data if len(set(flags)) == 1:
else: pooled_data = self.activation(pooled_data) if flags[0] else pooled_data
scores = [ else:
self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags) pooled_data = [
] self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
# scores shape: [batchsize, num_labels] # pooled_data shape: [batchsize, num_labels]
return scores return pooled_data

View File

@@ -5,10 +5,15 @@ from typing import TypeAlias
import torch import torch
from vllm.config import PoolerConfig from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler.abstract import Pooler from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import PoolerActivation from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.tasks import POOLING_TASKS, PoolingTask from vllm.tasks import POOLING_TASKS, PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
@@ -86,7 +91,14 @@ class SequencePooler(Pooler):
def pooler_for_embed(pooler_config: PoolerConfig): def pooler_for_embed(pooler_config: PoolerConfig):
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type()) pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = EmbeddingPoolerHead()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = EmbeddingPoolerHead(
projector=_load_st_projector(model_config),
head_dtype=model_config.head_dtype,
activation=PoolerNormalize(),
)
return SequencePooler(pooling=pooling, head=head) return SequencePooler(pooling=pooling, head=head)
@@ -101,6 +113,15 @@ def pooler_for_classify(
if pooling is None: if pooling is None:
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type()) pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn) vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = ClassifierPoolerHead(
classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias,
head_dtype=model_config.head_dtype,
activation=resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn
),
)
return SequencePooler(pooling=pooling, head=head) return SequencePooler(pooling=pooling, head=head)

View File

@@ -7,14 +7,7 @@ from typing import TypeAlias
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import get_current_vllm_config from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
from vllm.model_executor.layers.pooler import ClassifierFn
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
@@ -49,17 +42,17 @@ class TokenPoolerHead(nn.Module, ABC):
class TokenEmbeddingPoolerHead(TokenPoolerHead): class TokenEmbeddingPoolerHead(TokenPoolerHead):
def __init__(self) -> None: def __init__(
self,
projector: ProjectorFn | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__() super().__init__()
# Load ST projector if available self.projector = projector
vllm_config = get_current_vllm_config() self.head_dtype = head_dtype
model_config = vllm_config.model_config self.activation = activation
self.projector = _load_st_projector(model_config)
self.head_dtype = model_config.head_dtype
self.activation = PoolerNormalize()
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed"} return {"token_embed"}
@@ -73,7 +66,8 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
if pooled_data is None: if pooled_data is None:
return None return None
pooled_data = pooled_data.to(self.head_dtype) if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension] # pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector # Apply ST projector
@@ -85,7 +79,7 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
pooled_data = pooled_data[..., : pooling_param.dimensions] pooled_data = pooled_data[..., : pooling_param.dimensions]
# for normalize # for normalize
if pooling_param.normalize: if self.activation is not None and pooling_param.normalize:
pooled_data = self.activation(pooled_data) pooled_data = self.activation(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension] # pooled_data shape: [n_tokens, embedding_dimension]
@@ -96,20 +90,16 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
def __init__( def __init__(
self, self,
classifier: ClassifierFn | None = None, classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None, logit_bias: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.classifier = classifier self.classifier = classifier
self.logit_bias: float | None = model_config.pooler_config.logit_bias self.logit_bias = logit_bias
self.head_dtype = model_config.head_dtype self.head_dtype = head_dtype
self.activation = activation
self.act_fn = resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn
)
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"} return {"token_classify"}
@@ -123,7 +113,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
if pooled_data is None: if pooled_data is None:
return None return None
pooled_data = pooled_data.to(self.head_dtype) if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size] # hidden_states shape: [n_token, hidden_size]
if self.classifier is not None: if self.classifier is not None:
@@ -135,8 +126,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
if self.logit_bias is not None: if self.logit_bias is not None:
scores -= self.logit_bias scores -= self.logit_bias
if pooling_param.use_activation: if self.activation is not None and pooling_param.use_activation:
scores = self.act_fn(scores) scores = self.activation(scores)
# scores shape: [n_token, num_labels] # scores shape: [n_token, num_labels]
return scores return scores

View File

@@ -5,10 +5,15 @@ from typing import TypeAlias
import torch import torch
from vllm.config import PoolerConfig from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler.abstract import Pooler from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import PoolerActivation from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.tasks import POOLING_TASKS, PoolingTask from vllm.tasks import POOLING_TASKS, PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
@@ -86,7 +91,14 @@ class TokenPooler(Pooler):
def pooler_for_token_embed(pooler_config: PoolerConfig): def pooler_for_token_embed(pooler_config: PoolerConfig):
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type()) pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenEmbeddingPoolerHead()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenEmbeddingPoolerHead(
projector=_load_st_projector(model_config),
head_dtype=model_config.head_dtype,
activation=PoolerNormalize(),
)
return TokenPooler(pooling=pooling, head=head) return TokenPooler(pooling=pooling, head=head)
@@ -101,6 +113,15 @@ def pooler_for_token_classify(
if pooling is None: if pooling is None:
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type()) pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn) vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenClassifierPoolerHead(
classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias,
head_dtype=model_config.head_dtype,
activation=resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn
),
)
return TokenPooler(pooling=pooling, head=head) return TokenPooler(pooling=pooling, head=head)

View File

@@ -8,7 +8,7 @@ from torch import nn
from transformers import BertConfig from transformers import BertConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.encoder_only_attention import ( from vllm.model_executor.layers.attention.encoder_only_attention import (
@@ -24,11 +24,11 @@ from vllm.model_executor.layers.pooler import (
Pooler, Pooler,
PoolingParamsUpdate, PoolingParamsUpdate,
) )
from vllm.model_executor.layers.pooler.activations import LambdaPoolerActivation
from vllm.model_executor.layers.pooler.seqwise import ( from vllm.model_executor.layers.pooler.seqwise import (
EmbeddingPoolerHead,
SequencePooler, SequencePooler,
SequencePoolerHeadOutput,
SequencePoolerOutput, SequencePoolerOutput,
SequencePoolingMethodOutput,
get_seq_pooling_method, get_seq_pooling_method,
) )
from vllm.model_executor.layers.pooler.tokwise import ( from vllm.model_executor.layers.pooler.tokwise import (
@@ -94,26 +94,32 @@ class BertEmbedding(nn.Module):
class BertPooler(SequencePooler): class BertPooler(SequencePooler):
def __init__(self, config: BertConfig, pooler_config: PoolerConfig): def __init__(self, model_config: ModelConfig):
pooler_config = model_config.pooler_config
assert pooler_config is not None
config: BertConfig = model_config.hf_config
super().__init__( super().__init__(
pooling=get_seq_pooling_method(pooler_config.seq_pooling_type), pooling=get_seq_pooling_method(pooler_config.seq_pooling_type),
head=self.head, # We set this dummy to avoid adding parameters to nn.Module too early
head=nn.Identity(),
) )
self.dense = nn.Linear(config.hidden_size, config.hidden_size) head_dtype = model_config.head_dtype
self.activation = nn.Tanh() self.dense = nn.Linear(
config.hidden_size,
config.hidden_size,
dtype=head_dtype,
)
self.act_fn = nn.Tanh()
def head( # Use lambdas so that weights are not registered under `self.head`
self, self.head = EmbeddingPoolerHead(
pooled_data: SequencePoolingMethodOutput, projector=lambda x: self.dense(x),
pooling_metadata: PoolingMetadata, head_dtype=head_dtype,
) -> SequencePoolerHeadOutput: activation=LambdaPoolerActivation(self.act_fn),
if isinstance(pooled_data, list): )
pooled_data = torch.stack(pooled_data)
pooled_data = self.dense(pooled_data)
pooled_data = self.activation(pooled_data)
return pooled_data
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
@@ -449,12 +455,7 @@ class BertPoolingModel(BertModel):
embedding_class=embedding_class, embedding_class=embedding_class,
) )
config = vllm_config.model_config.hf_config self.pooler = BertPooler(vllm_config.model_config)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = BertPooler(config, pooler_config)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
other_weights, loaded_stacked_params = self._load_weights(weights) other_weights, loaded_stacked_params = self._load_weights(weights)

View File

@@ -466,10 +466,7 @@ class BertWithRope(nn.Module, SupportsQuant):
) )
if add_pooling_layer: if add_pooling_layer:
pooler_config = vllm_config.model_config.pooler_config self.pooler = BertPooler(vllm_config.model_config)
assert pooler_config is not None
self.pooler = BertPooler(self.config, pooler_config)
else: else:
self.pooler = None self.pooler = None

View File

@@ -5,7 +5,7 @@ from collections.abc import Set
import numpy as np import numpy as np
import torch import torch
from vllm.config import ModelConfig, PoolerConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import (
DispatchPooler, DispatchPooler,
@@ -13,8 +13,8 @@ from vllm.model_executor.layers.pooler import (
) )
from vllm.model_executor.layers.pooler.activations import PoolerNormalize from vllm.model_executor.layers.pooler.activations import PoolerNormalize
from vllm.model_executor.layers.pooler.seqwise import ( from vllm.model_executor.layers.pooler.seqwise import (
EmbeddingPoolerHead,
SequencePooler, SequencePooler,
SequencePoolerHeadOutput,
SequencePoolingMethod, SequencePoolingMethod,
SequencePoolingMethodOutput, SequencePoolingMethodOutput,
get_seq_pooling_method, get_seq_pooling_method,
@@ -178,25 +178,22 @@ class GritLMMeanPool(SequencePoolingMethod):
class GritLMPooler(SequencePooler): class GritLMPooler(SequencePooler):
def __init__(self, model_config: ModelConfig, pooler_config: PoolerConfig): def __init__(self, model_config: ModelConfig):
pooler_config = model_config.pooler_config
assert pooler_config is not None
super().__init__( super().__init__(
pooling=( pooling=(
GritLMMeanPool(model_config) GritLMMeanPool(model_config)
if pooler_config.seq_pooling_type == "MEAN" if pooler_config.seq_pooling_type == "MEAN"
else get_seq_pooling_method(pooler_config.seq_pooling_type) else get_seq_pooling_method(pooler_config.seq_pooling_type)
), ),
head=self.head, head=EmbeddingPoolerHead(
head_dtype=model_config.head_dtype,
activation=PoolerNormalize(),
),
) )
self.activation = PoolerNormalize()
def head(
self,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
return self.activation(pooled_data)
@default_pooling_type(seq_pooling_type="MEAN") @default_pooling_type(seq_pooling_type="MEAN")
class GritLM(LlamaForCausalLM): class GritLM(LlamaForCausalLM):
@@ -240,6 +237,6 @@ class GritLM(LlamaForCausalLM):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"token_embed": pooler_for_token_embed(pooler_config), "token_embed": pooler_for_token_embed(pooler_config),
"embed": GritLMPooler(vllm_config.model_config, pooler_config), "embed": GritLMPooler(vllm_config.model_config),
} }
) )

View File

@@ -8,17 +8,17 @@ from transformers import ModernBertConfig
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import PoolerConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.attention.encoder_only_attention import ( from vllm.model_executor.layers.attention.encoder_only_attention import (
EncoderOnlyAttention, EncoderOnlyAttention,
) )
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.activations import LambdaPoolerActivation
from vllm.model_executor.layers.pooler.seqwise import ( from vllm.model_executor.layers.pooler.seqwise import (
EmbeddingPoolerHead,
SequencePooler, SequencePooler,
SequencePoolerHeadOutput,
SequencePoolingMethodOutput,
get_seq_pooling_method, get_seq_pooling_method,
) )
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding from .interfaces import SupportsCrossEncoding
from .interfaces_base import attn_type, default_pooling_type from .interfaces_base import attn_type, default_pooling_type
@@ -282,7 +281,11 @@ class ModernBertModel(nn.Module):
class ModernBertPooler(SequencePooler): class ModernBertPooler(SequencePooler):
def __init__(self, config: ModernBertConfig, pooler_config: PoolerConfig): def __init__(self, model_config: ModelConfig):
pooler_config = model_config.pooler_config
assert pooler_config is not None
config: ModernBertConfig = model_config.hf_config
hf_pooling_type = config.classifier_pooling.upper() hf_pooling_type = config.classifier_pooling.upper()
# vllm_pooling_type = pooler_config.seq_pooling_type # vllm_pooling_type = pooler_config.seq_pooling_type
# Currently we don't have a way to see if the user set the pooling type # Currently we don't have a way to see if the user set the pooling type
@@ -290,27 +293,30 @@ class ModernBertPooler(SequencePooler):
super().__init__( super().__init__(
pooling=get_seq_pooling_method(hf_pooling_type), pooling=get_seq_pooling_method(hf_pooling_type),
head=self.head, # We set this dummy to avoid adding parameters to nn.Module too early
head=nn.Identity(),
) )
head_dtype = model_config.head_dtype
self.dense = nn.Linear( self.dense = nn.Linear(
config.hidden_size, config.hidden_size, config.classifier_bias config.hidden_size,
config.hidden_size,
config.classifier_bias,
dtype=head_dtype,
) )
self.act = nn.GELU() self.act = nn.GELU()
self.norm = nn.LayerNorm( self.norm = nn.LayerNorm(
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias,
) )
def head( # Use lambdas so that weights are not registered under `self.head`
self, self.head = EmbeddingPoolerHead(
pooled_data: SequencePoolingMethodOutput, projector=lambda x: self.dense(x),
pooling_metadata: PoolingMetadata, head_dtype=head_dtype,
) -> SequencePoolerHeadOutput: activation=LambdaPoolerActivation(lambda x: self.norm(self.act(x))),
if isinstance(pooled_data, list): )
pooled_data = torch.stack(pooled_data)
pooled_data = pooled_data.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_data)))
@default_pooling_type(seq_pooling_type="CLS") @default_pooling_type(seq_pooling_type="CLS")
@@ -335,7 +341,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooling = ModernBertPooler(config, pooler_config) self.pooling = ModernBertPooler(vllm_config.model_config)
self.pooler = DispatchPooler.for_seq_cls( self.pooler = DispatchPooler.for_seq_cls(
pooler_config, pooler_config,