[Model] Avoid hardcoding pooling type (#32119)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -25,11 +25,11 @@ from vllm.model_executor.layers.pooler import (
|
||||
PoolingParamsUpdate,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.seqwise import (
|
||||
CLSPool,
|
||||
SequencePooler,
|
||||
SequencePoolerHeadOutput,
|
||||
SequencePoolerOutput,
|
||||
SequencePoolingMethodOutput,
|
||||
get_seq_pooling_method,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise import (
|
||||
pooler_for_token_classify,
|
||||
@@ -94,9 +94,9 @@ class BertEmbedding(nn.Module):
|
||||
|
||||
|
||||
class BertPooler(SequencePooler):
|
||||
def __init__(self, config: BertConfig):
|
||||
def __init__(self, config: BertConfig, pooler_config: PoolerConfig):
|
||||
super().__init__(
|
||||
pooling=CLSPool(),
|
||||
pooling=get_seq_pooling_method(pooler_config.seq_pooling_type),
|
||||
head=self.head,
|
||||
)
|
||||
|
||||
@@ -450,7 +450,11 @@ class BertPoolingModel(BertModel):
|
||||
)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = BertPooler(config, pooler_config)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
other_weights, loaded_stacked_params = self._load_weights(weights)
|
||||
@@ -711,6 +715,8 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
|
||||
)
|
||||
|
||||
# None of vLLM's built-in sequence pooling types are
|
||||
# applicable so it is overwritten by SPLADESparsePooler
|
||||
pooling_mode = getattr(self, "_splade_pooling", "max")
|
||||
|
||||
cls_id = getattr(cfg, "cls_token_id", None)
|
||||
|
||||
@@ -453,6 +453,7 @@ class BertWithRope(nn.Module, SupportsQuant):
|
||||
add_pooling_layer: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.add_pooling_layer = add_pooling_layer
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
@@ -463,7 +464,14 @@ class BertWithRope(nn.Module, SupportsQuant):
|
||||
rotary_kwargs=self.config.rotary_kwargs,
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
self.pooler = BertPooler(self.config) if add_pooling_layer else None
|
||||
|
||||
if add_pooling_layer:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = BertPooler(self.config, pooler_config)
|
||||
else:
|
||||
self.pooler = None
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections.abc import Set
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
DispatchPooler,
|
||||
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.pooler.seqwise import (
|
||||
SequencePoolerHeadOutput,
|
||||
SequencePoolingMethod,
|
||||
SequencePoolingMethodOutput,
|
||||
get_seq_pooling_method,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
@@ -177,9 +178,13 @@ class GritLMMeanPool(SequencePoolingMethod):
|
||||
|
||||
|
||||
class GritLMPooler(SequencePooler):
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
def __init__(self, model_config: ModelConfig, pooler_config: PoolerConfig):
|
||||
super().__init__(
|
||||
pooling=GritLMMeanPool(model_config),
|
||||
pooling=(
|
||||
GritLMMeanPool(model_config)
|
||||
if pooler_config.seq_pooling_type == "MEAN"
|
||||
else get_seq_pooling_method(pooler_config.seq_pooling_type)
|
||||
),
|
||||
head=self.head,
|
||||
)
|
||||
|
||||
@@ -235,6 +240,6 @@ class GritLM(LlamaForCausalLM):
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_embed": pooler_for_token_embed(pooler_config),
|
||||
"embed": GritLMPooler(vllm_config.model_config),
|
||||
"embed": GritLMPooler(vllm_config.model_config, pooler_config),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from transformers import ModernBertConfig
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.attention.encoder_only_attention import (
|
||||
EncoderOnlyAttention,
|
||||
@@ -282,9 +282,14 @@ class ModernBertModel(nn.Module):
|
||||
|
||||
|
||||
class ModernBertPooler(SequencePooler):
|
||||
def __init__(self, config: ModernBertConfig):
|
||||
def __init__(self, config: ModernBertConfig, pooler_config: PoolerConfig):
|
||||
hf_pooling_type = config.classifier_pooling.upper()
|
||||
# vllm_pooling_type = pooler_config.seq_pooling_type
|
||||
# Currently we don't have a way to see if the user set the pooling type
|
||||
# explicitly or not, so we always use the HF pooling type for now.
|
||||
|
||||
super().__init__(
|
||||
pooling=get_seq_pooling_method(config.classifier_pooling.upper()),
|
||||
pooling=get_seq_pooling_method(hf_pooling_type),
|
||||
head=self.head,
|
||||
)
|
||||
|
||||
@@ -314,7 +319,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.config = config
|
||||
self.model = ModernBertModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
|
||||
@@ -324,11 +331,12 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
config.num_labels,
|
||||
dtype=vllm_config.model_config.head_dtype,
|
||||
)
|
||||
self.pooling = ModernBertPooler(config)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooling = ModernBertPooler(config, pooler_config)
|
||||
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config,
|
||||
pooling=self.pooling,
|
||||
|
||||
@@ -9,7 +9,6 @@ from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.pooler.seqwise import CLSPool
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.models.bert import (
|
||||
TOKEN_TYPE_SHIFT,
|
||||
@@ -86,7 +85,7 @@ class RobertaClassificationHead(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# CLSPool has already been applied in `pooling`
|
||||
# Token extraction has already been applied in `pooler.pooling`
|
||||
x = self.dense(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.out_proj(x)
|
||||
@@ -194,7 +193,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config,
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from transformers import AutoModelForSequenceClassification
|
||||
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.pooler.seqwise import CLSPool
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
|
||||
|
||||
@@ -32,7 +31,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class EmbeddingMixin(VllmModelForPooling):
|
||||
default_pooling_type = "CLS"
|
||||
default_seq_pooling_type = "CLS"
|
||||
|
||||
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
# Skip VllmModelForPooling.__init__ and call the next class in MRO
|
||||
@@ -47,7 +46,7 @@ class EmbeddingMixin(VllmModelForPooling):
|
||||
|
||||
|
||||
class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
|
||||
default_pooling_type = "CLS"
|
||||
default_seq_pooling_type = "CLS"
|
||||
|
||||
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
# Skip VllmModelForPooling.__init__ and call the next class in MRO
|
||||
@@ -85,8 +84,10 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
|
||||
self.init_parameters(self.classifier, dtype=self.model_config.head_dtype)
|
||||
|
||||
class ClassifierWithReshape(self.classifier.__class__):
|
||||
"""CLSPool has already been applied in `pooling`.
|
||||
Add dim to match expected input shape of `classifier.forward`."""
|
||||
"""
|
||||
Token extraction has already been applied in `pooler.pooling`.
|
||||
Add dim to match expected input shape of `classifier.forward`.
|
||||
"""
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if len(args) > 0:
|
||||
@@ -97,6 +98,5 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
|
||||
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config,
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user