[Model] Standardize pooling heads (#32148)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,12 +2,17 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
_T = TypeVar("_T", bound=torch.Tensor | list[torch.Tensor])
|
||||
|
||||
ProjectorFn = Callable[[torch.Tensor], torch.Tensor]
|
||||
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
|
||||
ActivationFn = Callable[[_T], _T]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -24,4 +29,4 @@ class PoolingParamsUpdate:
|
||||
params.requires_token_ids = self.requires_token_ids
|
||||
|
||||
|
||||
__all__ = ["ClassifierFn", "PoolingParamsUpdate"]
|
||||
__all__ = ["ActivationFn", "ClassifierFn", "ProjectorFn", "PoolingParamsUpdate"]
|
||||
|
||||
@@ -7,14 +7,7 @@ from typing import TypeAlias
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn
|
||||
from vllm.model_executor.layers.pooler.activations import (
|
||||
PoolerActivation,
|
||||
PoolerNormalize,
|
||||
resolve_classifier_act_fn,
|
||||
)
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
@@ -38,17 +31,17 @@ class SequencePoolerHead(nn.Module, ABC):
|
||||
|
||||
|
||||
class EmbeddingPoolerHead(SequencePoolerHead):
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
projector: ProjectorFn | None = None,
|
||||
head_dtype: torch.dtype | str | None = None,
|
||||
activation: ActivationFn | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Load ST projector if available
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
self.projector = _load_st_projector(model_config)
|
||||
self.head_dtype = model_config.head_dtype
|
||||
|
||||
self.activation = PoolerNormalize()
|
||||
self.projector = projector
|
||||
self.head_dtype = head_dtype
|
||||
self.activation = activation
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"embed"}
|
||||
@@ -65,7 +58,8 @@ class EmbeddingPoolerHead(SequencePoolerHead):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_dimension]
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
if self.head_dtype is not None:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
# Apply ST projector
|
||||
if self.projector is not None:
|
||||
@@ -88,15 +82,16 @@ class EmbeddingPoolerHead(SequencePoolerHead):
|
||||
]
|
||||
|
||||
# for normalize
|
||||
flags = [p.normalize for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
if flags[0]:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
else:
|
||||
pooled_data = [
|
||||
self.activation(vecs) if f else vecs
|
||||
for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
if self.activation is not None:
|
||||
flags = [p.normalize for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
if flags[0]:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
else:
|
||||
pooled_data = [
|
||||
self.activation(vecs) if f else vecs
|
||||
for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
# pooled_data shape: [batchsize, embedding_dimension]
|
||||
return pooled_data
|
||||
@@ -106,20 +101,16 @@ class ClassifierPoolerHead(SequencePoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
logit_bias: float | None = None,
|
||||
head_dtype: torch.dtype | str | None = None,
|
||||
activation: ActivationFn | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
self.classifier = classifier
|
||||
self.logit_bias: float | None = model_config.pooler_config.logit_bias
|
||||
self.head_dtype = model_config.head_dtype
|
||||
|
||||
self.act_fn = resolve_classifier_act_fn(
|
||||
model_config, static_num_labels=True, act_fn=act_fn
|
||||
)
|
||||
self.logit_bias = logit_bias
|
||||
self.head_dtype = head_dtype
|
||||
self.activation = activation
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"classify", "score"}
|
||||
@@ -136,7 +127,8 @@ class ClassifierPoolerHead(SequencePoolerHead):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_size]
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
if self.head_dtype is not None:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
if self.classifier is not None:
|
||||
pooled_data = self.classifier(pooled_data)
|
||||
@@ -145,13 +137,15 @@ class ClassifierPoolerHead(SequencePoolerHead):
|
||||
if self.logit_bias is not None:
|
||||
pooled_data -= self.logit_bias
|
||||
|
||||
flags = [p.use_activation for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
|
||||
else:
|
||||
scores = [
|
||||
self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
if self.activation is not None:
|
||||
flags = [p.use_activation for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
pooled_data = self.activation(pooled_data) if flags[0] else pooled_data
|
||||
else:
|
||||
pooled_data = [
|
||||
self.activation(vecs) if f else vecs
|
||||
for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
# scores shape: [batchsize, num_labels]
|
||||
return scores
|
||||
# pooled_data shape: [batchsize, num_labels]
|
||||
return pooled_data
|
||||
|
||||
@@ -5,10 +5,15 @@ from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.config import PoolerConfig, get_current_vllm_config
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
|
||||
from vllm.model_executor.layers.pooler.abstract import Pooler
|
||||
from vllm.model_executor.layers.pooler.activations import PoolerActivation
|
||||
from vllm.model_executor.layers.pooler.activations import (
|
||||
PoolerActivation,
|
||||
PoolerNormalize,
|
||||
resolve_classifier_act_fn,
|
||||
)
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.tasks import POOLING_TASKS, PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
@@ -86,7 +91,14 @@ class SequencePooler(Pooler):
|
||||
|
||||
def pooler_for_embed(pooler_config: PoolerConfig):
|
||||
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
|
||||
head = EmbeddingPoolerHead()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
head = EmbeddingPoolerHead(
|
||||
projector=_load_st_projector(model_config),
|
||||
head_dtype=model_config.head_dtype,
|
||||
activation=PoolerNormalize(),
|
||||
)
|
||||
|
||||
return SequencePooler(pooling=pooling, head=head)
|
||||
|
||||
@@ -101,6 +113,15 @@ def pooler_for_classify(
|
||||
if pooling is None:
|
||||
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
|
||||
|
||||
head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
head = ClassifierPoolerHead(
|
||||
classifier=classifier,
|
||||
logit_bias=model_config.pooler_config.logit_bias,
|
||||
head_dtype=model_config.head_dtype,
|
||||
activation=resolve_classifier_act_fn(
|
||||
model_config, static_num_labels=True, act_fn=act_fn
|
||||
),
|
||||
)
|
||||
|
||||
return SequencePooler(pooling=pooling, head=head)
|
||||
|
||||
@@ -7,14 +7,7 @@ from typing import TypeAlias
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn
|
||||
from vllm.model_executor.layers.pooler.activations import (
|
||||
PoolerActivation,
|
||||
PoolerNormalize,
|
||||
resolve_classifier_act_fn,
|
||||
)
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
@@ -49,17 +42,17 @@ class TokenPoolerHead(nn.Module, ABC):
|
||||
|
||||
|
||||
class TokenEmbeddingPoolerHead(TokenPoolerHead):
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
projector: ProjectorFn | None = None,
|
||||
head_dtype: torch.dtype | str | None = None,
|
||||
activation: ActivationFn | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Load ST projector if available
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
self.projector = _load_st_projector(model_config)
|
||||
self.head_dtype = model_config.head_dtype
|
||||
|
||||
self.activation = PoolerNormalize()
|
||||
self.projector = projector
|
||||
self.head_dtype = head_dtype
|
||||
self.activation = activation
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed"}
|
||||
@@ -73,7 +66,8 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
|
||||
if pooled_data is None:
|
||||
return None
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
if self.head_dtype is not None:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# pooled_data shape: [n_tokens, hidden_dimension]
|
||||
|
||||
# Apply ST projector
|
||||
@@ -85,7 +79,7 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
|
||||
pooled_data = pooled_data[..., : pooling_param.dimensions]
|
||||
|
||||
# for normalize
|
||||
if pooling_param.normalize:
|
||||
if self.activation is not None and pooling_param.normalize:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
|
||||
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||
@@ -96,20 +90,16 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
logit_bias: float | None = None,
|
||||
head_dtype: torch.dtype | str | None = None,
|
||||
activation: ActivationFn | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
self.classifier = classifier
|
||||
self.logit_bias: float | None = model_config.pooler_config.logit_bias
|
||||
self.head_dtype = model_config.head_dtype
|
||||
|
||||
self.act_fn = resolve_classifier_act_fn(
|
||||
model_config, static_num_labels=False, act_fn=act_fn
|
||||
)
|
||||
self.logit_bias = logit_bias
|
||||
self.head_dtype = head_dtype
|
||||
self.activation = activation
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_classify"}
|
||||
@@ -123,7 +113,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
|
||||
if pooled_data is None:
|
||||
return None
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
if self.head_dtype is not None:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# hidden_states shape: [n_token, hidden_size]
|
||||
|
||||
if self.classifier is not None:
|
||||
@@ -135,8 +126,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
|
||||
if self.logit_bias is not None:
|
||||
scores -= self.logit_bias
|
||||
|
||||
if pooling_param.use_activation:
|
||||
scores = self.act_fn(scores)
|
||||
if self.activation is not None and pooling_param.use_activation:
|
||||
scores = self.activation(scores)
|
||||
|
||||
# scores shape: [n_token, num_labels]
|
||||
return scores
|
||||
|
||||
@@ -5,10 +5,15 @@ from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.config import PoolerConfig, get_current_vllm_config
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
|
||||
from vllm.model_executor.layers.pooler.abstract import Pooler
|
||||
from vllm.model_executor.layers.pooler.activations import PoolerActivation
|
||||
from vllm.model_executor.layers.pooler.activations import (
|
||||
PoolerActivation,
|
||||
PoolerNormalize,
|
||||
resolve_classifier_act_fn,
|
||||
)
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.tasks import POOLING_TASKS, PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
@@ -86,7 +91,14 @@ class TokenPooler(Pooler):
|
||||
|
||||
def pooler_for_token_embed(pooler_config: PoolerConfig):
|
||||
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
head = TokenEmbeddingPoolerHead(
|
||||
projector=_load_st_projector(model_config),
|
||||
head_dtype=model_config.head_dtype,
|
||||
activation=PoolerNormalize(),
|
||||
)
|
||||
|
||||
return TokenPooler(pooling=pooling, head=head)
|
||||
|
||||
@@ -101,6 +113,15 @@ def pooler_for_token_classify(
|
||||
if pooling is None:
|
||||
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
|
||||
|
||||
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
head = TokenClassifierPoolerHead(
|
||||
classifier=classifier,
|
||||
logit_bias=model_config.pooler_config.logit_bias,
|
||||
head_dtype=model_config.head_dtype,
|
||||
activation=resolve_classifier_act_fn(
|
||||
model_config, static_num_labels=False, act_fn=act_fn
|
||||
),
|
||||
)
|
||||
|
||||
return TokenPooler(pooling=pooling, head=head)
|
||||
|
||||
@@ -8,7 +8,7 @@ from torch import nn
|
||||
from transformers import BertConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention.encoder_only_attention import (
|
||||
@@ -24,11 +24,11 @@ from vllm.model_executor.layers.pooler import (
|
||||
Pooler,
|
||||
PoolingParamsUpdate,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.activations import LambdaPoolerActivation
|
||||
from vllm.model_executor.layers.pooler.seqwise import (
|
||||
EmbeddingPoolerHead,
|
||||
SequencePooler,
|
||||
SequencePoolerHeadOutput,
|
||||
SequencePoolerOutput,
|
||||
SequencePoolingMethodOutput,
|
||||
get_seq_pooling_method,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise import (
|
||||
@@ -94,26 +94,32 @@ class BertEmbedding(nn.Module):
|
||||
|
||||
|
||||
class BertPooler(SequencePooler):
|
||||
def __init__(self, config: BertConfig, pooler_config: PoolerConfig):
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
pooler_config = model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
config: BertConfig = model_config.hf_config
|
||||
|
||||
super().__init__(
|
||||
pooling=get_seq_pooling_method(pooler_config.seq_pooling_type),
|
||||
head=self.head,
|
||||
# We set this dummy to avoid adding parameters to nn.Module too early
|
||||
head=nn.Identity(),
|
||||
)
|
||||
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
head_dtype = model_config.head_dtype
|
||||
self.dense = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
dtype=head_dtype,
|
||||
)
|
||||
self.act_fn = nn.Tanh()
|
||||
|
||||
def head(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
|
||||
pooled_data = self.dense(pooled_data)
|
||||
pooled_data = self.activation(pooled_data)
|
||||
return pooled_data
|
||||
# Use lambdas so that weights are not registered under `self.head`
|
||||
self.head = EmbeddingPoolerHead(
|
||||
projector=lambda x: self.dense(x),
|
||||
head_dtype=head_dtype,
|
||||
activation=LambdaPoolerActivation(self.act_fn),
|
||||
)
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
@@ -449,12 +455,7 @@ class BertPoolingModel(BertModel):
|
||||
embedding_class=embedding_class,
|
||||
)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = BertPooler(config, pooler_config)
|
||||
self.pooler = BertPooler(vllm_config.model_config)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
other_weights, loaded_stacked_params = self._load_weights(weights)
|
||||
|
||||
@@ -466,10 +466,7 @@ class BertWithRope(nn.Module, SupportsQuant):
|
||||
)
|
||||
|
||||
if add_pooling_layer:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = BertPooler(self.config, pooler_config)
|
||||
self.pooler = BertPooler(vllm_config.model_config)
|
||||
else:
|
||||
self.pooler = None
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections.abc import Set
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
DispatchPooler,
|
||||
@@ -13,8 +13,8 @@ from vllm.model_executor.layers.pooler import (
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.activations import PoolerNormalize
|
||||
from vllm.model_executor.layers.pooler.seqwise import (
|
||||
EmbeddingPoolerHead,
|
||||
SequencePooler,
|
||||
SequencePoolerHeadOutput,
|
||||
SequencePoolingMethod,
|
||||
SequencePoolingMethodOutput,
|
||||
get_seq_pooling_method,
|
||||
@@ -178,25 +178,22 @@ class GritLMMeanPool(SequencePoolingMethod):
|
||||
|
||||
|
||||
class GritLMPooler(SequencePooler):
|
||||
def __init__(self, model_config: ModelConfig, pooler_config: PoolerConfig):
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
pooler_config = model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
super().__init__(
|
||||
pooling=(
|
||||
GritLMMeanPool(model_config)
|
||||
if pooler_config.seq_pooling_type == "MEAN"
|
||||
else get_seq_pooling_method(pooler_config.seq_pooling_type)
|
||||
),
|
||||
head=self.head,
|
||||
head=EmbeddingPoolerHead(
|
||||
head_dtype=model_config.head_dtype,
|
||||
activation=PoolerNormalize(),
|
||||
),
|
||||
)
|
||||
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def head(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
return self.activation(pooled_data)
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="MEAN")
|
||||
class GritLM(LlamaForCausalLM):
|
||||
@@ -240,6 +237,6 @@ class GritLM(LlamaForCausalLM):
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_embed": pooler_for_token_embed(pooler_config),
|
||||
"embed": GritLMPooler(vllm_config.model_config, pooler_config),
|
||||
"embed": GritLMPooler(vllm_config.model_config),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -8,17 +8,17 @@ from transformers import ModernBertConfig
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import PoolerConfig, VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.attention.encoder_only_attention import (
|
||||
EncoderOnlyAttention,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.pooler.activations import LambdaPoolerActivation
|
||||
from vllm.model_executor.layers.pooler.seqwise import (
|
||||
EmbeddingPoolerHead,
|
||||
SequencePooler,
|
||||
SequencePoolerHeadOutput,
|
||||
SequencePoolingMethodOutput,
|
||||
get_seq_pooling_method,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
|
||||
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
from .interfaces_base import attn_type, default_pooling_type
|
||||
@@ -282,7 +281,11 @@ class ModernBertModel(nn.Module):
|
||||
|
||||
|
||||
class ModernBertPooler(SequencePooler):
|
||||
def __init__(self, config: ModernBertConfig, pooler_config: PoolerConfig):
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
pooler_config = model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
config: ModernBertConfig = model_config.hf_config
|
||||
hf_pooling_type = config.classifier_pooling.upper()
|
||||
# vllm_pooling_type = pooler_config.seq_pooling_type
|
||||
# Currently we don't have a way to see if the user set the pooling type
|
||||
@@ -290,27 +293,30 @@ class ModernBertPooler(SequencePooler):
|
||||
|
||||
super().__init__(
|
||||
pooling=get_seq_pooling_method(hf_pooling_type),
|
||||
head=self.head,
|
||||
# We set this dummy to avoid adding parameters to nn.Module too early
|
||||
head=nn.Identity(),
|
||||
)
|
||||
|
||||
head_dtype = model_config.head_dtype
|
||||
self.dense = nn.Linear(
|
||||
config.hidden_size, config.hidden_size, config.classifier_bias
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
config.classifier_bias,
|
||||
dtype=head_dtype,
|
||||
)
|
||||
self.act = nn.GELU()
|
||||
self.norm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
|
||||
config.hidden_size,
|
||||
eps=config.norm_eps,
|
||||
bias=config.norm_bias,
|
||||
)
|
||||
|
||||
def head(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
|
||||
pooled_data = pooled_data.to(self.dense.weight.dtype)
|
||||
return self.norm(self.act(self.dense(pooled_data)))
|
||||
# Use lambdas so that weights are not registered under `self.head`
|
||||
self.head = EmbeddingPoolerHead(
|
||||
projector=lambda x: self.dense(x),
|
||||
head_dtype=head_dtype,
|
||||
activation=LambdaPoolerActivation(lambda x: self.norm(self.act(x))),
|
||||
)
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
@@ -335,7 +341,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooling = ModernBertPooler(config, pooler_config)
|
||||
self.pooling = ModernBertPooler(vllm_config.model_config)
|
||||
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config,
|
||||
|
||||
Reference in New Issue
Block a user