[Refactor] Separate sequence and token pooling types (#32026)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -357,7 +357,7 @@ class BertOutput(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertModel(nn.Module, SupportsQuant):
|
||||
is_pooling_model = True
|
||||
|
||||
@@ -461,7 +461,7 @@ class BertPoolingModel(BertModel):
|
||||
return loaded_params
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
@@ -675,7 +675,7 @@ class SPLADESparsePooler(Pooler):
|
||||
return torch.stack(pooled_list, dim=0).contiguous()
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
"""
|
||||
BertEmbeddingModel + SPLADE sparse embedding.
|
||||
@@ -780,7 +780,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
return loaded
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
@@ -849,7 +849,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
@default_pooling_type("ALL")
|
||||
@default_pooling_type(tok_pooling_type="ALL")
|
||||
class BertForTokenClassification(nn.Module):
|
||||
is_pooling_model = True
|
||||
|
||||
|
||||
@@ -441,7 +441,7 @@ class BertWithRopeEncoder(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertWithRope(nn.Module, SupportsQuant):
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
|
||||
@@ -670,7 +670,7 @@ class JinaRobertaModel(BertWithRope):
|
||||
return super().load_weights(weights)
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
is_pooling_model = True
|
||||
|
||||
|
||||
@@ -145,7 +145,7 @@ class CLIPProcessingInfo(BaseProcessingInfo):
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
_get_vision_feature_select_strategy(pooler_config.pooling_type),
|
||||
_get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
|
||||
)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
@@ -819,7 +819,7 @@ class CLIPVisionModel(nn.Module):
|
||||
|
||||
|
||||
# Assume EOS token corresponds to LAST token in text model
|
||||
@default_pooling_type("LAST")
|
||||
@default_pooling_type(seq_pooling_type="LAST")
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
CLIPMultiModalProcessor,
|
||||
info=CLIPProcessingInfo,
|
||||
@@ -908,7 +908,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
) -> torch.Tensor:
|
||||
if feature_select_strategy is None:
|
||||
feature_select_strategy = _get_vision_feature_select_strategy(
|
||||
self.pooler_config.pooling_type
|
||||
self.pooler_config.seq_pooling_type
|
||||
)
|
||||
|
||||
pooled_output = self.vision_model(
|
||||
|
||||
@@ -94,12 +94,12 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.config.pooler import SequencePoolingType
|
||||
|
||||
hf_config = model_config.hf_config
|
||||
hf_config.is_causal = False
|
||||
|
||||
pooling_type_map: dict[str, PoolingTypeStr] = {
|
||||
pooling_type_map: dict[str, SequencePoolingType] = {
|
||||
"avg": "MEAN",
|
||||
"cls": "CLS",
|
||||
"last": "LAST",
|
||||
@@ -107,8 +107,9 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
|
||||
pooling_type = pooling_type_map.get(hf_config.pooling, None)
|
||||
if pooling_type is None:
|
||||
raise ValueError(f"pool_type {hf_config.pooling} not supported")
|
||||
model_config.pooler_config.pooling_type = pooling_type
|
||||
raise ValueError(f"pool_type {hf_config.pooling!r} not supported")
|
||||
|
||||
model_config.pooler_config.seq_pooling_type = pooling_type
|
||||
|
||||
|
||||
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@@ -193,7 +193,7 @@ class GritLMPooler(SequencePooler):
|
||||
return self.activation(pooled_data)
|
||||
|
||||
|
||||
@default_pooling_type("MEAN")
|
||||
@default_pooling_type(seq_pooling_type="MEAN")
|
||||
class GritLM(LlamaForCausalLM):
|
||||
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
|
||||
|
||||
|
||||
@@ -20,12 +20,13 @@ from vllm.utils.func_utils import supports_kw
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.model import AttnTypeStr
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.config.pooler import SequencePoolingType, TokenPoolingType
|
||||
from vllm.model_executor.layers.pooler import Pooler
|
||||
else:
|
||||
VllmConfig = Any
|
||||
Pooler = Any
|
||||
PoolingTypeStr = Any
|
||||
SequencePoolingType = Any
|
||||
TokenPoolingType = Any
|
||||
AttnTypeStr = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -155,9 +156,19 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
default_pooling_type: ClassVar[PoolingTypeStr] = "LAST"
|
||||
default_seq_pooling_type: ClassVar[SequencePoolingType] = "LAST"
|
||||
"""
|
||||
Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
|
||||
Indicates the [vllm.config.pooler.PoolerConfig.seq_pooling_type][]
|
||||
to use by default.
|
||||
|
||||
You can use the
|
||||
[vllm.model_executor.models.interfaces_base.default_pooling_type][]
|
||||
decorator to conveniently set this field.
|
||||
"""
|
||||
|
||||
default_tok_pooling_type: ClassVar[TokenPoolingType] = "ALL"
|
||||
"""
|
||||
Indicates the [vllm.config.pooler.PoolerConfig.tok_pooling_type][]
|
||||
to use by default.
|
||||
|
||||
You can use the
|
||||
@@ -200,18 +211,31 @@ def is_pooling_model(
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
|
||||
def default_pooling_type(pooling_type: PoolingTypeStr):
|
||||
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
|
||||
def default_pooling_type(
|
||||
*,
|
||||
seq_pooling_type: SequencePoolingType = "LAST",
|
||||
tok_pooling_type: TokenPoolingType = "ALL",
|
||||
):
|
||||
"""Decorator to set `VllmModelForPooling.default_*_pooling_type`."""
|
||||
|
||||
def func(model: _T) -> _T:
|
||||
model.default_pooling_type = pooling_type # type: ignore
|
||||
model.default_seq_pooling_type = seq_pooling_type # type: ignore
|
||||
model.default_tok_pooling_type = tok_pooling_type # type: ignore
|
||||
return model
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr:
|
||||
return getattr(model, "default_pooling_type", "LAST")
|
||||
def get_default_seq_pooling_type(
|
||||
model: type[object] | object,
|
||||
) -> SequencePoolingType:
|
||||
return getattr(model, "default_seq_pooling_type", "LAST")
|
||||
|
||||
|
||||
def get_default_tok_pooling_type(
|
||||
model: type[object] | object,
|
||||
) -> TokenPoolingType:
|
||||
return getattr(model, "default_tok_pooling_type", "ALL")
|
||||
|
||||
|
||||
def attn_type(attn_type: AttnTypeStr):
|
||||
|
||||
@@ -402,7 +402,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
return loaded_params
|
||||
|
||||
|
||||
@default_pooling_type("ALL")
|
||||
@default_pooling_type(tok_pooling_type="ALL")
|
||||
class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
is_pooling_model = True
|
||||
|
||||
|
||||
@@ -221,7 +221,7 @@ class ModernBertEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class ModernBertModel(nn.Module):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={"layers.": "encoder_layer.layers."}
|
||||
@@ -308,7 +308,7 @@ class ModernBertPooler(SequencePooler):
|
||||
return self.norm(self.act(self.dense(pooled_data)))
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
is_pooling_model = True
|
||||
|
||||
@@ -395,7 +395,7 @@ class ModernBertPredictionHead(nn.Module):
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
@default_pooling_type("ALL")
|
||||
@default_pooling_type(tok_pooling_type="ALL")
|
||||
class ModernBertForTokenClassification(nn.Module):
|
||||
is_pooling_model = True
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
@default_pooling_type("ALL")
|
||||
@default_pooling_type(tok_pooling_type="ALL")
|
||||
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
vllm_config.model_config.hf_config.num_labels = 1
|
||||
@@ -108,7 +108,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
self.pooler = pooler_for_token_classify(pooler_config)
|
||||
|
||||
|
||||
@default_pooling_type("STEP")
|
||||
@default_pooling_type(tok_pooling_type="STEP")
|
||||
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
vllm_config.model_config.hf_config.num_labels = 2
|
||||
|
||||
@@ -35,10 +35,11 @@ from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.model import AttnTypeStr
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.config.pooler import SequencePoolingType, TokenPoolingType
|
||||
else:
|
||||
AttnTypeStr = Any
|
||||
PoolingTypeStr = Any
|
||||
SequencePoolingType = Any
|
||||
TokenPoolingType = Any
|
||||
|
||||
|
||||
from .interfaces import (
|
||||
@@ -57,7 +58,8 @@ from .interfaces import (
|
||||
)
|
||||
from .interfaces_base import (
|
||||
get_attn_type,
|
||||
get_default_pooling_type,
|
||||
get_default_seq_pooling_type,
|
||||
get_default_tok_pooling_type,
|
||||
is_pooling_model,
|
||||
is_text_generation_model,
|
||||
)
|
||||
@@ -548,7 +550,8 @@ class _ModelInfo:
|
||||
is_text_generation_model: bool
|
||||
is_pooling_model: bool
|
||||
attn_type: AttnTypeStr
|
||||
default_pooling_type: PoolingTypeStr
|
||||
default_seq_pooling_type: SequencePoolingType
|
||||
default_tok_pooling_type: TokenPoolingType
|
||||
supports_cross_encoding: bool
|
||||
supports_multimodal: bool
|
||||
supports_multimodal_raw_input_only: bool
|
||||
@@ -569,7 +572,8 @@ class _ModelInfo:
|
||||
architecture=model.__name__,
|
||||
is_text_generation_model=is_text_generation_model(model),
|
||||
is_pooling_model=is_pooling_model(model),
|
||||
default_pooling_type=get_default_pooling_type(model),
|
||||
default_seq_pooling_type=get_default_seq_pooling_type(model),
|
||||
default_tok_pooling_type=get_default_tok_pooling_type(model),
|
||||
attn_type=get_attn_type(model),
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
|
||||
@@ -93,7 +93,7 @@ class RobertaClassificationHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
"""A model that uses Roberta to provide embedding functionalities."""
|
||||
|
||||
@@ -150,7 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
return loader.load_weights(weights_list, mapper=mapper)
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ class SiglipProcessingInfo(BaseProcessingInfo):
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
_get_vision_feature_select_strategy(pooler_config.pooling_type),
|
||||
_get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
|
||||
)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
@@ -998,7 +998,7 @@ class SiglipTextEmbeddings(nn.Module):
|
||||
|
||||
|
||||
# Assume EOS token corresponds to CLS token in text model
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
SiglipMultiModalProcessor,
|
||||
info=SiglipProcessingInfo,
|
||||
@@ -1125,7 +1125,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
) -> torch.Tensor:
|
||||
if feature_select_strategy is None:
|
||||
feature_select_strategy = _get_vision_feature_select_strategy(
|
||||
self.pooler_config.pooling_type
|
||||
self.pooler_config.seq_pooling_type
|
||||
)
|
||||
|
||||
pooled_output = self.vision_model(
|
||||
|
||||
Reference in New Issue
Block a user