[Refactor] Separate sequence and token pooling types (#32026)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-10 12:53:24 +08:00
committed by GitHub
parent 52d428295d
commit 583a90e005
42 changed files with 324 additions and 204 deletions

View File

@@ -357,7 +357,7 @@ class BertOutput(nn.Module):
@support_torch_compile
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class BertModel(nn.Module, SupportsQuant):
is_pooling_model = True
@@ -461,7 +461,7 @@ class BertPoolingModel(BertModel):
return loaded_params
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
@@ -675,7 +675,7 @@ class SPLADESparsePooler(Pooler):
return torch.stack(pooled_list, dim=0).contiguous()
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
"""
BertEmbeddingModel + SPLADE sparse embedding.
@@ -780,7 +780,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return loaded
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
@@ -849,7 +849,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
@attn_type("encoder_only")
@default_pooling_type("ALL")
@default_pooling_type(tok_pooling_type="ALL")
class BertForTokenClassification(nn.Module):
is_pooling_model = True

View File

@@ -441,7 +441,7 @@ class BertWithRopeEncoder(nn.Module):
@support_torch_compile
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class BertWithRope(nn.Module, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
@@ -670,7 +670,7 @@ class JinaRobertaModel(BertWithRope):
return super().load_weights(weights)
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
is_pooling_model = True

View File

@@ -145,7 +145,7 @@ class CLIPProcessingInfo(BaseProcessingInfo):
image_width=image_width,
image_height=image_height,
),
_get_vision_feature_select_strategy(pooler_config.pooling_type),
_get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
)
def get_image_size_with_most_features(self) -> ImageSize:
@@ -819,7 +819,7 @@ class CLIPVisionModel(nn.Module):
# Assume EOS token corresponds to LAST token in text model
@default_pooling_type("LAST")
@default_pooling_type(seq_pooling_type="LAST")
@MULTIMODAL_REGISTRY.register_processor(
CLIPMultiModalProcessor,
info=CLIPProcessingInfo,
@@ -908,7 +908,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
) -> torch.Tensor:
if feature_select_strategy is None:
feature_select_strategy = _get_vision_feature_select_strategy(
self.pooler_config.pooling_type
self.pooler_config.seq_pooling_type
)
pooled_output = self.vision_model(

View File

@@ -94,12 +94,12 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
from vllm.config.pooler import PoolingTypeStr
from vllm.config.pooler import SequencePoolingType
hf_config = model_config.hf_config
hf_config.is_causal = False
pooling_type_map: dict[str, PoolingTypeStr] = {
pooling_type_map: dict[str, SequencePoolingType] = {
"avg": "MEAN",
"cls": "CLS",
"last": "LAST",
@@ -107,8 +107,9 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
pooling_type = pooling_type_map.get(hf_config.pooling, None)
if pooling_type is None:
raise ValueError(f"pool_type {hf_config.pooling} not supported")
model_config.pooler_config.pooling_type = pooling_type
raise ValueError(f"pool_type {hf_config.pooling!r} not supported")
model_config.pooler_config.seq_pooling_type = pooling_type
class NomicBertModelConfig(VerifyAndUpdateConfig):

View File

@@ -193,7 +193,7 @@ class GritLMPooler(SequencePooler):
return self.activation(pooled_data)
@default_pooling_type("MEAN")
@default_pooling_type(seq_pooling_type="MEAN")
class GritLM(LlamaForCausalLM):
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm.

View File

@@ -20,12 +20,13 @@ from vllm.utils.func_utils import supports_kw
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.model import AttnTypeStr
from vllm.config.pooler import PoolingTypeStr
from vllm.config.pooler import SequencePoolingType, TokenPoolingType
from vllm.model_executor.layers.pooler import Pooler
else:
VllmConfig = Any
Pooler = Any
PoolingTypeStr = Any
SequencePoolingType = Any
TokenPoolingType = Any
AttnTypeStr = Any
logger = init_logger(__name__)
@@ -155,9 +156,19 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
MRO of your model class.
"""
default_pooling_type: ClassVar[PoolingTypeStr] = "LAST"
default_seq_pooling_type: ClassVar[SequencePoolingType] = "LAST"
"""
Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
Indicates the [vllm.config.pooler.PoolerConfig.seq_pooling_type][]
to use by default.
You can use the
[vllm.model_executor.models.interfaces_base.default_pooling_type][]
decorator to conveniently set this field.
"""
default_tok_pooling_type: ClassVar[TokenPoolingType] = "ALL"
"""
Indicates the [vllm.config.pooler.PoolerConfig.tok_pooling_type][]
to use by default.
You can use the
@@ -200,18 +211,31 @@ def is_pooling_model(
_T = TypeVar("_T", bound=type[nn.Module])
def default_pooling_type(pooling_type: PoolingTypeStr):
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
def default_pooling_type(
*,
seq_pooling_type: SequencePoolingType = "LAST",
tok_pooling_type: TokenPoolingType = "ALL",
):
"""Decorator to set `VllmModelForPooling.default_*_pooling_type`."""
def func(model: _T) -> _T:
model.default_pooling_type = pooling_type # type: ignore
model.default_seq_pooling_type = seq_pooling_type # type: ignore
model.default_tok_pooling_type = tok_pooling_type # type: ignore
return model
return func
def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr:
return getattr(model, "default_pooling_type", "LAST")
def get_default_seq_pooling_type(
model: type[object] | object,
) -> SequencePoolingType:
return getattr(model, "default_seq_pooling_type", "LAST")
def get_default_tok_pooling_type(
model: type[object] | object,
) -> TokenPoolingType:
return getattr(model, "default_tok_pooling_type", "ALL")
def attn_type(attn_type: AttnTypeStr):

View File

@@ -402,7 +402,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
return loaded_params
@default_pooling_type("ALL")
@default_pooling_type(tok_pooling_type="ALL")
class InternLM2ForRewardModel(InternLM2ForCausalLM):
is_pooling_model = True

View File

@@ -221,7 +221,7 @@ class ModernBertEncoderLayer(nn.Module):
@support_torch_compile
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class ModernBertModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"layers.": "encoder_layer.layers."}
@@ -308,7 +308,7 @@ class ModernBertPooler(SequencePooler):
return self.norm(self.act(self.dense(pooled_data)))
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
is_pooling_model = True
@@ -395,7 +395,7 @@ class ModernBertPredictionHead(nn.Module):
@attn_type("encoder_only")
@default_pooling_type("ALL")
@default_pooling_type(tok_pooling_type="ALL")
class ModernBertForTokenClassification(nn.Module):
is_pooling_model = True

View File

@@ -96,7 +96,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
return loader.load_weights(weights)
@default_pooling_type("ALL")
@default_pooling_type(tok_pooling_type="ALL")
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 1
@@ -108,7 +108,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
self.pooler = pooler_for_token_classify(pooler_config)
@default_pooling_type("STEP")
@default_pooling_type(tok_pooling_type="STEP")
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 2

View File

@@ -35,10 +35,11 @@ from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.config.model import AttnTypeStr
from vllm.config.pooler import PoolingTypeStr
from vllm.config.pooler import SequencePoolingType, TokenPoolingType
else:
AttnTypeStr = Any
PoolingTypeStr = Any
SequencePoolingType = Any
TokenPoolingType = Any
from .interfaces import (
@@ -57,7 +58,8 @@ from .interfaces import (
)
from .interfaces_base import (
get_attn_type,
get_default_pooling_type,
get_default_seq_pooling_type,
get_default_tok_pooling_type,
is_pooling_model,
is_text_generation_model,
)
@@ -548,7 +550,8 @@ class _ModelInfo:
is_text_generation_model: bool
is_pooling_model: bool
attn_type: AttnTypeStr
default_pooling_type: PoolingTypeStr
default_seq_pooling_type: SequencePoolingType
default_tok_pooling_type: TokenPoolingType
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input_only: bool
@@ -569,7 +572,8 @@ class _ModelInfo:
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model),
is_pooling_model=is_pooling_model(model),
default_pooling_type=get_default_pooling_type(model),
default_seq_pooling_type=get_default_seq_pooling_type(model),
default_tok_pooling_type=get_default_tok_pooling_type(model),
attn_type=get_attn_type(model),
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),

View File

@@ -93,7 +93,7 @@ class RobertaClassificationHead(nn.Module):
return x
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities."""
@@ -150,7 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return loader.load_weights(weights_list, mapper=mapper)
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.

View File

@@ -129,7 +129,7 @@ class SiglipProcessingInfo(BaseProcessingInfo):
image_width=image_width,
image_height=image_height,
),
_get_vision_feature_select_strategy(pooler_config.pooling_type),
_get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
)
def get_image_size_with_most_features(self) -> ImageSize:
@@ -998,7 +998,7 @@ class SiglipTextEmbeddings(nn.Module):
# Assume EOS token corresponds to CLS token in text model
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
@MULTIMODAL_REGISTRY.register_processor(
SiglipMultiModalProcessor,
info=SiglipProcessingInfo,
@@ -1125,7 +1125,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
) -> torch.Tensor:
if feature_select_strategy is None:
feature_select_strategy = _get_vision_feature_select_strategy(
self.pooler_config.pooling_type
self.pooler_config.seq_pooling_type
)
pooled_output = self.vision_model(