[Refactor] Separate sequence and token pooling types (#32026)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -46,7 +46,8 @@ def test_model_loading_with_params(vllm_runner, monkeypatch):
|
||||
assert model_config.encoder_config["do_lower_case"]
|
||||
|
||||
# asserts on the pooling config files
|
||||
assert model_config.pooler_config.pooling_type == "CLS"
|
||||
assert model_config.pooler_config.seq_pooling_type == "CLS"
|
||||
assert model_config.pooler_config.tok_pooling_type == "ALL"
|
||||
assert model_config.pooler_config.normalize
|
||||
|
||||
# asserts on the tokenizer loaded
|
||||
@@ -90,7 +91,8 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
|
||||
assert not model_config.encoder_config["do_lower_case"]
|
||||
|
||||
# asserts on the pooling config files
|
||||
assert model_config.pooler_config.pooling_type == "MEAN"
|
||||
assert model_config.pooler_config.seq_pooling_type == "MEAN"
|
||||
assert model_config.pooler_config.tok_pooling_type == "ALL"
|
||||
assert model_config.pooler_config.normalize
|
||||
|
||||
# asserts on the tokenizer loaded
|
||||
|
||||
@@ -54,7 +54,7 @@ def test_models(
|
||||
vllm_extra_kwargs = {}
|
||||
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
|
||||
vllm_extra_kwargs["pooler_config"] = PoolerConfig(
|
||||
pooling_type="MEAN", normalize=False
|
||||
seq_pooling_type="MEAN", normalize=False
|
||||
)
|
||||
|
||||
max_model_len: int | None = 512
|
||||
|
||||
@@ -88,7 +88,7 @@ def test_gemma_multimodal(
|
||||
convert="classify",
|
||||
load_format="auto",
|
||||
hf_overrides=update_config,
|
||||
pooler_config=PoolerConfig(pooling_type="LAST"),
|
||||
pooler_config=PoolerConfig(seq_pooling_type="LAST"),
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1,
|
||||
|
||||
@@ -162,8 +162,11 @@ def mteb_test_embed_models(
|
||||
assert model_info.architecture in model_config.architectures
|
||||
|
||||
# Confirm whether the important configs in model_config are correct.
|
||||
if model_info.pooling_type is not None:
|
||||
assert model_config.pooler_config.pooling_type == model_info.pooling_type
|
||||
pooler_config = model_config.pooler_config
|
||||
if model_info.seq_pooling_type is not None:
|
||||
assert pooler_config.seq_pooling_type == model_info.seq_pooling_type
|
||||
if model_info.tok_pooling_type is not None:
|
||||
assert pooler_config.tok_pooling_type == model_info.tok_pooling_type
|
||||
if model_info.attn_type is not None:
|
||||
assert model_config.attn_type == model_info.attn_type
|
||||
if model_info.is_prefix_caching_supported is not None:
|
||||
|
||||
@@ -254,8 +254,11 @@ def mteb_test_rerank_models(
|
||||
assert model_config.hf_config.num_labels == 1
|
||||
|
||||
# Confirm whether the important configs in model_config are correct.
|
||||
if model_info.pooling_type is not None:
|
||||
assert model_config.pooler_config.pooling_type == model_info.pooling_type
|
||||
pooler_config = model_config.pooler_config
|
||||
if model_info.seq_pooling_type is not None:
|
||||
assert pooler_config.seq_pooling_type == model_info.seq_pooling_type
|
||||
if model_info.tok_pooling_type is not None:
|
||||
assert pooler_config.tok_pooling_type == model_info.tok_pooling_type
|
||||
if model_info.attn_type is not None:
|
||||
assert model_config.attn_type == model_info.attn_type
|
||||
if model_info.is_prefix_caching_supported is not None:
|
||||
|
||||
@@ -17,7 +17,7 @@ MODELS = [
|
||||
"BAAI/bge-base-en",
|
||||
architecture="BertModel",
|
||||
mteb_score=0.779336792,
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -54,7 +54,7 @@ MODELS = [
|
||||
"BAAI/bge-m3",
|
||||
architecture="XLMRobertaModel",
|
||||
mteb_score=0.787343078,
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -65,7 +65,7 @@ MODELS = [
|
||||
"BAAI/bge-code-v1",
|
||||
architecture="Qwen2Model",
|
||||
mteb_score=0.75724465,
|
||||
pooling_type="LAST",
|
||||
seq_pooling_type="LAST",
|
||||
attn_type="decoder",
|
||||
is_prefix_caching_supported=True,
|
||||
is_chunked_prefill_supported=True,
|
||||
@@ -79,7 +79,7 @@ RERANK_MODELS = [
|
||||
"BAAI/bge-reranker-base",
|
||||
architecture="XLMRobertaForSequenceClassification",
|
||||
mteb_score=0.32398,
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
|
||||
@@ -26,7 +26,7 @@ RERANK_MODELS = [
|
||||
"method": "no_post_processing",
|
||||
},
|
||||
mteb_score=0.33757,
|
||||
pooling_type="LAST",
|
||||
seq_pooling_type="LAST",
|
||||
attn_type="decoder",
|
||||
is_prefix_caching_supported=True,
|
||||
is_chunked_prefill_supported=True,
|
||||
|
||||
@@ -12,7 +12,7 @@ RERANK_MODELS = [
|
||||
RerankModelInfo(
|
||||
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
||||
architecture="BertForSequenceClassification",
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -21,7 +21,7 @@ RERANK_MODELS = [
|
||||
RerankModelInfo(
|
||||
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
pooling_type="LAST",
|
||||
seq_pooling_type="LAST",
|
||||
attn_type="decoder",
|
||||
is_prefix_caching_supported=True,
|
||||
is_chunked_prefill_supported=True,
|
||||
|
||||
@@ -18,7 +18,7 @@ MODELS = [
|
||||
"thenlper/gte-large",
|
||||
mteb_score=0.76807651,
|
||||
architecture="BertModel",
|
||||
pooling_type="MEAN",
|
||||
seq_pooling_type="MEAN",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -44,7 +44,7 @@ MODELS = [
|
||||
architecture="GteNewModel",
|
||||
mteb_score=0.775074696,
|
||||
hf_overrides={"architectures": ["GteNewModel"]},
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -67,7 +67,7 @@ MODELS = [
|
||||
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
mteb_score=0.758473459018872,
|
||||
architecture="Qwen2ForCausalLM",
|
||||
pooling_type="LAST",
|
||||
seq_pooling_type="LAST",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -78,7 +78,7 @@ MODELS = [
|
||||
"Alibaba-NLP/gte-modernbert-base",
|
||||
mteb_score=0.748193353,
|
||||
architecture="ModernBertModel",
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -89,7 +89,7 @@ MODELS = [
|
||||
"Qwen/Qwen3-Embedding-0.6B",
|
||||
mteb_score=0.771163695,
|
||||
architecture="Qwen3ForCausalLM",
|
||||
pooling_type="LAST",
|
||||
seq_pooling_type="LAST",
|
||||
attn_type="decoder",
|
||||
is_prefix_caching_supported=True,
|
||||
is_chunked_prefill_supported=True,
|
||||
@@ -108,7 +108,7 @@ RERANK_MODELS = [
|
||||
"Alibaba-NLP/gte-reranker-modernbert-base",
|
||||
mteb_score=0.33386,
|
||||
architecture="ModernBertForSequenceClassification",
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -119,7 +119,7 @@ RERANK_MODELS = [
|
||||
mteb_score=0.33062,
|
||||
architecture="GteNewForSequenceClassification",
|
||||
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
|
||||
@@ -13,7 +13,7 @@ MODELS = [
|
||||
"intfloat/e5-small",
|
||||
architecture="BertModel",
|
||||
mteb_score=0.742285423,
|
||||
pooling_type="MEAN",
|
||||
seq_pooling_type="MEAN",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -29,7 +29,7 @@ MODELS = [
|
||||
"intfloat/multilingual-e5-base",
|
||||
architecture="XLMRobertaModel",
|
||||
mteb_score=0.779325955,
|
||||
pooling_type="MEAN",
|
||||
seq_pooling_type="MEAN",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
|
||||
@@ -24,7 +24,7 @@ EMBEDDING_MODELS = [
|
||||
mteb_score=0.824413164,
|
||||
architecture="XLMRobertaModel",
|
||||
is_matryoshka=True,
|
||||
pooling_type="MEAN",
|
||||
seq_pooling_type="MEAN",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -36,7 +36,7 @@ RERANK_MODELS = [
|
||||
"jinaai/jina-reranker-v2-base-multilingual",
|
||||
mteb_score=0.33643,
|
||||
architecture="XLMRobertaForSequenceClassification",
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
|
||||
@@ -24,7 +24,7 @@ RERANK_MODELS = [
|
||||
"mixedbread-ai/mxbai-rerank-base-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
hf_overrides=mxbai_rerank_hf_overrides,
|
||||
pooling_type="LAST",
|
||||
seq_pooling_type="LAST",
|
||||
attn_type="decoder",
|
||||
is_prefix_caching_supported=True,
|
||||
is_chunked_prefill_supported=True,
|
||||
|
||||
@@ -19,7 +19,7 @@ EMBEDDING_MODELS = [
|
||||
"nvidia/llama-nemotron-embed-1b-v2",
|
||||
architecture="LlamaBidirectionalModel",
|
||||
mteb_score=0.689164662128673,
|
||||
pooling_type="MEAN",
|
||||
seq_pooling_type="MEAN",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -32,7 +32,7 @@ RERANK_MODELS = [
|
||||
architecture="LlamaBidirectionalForSequenceClassification",
|
||||
chat_template_name="nemotron-rerank.jinja",
|
||||
mteb_score=0.33994,
|
||||
pooling_type="MEAN",
|
||||
seq_pooling_type="MEAN",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
|
||||
@@ -14,7 +14,7 @@ MODELS = [
|
||||
architecture="NomicBertModel",
|
||||
mteb_score=0.737568559,
|
||||
enable_test=True,
|
||||
pooling_type="MEAN",
|
||||
seq_pooling_type="MEAN",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -32,7 +32,7 @@ MODELS = [
|
||||
architecture="NomicBertModel",
|
||||
mteb_score=0.715488912,
|
||||
enable_test=True,
|
||||
pooling_type="MEAN",
|
||||
seq_pooling_type="MEAN",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
|
||||
@@ -27,7 +27,7 @@ RERANK_MODELS = [
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
hf_overrides=qwen3_reranker_hf_overrides,
|
||||
chat_template_name="qwen3_reranker.jinja",
|
||||
pooling_type="LAST",
|
||||
seq_pooling_type="LAST",
|
||||
attn_type="decoder",
|
||||
is_prefix_caching_supported=True,
|
||||
is_chunked_prefill_supported=True,
|
||||
|
||||
@@ -14,7 +14,7 @@ MODELS = [
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
mteb_score=0.714927797,
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -37,7 +37,7 @@ MODELS = [
|
||||
is_matryoshka=False,
|
||||
architecture="NomicBertModel",
|
||||
mteb_score=0.681146831,
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -54,7 +54,7 @@ MODELS = [
|
||||
is_matryoshka=True,
|
||||
architecture="BertModel",
|
||||
mteb_score=0.649088363,
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -65,7 +65,7 @@ MODELS = [
|
||||
is_matryoshka=True,
|
||||
architecture="XLMRobertaModel",
|
||||
mteb_score=0.712258299,
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -76,7 +76,7 @@ MODELS = [
|
||||
is_matryoshka=True,
|
||||
architecture="GteModel",
|
||||
mteb_score=0.706622444,
|
||||
pooling_type="CLS",
|
||||
seq_pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
|
||||
@@ -14,7 +14,7 @@ ST_PROJECTOR_MODELS = [
|
||||
"TencentBAC/Conan-embedding-v1",
|
||||
architecture="BertModel",
|
||||
mteb_score=0.688611955,
|
||||
pooling_type="MEAN",
|
||||
seq_pooling_type="MEAN",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
@@ -24,7 +24,7 @@ ST_PROJECTOR_MODELS = [
|
||||
"google/embeddinggemma-300m",
|
||||
architecture="Gemma3TextModel",
|
||||
mteb_score=0.7473819294684156,
|
||||
pooling_type="MEAN",
|
||||
seq_pooling_type="MEAN",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
|
||||
@@ -11,6 +11,7 @@ import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.model import AttnTypeStr, ModelConfig, ModelDType, RunnerOption
|
||||
from vllm.config.pooler import SequencePoolingType, TokenPoolingType
|
||||
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
|
||||
from vllm.multimodal.processing import InputProcessingContext
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
@@ -379,7 +380,8 @@ class ModelInfo:
|
||||
max_model_len: int | None = None
|
||||
hf_dtype: str = "float32"
|
||||
hf_overrides: dict[str, Any] | None = None
|
||||
pooling_type: str | None = None
|
||||
seq_pooling_type: SequencePoolingType | None = None
|
||||
tok_pooling_type: TokenPoolingType | None = None
|
||||
attn_type: AttnTypeStr | None = None
|
||||
is_prefix_caching_supported: bool | None = None
|
||||
is_chunked_prefill_supported: bool | None = None
|
||||
|
||||
@@ -161,7 +161,8 @@ def test_get_pooling_config():
|
||||
|
||||
assert model_config.pooler_config is not None
|
||||
assert model_config.pooler_config.normalize
|
||||
assert model_config.pooler_config.pooling_type == "MEAN"
|
||||
assert model_config.pooler_config.seq_pooling_type == "MEAN"
|
||||
assert model_config.pooler_config.tok_pooling_type == "ALL"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -169,7 +170,7 @@ def test_get_pooling_config():
|
||||
)
|
||||
def test_get_pooling_config_from_args():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True)
|
||||
pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True)
|
||||
model_config = ModelConfig(model_id, pooler_config=pooler_config)
|
||||
|
||||
assert asdict(model_config.pooler_config) == asdict(pooler_config)
|
||||
@@ -180,14 +181,25 @@ def test_get_pooling_config_from_args():
|
||||
[
|
||||
("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM
|
||||
("intfloat/e5-small", "CLS", "MEAN"), # BertModel
|
||||
],
|
||||
)
|
||||
def test_default_seq_pooling_type(model_id, default_pooling_type, pooling_type):
|
||||
model_config = ModelConfig(model_id)
|
||||
assert model_config._model_info.default_seq_pooling_type == default_pooling_type
|
||||
assert model_config.pooler_config.seq_pooling_type == pooling_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "default_pooling_type", "pooling_type"),
|
||||
[
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward
|
||||
("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"), # step reward
|
||||
],
|
||||
)
|
||||
def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
|
||||
def test_default_tok_pooling_type(model_id, default_pooling_type, pooling_type):
|
||||
model_config = ModelConfig(model_id)
|
||||
assert model_config._model_info.default_pooling_type == default_pooling_type
|
||||
assert model_config.pooler_config.pooling_type == pooling_type
|
||||
assert model_config._model_info.default_tok_pooling_type == default_pooling_type
|
||||
assert model_config.pooler_config.tok_pooling_type == pooling_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -554,100 +566,100 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
|
||||
"jason9693/Qwen2.5-1.5B-apeach",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support chunked prefill.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-Embedding-0.6B",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support chunked prefill.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen2.5-Math-PRM-7B",
|
||||
"decoder",
|
||||
False,
|
||||
"Pooling models with step pooling does not support chunked prefill.",
|
||||
"Pooling models with causal attn and LAST/STEP pooling do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"internlm/internlm2-1_8b-reward",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and all pooling support chunked prefill.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"BAAI/bge-base-en",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"boltuix/NeuroBERT-NER",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"papluca/xlm-roberta-base-language-detection",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"intfloat/e5-small",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
# multimodal models
|
||||
(
|
||||
"openai/clip-vit-base-patch32",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support chunked prefill.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"google/siglip-base-patch16-224",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
# generate models
|
||||
(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
"decoder",
|
||||
True,
|
||||
"Generative models support chunked prefill.",
|
||||
"Generative models support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
"hybrid",
|
||||
True,
|
||||
"Generative models support chunked prefill.",
|
||||
"Generative models support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"ibm-granite/granite-4.0-h-small",
|
||||
"hybrid",
|
||||
True,
|
||||
"Generative models support chunked prefill.",
|
||||
"Generative models support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"attention_free",
|
||||
True,
|
||||
"Generative models support chunked prefill.",
|
||||
"Generative models support chunked prefill.", # noqa: E501
|
||||
),
|
||||
# encoder_decoder models
|
||||
(
|
||||
"openai/whisper-small",
|
||||
"encoder_decoder",
|
||||
False,
|
||||
"Encoder decoder models does not support chunked prefill.",
|
||||
"Encoder decoder models do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -673,100 +685,100 @@ def test_is_chunked_prefill_supported(
|
||||
"jason9693/Qwen2.5-1.5B-apeach",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support prefix caching.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-Embedding-0.6B",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support prefix caching.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen2.5-Math-PRM-7B",
|
||||
"decoder",
|
||||
False,
|
||||
"Pooling models with step pooling does not support prefix caching.",
|
||||
"Pooling models with causal attn and LAST/STEP pooling do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"internlm/internlm2-1_8b-reward",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and all pooling support prefix caching.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"BAAI/bge-base-en",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"boltuix/NeuroBERT-NER",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"papluca/xlm-roberta-base-language-detection",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"intfloat/e5-small",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
# multimodal models
|
||||
(
|
||||
"openai/clip-vit-base-patch32",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support prefix caching.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"google/siglip-base-patch16-224",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
# generate models
|
||||
(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
"decoder",
|
||||
True,
|
||||
"Generative models support prefix caching.",
|
||||
"Generative models support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
"hybrid",
|
||||
False,
|
||||
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
"Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"ibm-granite/granite-4.0-h-small",
|
||||
"hybrid",
|
||||
False,
|
||||
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
"Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"attention_free",
|
||||
False,
|
||||
"Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
"Attention free models do not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
),
|
||||
# encoder_decoder models
|
||||
(
|
||||
"openai/whisper-small",
|
||||
"encoder_decoder",
|
||||
False,
|
||||
"Encoder decoder models does not support prefix caching.",
|
||||
"Encoder decoder models do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -40,7 +40,7 @@ def test_task():
|
||||
|
||||
def test_embed():
|
||||
task = "embed"
|
||||
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
|
||||
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
|
||||
|
||||
pooling_params = PoolingParams(normalize=None)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
@@ -86,7 +86,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
|
||||
|
||||
@pytest.mark.parametrize("task", ["score", "classify"])
|
||||
def test_classify(task):
|
||||
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
|
||||
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
|
||||
|
||||
pooling_params = PoolingParams(use_activation=None)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
@@ -108,7 +108,7 @@ def test_classify(task):
|
||||
def test_token_embed(pooling_type: str):
|
||||
task = "token_embed"
|
||||
model_config = MockModelConfig(
|
||||
pooler_config=PoolerConfig(pooling_type=pooling_type)
|
||||
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(normalize=None)
|
||||
@@ -134,7 +134,7 @@ def test_token_embed(pooling_type: str):
|
||||
def test_token_classify(pooling_type: str):
|
||||
task = "token_classify"
|
||||
model_config = MockModelConfig(
|
||||
pooler_config=PoolerConfig(pooling_type=pooling_type)
|
||||
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(use_activation=None)
|
||||
|
||||
@@ -539,9 +539,12 @@ class ModelConfig:
|
||||
if getattr(self.pooler_config, k) is None:
|
||||
setattr(self.pooler_config, k, v)
|
||||
|
||||
default_pooling_type = self._model_info.default_pooling_type
|
||||
if self.pooler_config.pooling_type is None:
|
||||
self.pooler_config.pooling_type = default_pooling_type
|
||||
default_seq_pooling_type = self._model_info.default_seq_pooling_type
|
||||
if self.pooler_config.seq_pooling_type is None:
|
||||
self.pooler_config.seq_pooling_type = default_seq_pooling_type
|
||||
default_tok_pooling_type = self._model_info.default_tok_pooling_type
|
||||
if self.pooler_config.tok_pooling_type is None:
|
||||
self.pooler_config.tok_pooling_type = default_tok_pooling_type
|
||||
|
||||
self.dtype: torch.dtype = _get_and_verify_dtype(
|
||||
self.model,
|
||||
@@ -1543,8 +1546,8 @@ class ModelConfig:
|
||||
@property
|
||||
def attn_type(self) -> AttnTypeStr:
|
||||
if self.pooler_config is not None:
|
||||
pooling_type = self._model_info.default_pooling_type.lower()
|
||||
if pooling_type == "cls":
|
||||
seq_pooling_type = self._model_info.default_seq_pooling_type
|
||||
if seq_pooling_type == "CLS":
|
||||
return "encoder_only"
|
||||
else:
|
||||
is_causal = getattr(self.hf_config, "is_causal", True)
|
||||
@@ -1561,89 +1564,102 @@ class ModelConfig:
|
||||
@property
|
||||
def is_chunked_prefill_supported(self) -> bool:
|
||||
attn_type = self.attn_type
|
||||
if self.pooler_config is not None:
|
||||
|
||||
if pooler_config := self.pooler_config:
|
||||
# for pooling models
|
||||
if attn_type == "encoder_only":
|
||||
logger.debug(
|
||||
"Pooling models with bidirectional attn does not support "
|
||||
"chunked prefill."
|
||||
"Pooling models with bidirectional attn "
|
||||
"do not support chunked prefill."
|
||||
)
|
||||
return False
|
||||
elif attn_type == "decoder":
|
||||
pooling_type = self.pooler_config.pooling_type.lower()
|
||||
if pooling_type in ["mean", "step", "cls"]:
|
||||
|
||||
if attn_type == "decoder":
|
||||
if (
|
||||
pooler_config.seq_pooling_type in ("MEAN", "CLS")
|
||||
or pooler_config.tok_pooling_type == "STEP"
|
||||
):
|
||||
logger.debug(
|
||||
"Pooling models with %s pooling does not "
|
||||
"support chunked prefill.",
|
||||
pooling_type,
|
||||
"Pooling models with causal attn and %s/%s pooling "
|
||||
"do not support chunked prefill.",
|
||||
pooler_config.seq_pooling_type,
|
||||
pooler_config.tok_pooling_type,
|
||||
)
|
||||
return False
|
||||
elif pooling_type in ["all", "last"]:
|
||||
else:
|
||||
logger.debug(
|
||||
"Pooling models with causal attn and %s pooling support "
|
||||
"chunked prefill.",
|
||||
pooling_type,
|
||||
"Pooling models with causal attn and %s/%s pooling "
|
||||
"support chunked prefill.",
|
||||
pooler_config.seq_pooling_type,
|
||||
pooler_config.tok_pooling_type,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"{pooling_type=} not supported.")
|
||||
|
||||
# vllm currently does not have pooling models using hybrid,
|
||||
# attention_free or encoder_decoder attn types.
|
||||
return attn_type != "encoder_decoder"
|
||||
else:
|
||||
# for generative models
|
||||
if attn_type == "encoder_decoder":
|
||||
logger.debug("Encoder decoder models does not support chunked prefill.")
|
||||
logger.debug("Encoder decoder models do not support chunked prefill.")
|
||||
return False
|
||||
|
||||
logger.debug("Generative models support chunked prefill.")
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_prefix_caching_supported(self) -> bool:
|
||||
attn_type = self.attn_type
|
||||
if self.pooler_config is not None:
|
||||
|
||||
if pooler_config := self.pooler_config:
|
||||
# for pooling models
|
||||
if attn_type == "encoder_only":
|
||||
logger.debug(
|
||||
"Pooling models with bidirectional attn does not "
|
||||
"support prefix caching."
|
||||
"Pooling models with bidirectional attn "
|
||||
"do not support prefix caching."
|
||||
)
|
||||
return False
|
||||
elif attn_type == "decoder":
|
||||
pooling_type = self.pooler_config.pooling_type.lower()
|
||||
if pooling_type in ["mean", "step", "cls"]:
|
||||
|
||||
if attn_type == "decoder":
|
||||
if (
|
||||
pooler_config.seq_pooling_type in ("MEAN", "CLS")
|
||||
or pooler_config.tok_pooling_type == "STEP"
|
||||
):
|
||||
logger.debug(
|
||||
"Pooling models with %s pooling does not "
|
||||
"support prefix caching.",
|
||||
pooling_type,
|
||||
"Pooling models with causal attn and %s/%s pooling "
|
||||
"do not support prefix caching.",
|
||||
pooler_config.seq_pooling_type,
|
||||
pooler_config.tok_pooling_type,
|
||||
)
|
||||
return False
|
||||
elif pooling_type in ["all", "last"]:
|
||||
else:
|
||||
logger.debug(
|
||||
"Pooling models with causal attn and %s pooling support "
|
||||
"prefix caching.",
|
||||
pooling_type,
|
||||
"Pooling models with causal attn and %s/%s pooling "
|
||||
"support prefix caching.",
|
||||
pooler_config.seq_pooling_type,
|
||||
pooler_config.tok_pooling_type,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"{pooling_type=} not supported.")
|
||||
|
||||
# vllm currently does not have pooling models using hybrid,
|
||||
# attention_free or encoder_decoder attn types.
|
||||
return False
|
||||
else:
|
||||
# for generative models
|
||||
if attn_type == "hybrid":
|
||||
logger.debug(
|
||||
"Hybrid models does not support prefix caching since the feature "
|
||||
"Hybrid models do not support prefix caching since the feature "
|
||||
"is still experimental."
|
||||
)
|
||||
return False
|
||||
elif attn_type == "attention_free":
|
||||
logger.debug(
|
||||
"Attention free models does not support prefix caching since the "
|
||||
"Attention free models do not support prefix caching since the "
|
||||
"feature is still experimental."
|
||||
)
|
||||
return False
|
||||
elif attn_type == "encoder_decoder":
|
||||
logger.debug("Encoder decoder models does not support prefix caching.")
|
||||
logger.debug("Encoder decoder models do not support prefix caching.")
|
||||
return False
|
||||
else: # attn_type == "decoder"
|
||||
logger.debug("Generative models support prefix caching.")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
@@ -11,7 +11,11 @@ from vllm.utils.hashing import safe_hash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
|
||||
SequencePoolingType = Literal["CLS", "LAST", "MEAN"]
|
||||
SEQ_POOLING_TYPES: tuple[SequencePoolingType, ...] = get_args(SequencePoolingType)
|
||||
|
||||
TokenPoolingType = Literal["ALL", "STEP"]
|
||||
TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
|
||||
|
||||
|
||||
@config
|
||||
@@ -19,9 +23,26 @@ PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
|
||||
pooling_type: PoolingTypeStr | None = None
|
||||
pooling_type: SequencePoolingType | TokenPoolingType | None = None
|
||||
"""
|
||||
The pooling method of the pooling model.
|
||||
The pooling method used for pooling.
|
||||
|
||||
If set, `seq_pooling_type` or `tok_pooling_type` are automatically populated
|
||||
with this field. Alternatively, users can set `seq_pooling_type` and
|
||||
`tok_pooling_type` explicitly.
|
||||
|
||||
This field is mainly for user convenience. Internal code should always use
|
||||
`seq_pooling_type` or `tok_pooling_type` instead of `pooling_type`.
|
||||
"""
|
||||
|
||||
seq_pooling_type: SequencePoolingType | None = None
|
||||
"""
|
||||
The pooling method used for sequence pooling.
|
||||
"""
|
||||
|
||||
tok_pooling_type: TokenPoolingType | None = None
|
||||
"""
|
||||
The pooling method used for tokenwise pooling.
|
||||
"""
|
||||
|
||||
## for embeddings models
|
||||
@@ -88,9 +109,40 @@ class PoolerConfig:
|
||||
# raise deprecated warning for softmax and activation
|
||||
self.use_activation = get_use_activation(self)
|
||||
|
||||
def get_pooling_type(self) -> PoolingTypeStr:
|
||||
assert self.pooling_type is not None, "Should be resolved by ModelConfig"
|
||||
return self.pooling_type
|
||||
if pooling_type := self.pooling_type:
|
||||
if self.seq_pooling_type is not None:
|
||||
raise ValueError(
|
||||
"Cannot set both `pooling_type` and `seq_pooling_type`"
|
||||
)
|
||||
if self.tok_pooling_type is not None:
|
||||
raise ValueError(
|
||||
"Cannot set both `pooling_type` and `tok_pooling_type`"
|
||||
)
|
||||
|
||||
if pooling_type in SEQ_POOLING_TYPES:
|
||||
logger.debug(
|
||||
"Resolved `pooling_type=%r` to `seq_pooling_type=%r`.",
|
||||
pooling_type,
|
||||
pooling_type,
|
||||
)
|
||||
self.seq_pooling_type = pooling_type
|
||||
elif pooling_type in TOK_POOLING_TYPES:
|
||||
logger.debug(
|
||||
"Resolved `pooling_type=%r` to `tok_pooling_type=%r`.",
|
||||
pooling_type,
|
||||
pooling_type,
|
||||
)
|
||||
self.tok_pooling_type = pooling_type
|
||||
else:
|
||||
raise NotImplementedError(pooling_type)
|
||||
|
||||
def get_seq_pooling_type(self) -> SequencePoolingType:
|
||||
assert self.seq_pooling_type is not None, "Should be resolved by ModelConfig"
|
||||
return self.seq_pooling_type
|
||||
|
||||
def get_tok_pooling_type(self) -> TokenPoolingType:
|
||||
assert self.tok_pooling_type is not None, "Should be resolved by ModelConfig"
|
||||
return self.tok_pooling_type
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -172,7 +172,7 @@ class LLM:
|
||||
The available overrides depend on the model that is being run.
|
||||
For example, for Phi-3-Vision: `{"num_crops": 4}`.
|
||||
pooler_config: Initialize non-default pooling config for the pooling
|
||||
model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
|
||||
model. e.g. `PoolerConfig(seq_pooling_type="MEAN", normalize=False)`.
|
||||
compilation_config: Either an integer or a dictionary. If it is an
|
||||
integer, it is used as the mode of compilation optimization. If it
|
||||
is a dictionary, it can specify the full compilation configuration.
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import TypeAlias
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.config.pooler import SequencePoolingType
|
||||
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
@@ -82,11 +82,11 @@ class MeanPool(SequencePoolingMethod):
|
||||
) / prompt_lens.unsqueeze(1)
|
||||
|
||||
|
||||
def get_seq_pooling_method(pooling_type: PoolingTypeStr | str):
|
||||
if pooling_type == "LAST":
|
||||
return LastPool()
|
||||
def get_seq_pooling_method(pooling_type: SequencePoolingType | str):
|
||||
if pooling_type == "CLS":
|
||||
return CLSPool()
|
||||
if pooling_type == "LAST":
|
||||
return LastPool()
|
||||
if pooling_type == "MEAN":
|
||||
return MeanPool()
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class SequencePooler(Pooler):
|
||||
|
||||
|
||||
def pooler_for_embed(pooler_config: PoolerConfig):
|
||||
pooling = get_seq_pooling_method(pooler_config.get_pooling_type())
|
||||
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
|
||||
head = EmbeddingPoolerHead()
|
||||
|
||||
return SequencePooler(pooling=pooling, head=head)
|
||||
@@ -99,7 +99,7 @@ def pooler_for_classify(
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
if pooling is None:
|
||||
pooling = get_seq_pooling_method(pooler_config.get_pooling_type())
|
||||
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
|
||||
|
||||
head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.config.pooler import TokenPoolingType
|
||||
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
@@ -113,12 +113,10 @@ class StepPool(AllPool):
|
||||
return pooled_data
|
||||
|
||||
|
||||
def get_tok_pooling_method(pooling_type: PoolingTypeStr | str):
|
||||
def get_tok_pooling_method(pooling_type: TokenPoolingType | str):
|
||||
if pooling_type == "ALL":
|
||||
return AllPool()
|
||||
if pooling_type == "STEP":
|
||||
return StepPool()
|
||||
|
||||
# TODO: Separate seq and tok pooling types so we don't need this fallback
|
||||
return AllPool()
|
||||
raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}")
|
||||
|
||||
@@ -85,7 +85,7 @@ class TokenPooler(Pooler):
|
||||
|
||||
|
||||
def pooler_for_token_embed(pooler_config: PoolerConfig):
|
||||
pooling = get_tok_pooling_method(pooler_config.get_pooling_type())
|
||||
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
|
||||
return TokenPooler(pooling=pooling, head=head)
|
||||
@@ -99,7 +99,7 @@ def pooler_for_token_classify(
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
if pooling is None:
|
||||
pooling = get_tok_pooling_method(pooler_config.get_pooling_type())
|
||||
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
|
||||
|
||||
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
|
||||
|
||||
|
||||
@@ -357,7 +357,7 @@ class BertOutput(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertModel(nn.Module, SupportsQuant):
|
||||
is_pooling_model = True
|
||||
|
||||
@@ -461,7 +461,7 @@ class BertPoolingModel(BertModel):
|
||||
return loaded_params
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
@@ -675,7 +675,7 @@ class SPLADESparsePooler(Pooler):
|
||||
return torch.stack(pooled_list, dim=0).contiguous()
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
"""
|
||||
BertEmbeddingModel + SPLADE sparse embedding.
|
||||
@@ -780,7 +780,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
return loaded
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
@@ -849,7 +849,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
@default_pooling_type("ALL")
|
||||
@default_pooling_type(tok_pooling_type="ALL")
|
||||
class BertForTokenClassification(nn.Module):
|
||||
is_pooling_model = True
|
||||
|
||||
|
||||
@@ -441,7 +441,7 @@ class BertWithRopeEncoder(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertWithRope(nn.Module, SupportsQuant):
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
|
||||
@@ -670,7 +670,7 @@ class JinaRobertaModel(BertWithRope):
|
||||
return super().load_weights(weights)
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
is_pooling_model = True
|
||||
|
||||
|
||||
@@ -145,7 +145,7 @@ class CLIPProcessingInfo(BaseProcessingInfo):
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
_get_vision_feature_select_strategy(pooler_config.pooling_type),
|
||||
_get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
|
||||
)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
@@ -819,7 +819,7 @@ class CLIPVisionModel(nn.Module):
|
||||
|
||||
|
||||
# Assume EOS token corresponds to LAST token in text model
|
||||
@default_pooling_type("LAST")
|
||||
@default_pooling_type(seq_pooling_type="LAST")
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
CLIPMultiModalProcessor,
|
||||
info=CLIPProcessingInfo,
|
||||
@@ -908,7 +908,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
) -> torch.Tensor:
|
||||
if feature_select_strategy is None:
|
||||
feature_select_strategy = _get_vision_feature_select_strategy(
|
||||
self.pooler_config.pooling_type
|
||||
self.pooler_config.seq_pooling_type
|
||||
)
|
||||
|
||||
pooled_output = self.vision_model(
|
||||
|
||||
@@ -94,12 +94,12 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.config.pooler import SequencePoolingType
|
||||
|
||||
hf_config = model_config.hf_config
|
||||
hf_config.is_causal = False
|
||||
|
||||
pooling_type_map: dict[str, PoolingTypeStr] = {
|
||||
pooling_type_map: dict[str, SequencePoolingType] = {
|
||||
"avg": "MEAN",
|
||||
"cls": "CLS",
|
||||
"last": "LAST",
|
||||
@@ -107,8 +107,9 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
|
||||
pooling_type = pooling_type_map.get(hf_config.pooling, None)
|
||||
if pooling_type is None:
|
||||
raise ValueError(f"pool_type {hf_config.pooling} not supported")
|
||||
model_config.pooler_config.pooling_type = pooling_type
|
||||
raise ValueError(f"pool_type {hf_config.pooling!r} not supported")
|
||||
|
||||
model_config.pooler_config.seq_pooling_type = pooling_type
|
||||
|
||||
|
||||
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@@ -193,7 +193,7 @@ class GritLMPooler(SequencePooler):
|
||||
return self.activation(pooled_data)
|
||||
|
||||
|
||||
@default_pooling_type("MEAN")
|
||||
@default_pooling_type(seq_pooling_type="MEAN")
|
||||
class GritLM(LlamaForCausalLM):
|
||||
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
|
||||
|
||||
|
||||
@@ -20,12 +20,13 @@ from vllm.utils.func_utils import supports_kw
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.model import AttnTypeStr
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.config.pooler import SequencePoolingType, TokenPoolingType
|
||||
from vllm.model_executor.layers.pooler import Pooler
|
||||
else:
|
||||
VllmConfig = Any
|
||||
Pooler = Any
|
||||
PoolingTypeStr = Any
|
||||
SequencePoolingType = Any
|
||||
TokenPoolingType = Any
|
||||
AttnTypeStr = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -155,9 +156,19 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
default_pooling_type: ClassVar[PoolingTypeStr] = "LAST"
|
||||
default_seq_pooling_type: ClassVar[SequencePoolingType] = "LAST"
|
||||
"""
|
||||
Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
|
||||
Indicates the [vllm.config.pooler.PoolerConfig.seq_pooling_type][]
|
||||
to use by default.
|
||||
|
||||
You can use the
|
||||
[vllm.model_executor.models.interfaces_base.default_pooling_type][]
|
||||
decorator to conveniently set this field.
|
||||
"""
|
||||
|
||||
default_tok_pooling_type: ClassVar[TokenPoolingType] = "ALL"
|
||||
"""
|
||||
Indicates the [vllm.config.pooler.PoolerConfig.tok_pooling_type][]
|
||||
to use by default.
|
||||
|
||||
You can use the
|
||||
@@ -200,18 +211,31 @@ def is_pooling_model(
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
|
||||
def default_pooling_type(pooling_type: PoolingTypeStr):
|
||||
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
|
||||
def default_pooling_type(
|
||||
*,
|
||||
seq_pooling_type: SequencePoolingType = "LAST",
|
||||
tok_pooling_type: TokenPoolingType = "ALL",
|
||||
):
|
||||
"""Decorator to set `VllmModelForPooling.default_*_pooling_type`."""
|
||||
|
||||
def func(model: _T) -> _T:
|
||||
model.default_pooling_type = pooling_type # type: ignore
|
||||
model.default_seq_pooling_type = seq_pooling_type # type: ignore
|
||||
model.default_tok_pooling_type = tok_pooling_type # type: ignore
|
||||
return model
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr:
|
||||
return getattr(model, "default_pooling_type", "LAST")
|
||||
def get_default_seq_pooling_type(
|
||||
model: type[object] | object,
|
||||
) -> SequencePoolingType:
|
||||
return getattr(model, "default_seq_pooling_type", "LAST")
|
||||
|
||||
|
||||
def get_default_tok_pooling_type(
|
||||
model: type[object] | object,
|
||||
) -> TokenPoolingType:
|
||||
return getattr(model, "default_tok_pooling_type", "ALL")
|
||||
|
||||
|
||||
def attn_type(attn_type: AttnTypeStr):
|
||||
|
||||
@@ -402,7 +402,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
return loaded_params
|
||||
|
||||
|
||||
@default_pooling_type("ALL")
|
||||
@default_pooling_type(tok_pooling_type="ALL")
|
||||
class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
is_pooling_model = True
|
||||
|
||||
|
||||
@@ -221,7 +221,7 @@ class ModernBertEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class ModernBertModel(nn.Module):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={"layers.": "encoder_layer.layers."}
|
||||
@@ -308,7 +308,7 @@ class ModernBertPooler(SequencePooler):
|
||||
return self.norm(self.act(self.dense(pooled_data)))
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
is_pooling_model = True
|
||||
|
||||
@@ -395,7 +395,7 @@ class ModernBertPredictionHead(nn.Module):
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
@default_pooling_type("ALL")
|
||||
@default_pooling_type(tok_pooling_type="ALL")
|
||||
class ModernBertForTokenClassification(nn.Module):
|
||||
is_pooling_model = True
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
@default_pooling_type("ALL")
|
||||
@default_pooling_type(tok_pooling_type="ALL")
|
||||
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
vllm_config.model_config.hf_config.num_labels = 1
|
||||
@@ -108,7 +108,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
self.pooler = pooler_for_token_classify(pooler_config)
|
||||
|
||||
|
||||
@default_pooling_type("STEP")
|
||||
@default_pooling_type(tok_pooling_type="STEP")
|
||||
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
vllm_config.model_config.hf_config.num_labels = 2
|
||||
|
||||
@@ -35,10 +35,11 @@ from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.model import AttnTypeStr
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.config.pooler import SequencePoolingType, TokenPoolingType
|
||||
else:
|
||||
AttnTypeStr = Any
|
||||
PoolingTypeStr = Any
|
||||
SequencePoolingType = Any
|
||||
TokenPoolingType = Any
|
||||
|
||||
|
||||
from .interfaces import (
|
||||
@@ -57,7 +58,8 @@ from .interfaces import (
|
||||
)
|
||||
from .interfaces_base import (
|
||||
get_attn_type,
|
||||
get_default_pooling_type,
|
||||
get_default_seq_pooling_type,
|
||||
get_default_tok_pooling_type,
|
||||
is_pooling_model,
|
||||
is_text_generation_model,
|
||||
)
|
||||
@@ -548,7 +550,8 @@ class _ModelInfo:
|
||||
is_text_generation_model: bool
|
||||
is_pooling_model: bool
|
||||
attn_type: AttnTypeStr
|
||||
default_pooling_type: PoolingTypeStr
|
||||
default_seq_pooling_type: SequencePoolingType
|
||||
default_tok_pooling_type: TokenPoolingType
|
||||
supports_cross_encoding: bool
|
||||
supports_multimodal: bool
|
||||
supports_multimodal_raw_input_only: bool
|
||||
@@ -569,7 +572,8 @@ class _ModelInfo:
|
||||
architecture=model.__name__,
|
||||
is_text_generation_model=is_text_generation_model(model),
|
||||
is_pooling_model=is_pooling_model(model),
|
||||
default_pooling_type=get_default_pooling_type(model),
|
||||
default_seq_pooling_type=get_default_seq_pooling_type(model),
|
||||
default_tok_pooling_type=get_default_tok_pooling_type(model),
|
||||
attn_type=get_attn_type(model),
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
|
||||
@@ -93,7 +93,7 @@ class RobertaClassificationHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
"""A model that uses Roberta to provide embedding functionalities."""
|
||||
|
||||
@@ -150,7 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
return loader.load_weights(weights_list, mapper=mapper)
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ class SiglipProcessingInfo(BaseProcessingInfo):
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
_get_vision_feature_select_strategy(pooler_config.pooling_type),
|
||||
_get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
|
||||
)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
@@ -998,7 +998,7 @@ class SiglipTextEmbeddings(nn.Module):
|
||||
|
||||
|
||||
# Assume EOS token corresponds to CLS token in text model
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
SiglipMultiModalProcessor,
|
||||
info=SiglipProcessingInfo,
|
||||
@@ -1125,7 +1125,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
) -> torch.Tensor:
|
||||
if feature_select_strategy is None:
|
||||
feature_select_strategy = _get_vision_feature_select_strategy(
|
||||
self.pooler_config.pooling_type
|
||||
self.pooler_config.seq_pooling_type
|
||||
)
|
||||
|
||||
pooled_output = self.vision_model(
|
||||
|
||||
@@ -140,7 +140,7 @@ class PoolingParams(
|
||||
self, pooler_config: "PoolerConfig", valid_parameters: list[str]
|
||||
):
|
||||
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
|
||||
if pooler_config.pooling_type != "STEP":
|
||||
if pooler_config.tok_pooling_type != "STEP":
|
||||
invalid_parameters = []
|
||||
for k in step_pooling_parameters:
|
||||
if getattr(self, k, None) is not None:
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
from typing import Literal, get_args
|
||||
|
||||
GenerationTask = Literal["generate", "transcription"]
|
||||
GENERATION_TASKS = get_args(GenerationTask)
|
||||
GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask)
|
||||
|
||||
PoolingTask = Literal[
|
||||
"embed", "classify", "score", "token_embed", "token_classify", "plugin"
|
||||
]
|
||||
POOLING_TASKS = get_args(PoolingTask)
|
||||
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)
|
||||
|
||||
SupportedTask = Literal[GenerationTask, PoolingTask]
|
||||
|
||||
@@ -10,9 +10,7 @@ from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import (
|
||||
get_safetensors_metadata,
|
||||
)
|
||||
from huggingface_hub import get_safetensors_metadata
|
||||
from packaging.version import Version
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
from transformers.models.auto.image_processing_auto import get_image_processor_config
|
||||
@@ -742,7 +740,10 @@ def get_config(
|
||||
|
||||
|
||||
@cache
|
||||
def get_pooling_config(model: str, revision: str | None = "main") -> dict | None:
|
||||
def get_pooling_config(
|
||||
model: str,
|
||||
revision: str | None = "main",
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
This function gets the pooling and normalize
|
||||
config from the model - only applies to
|
||||
@@ -793,38 +794,40 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None
|
||||
)
|
||||
|
||||
if pooling:
|
||||
pooling_file_name = "{}/config.json".format(pooling["path"])
|
||||
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision)
|
||||
pooling_type_name = next(
|
||||
(item for item, val in pooling_dict.items() if val is True), None
|
||||
)
|
||||
from vllm.config.pooler import SEQ_POOLING_TYPES, TOK_POOLING_TYPES
|
||||
|
||||
if pooling_type_name is not None:
|
||||
pooling_type_name = get_pooling_config_name(pooling_type_name)
|
||||
pooling_file_name = "{}/config.json".format(pooling["path"])
|
||||
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision) or {}
|
||||
|
||||
logger.info("Found pooling configuration.")
|
||||
return {"pooling_type": pooling_type_name, "normalize": normalize}
|
||||
|
||||
config: dict[str, Any] = {"normalize": normalize}
|
||||
for key, val in pooling_dict.items():
|
||||
if val is True:
|
||||
pooling_type = parse_pooling_type(key)
|
||||
if pooling_type in SEQ_POOLING_TYPES:
|
||||
config["seq_pooling_type"] = pooling_type
|
||||
elif pooling_type in TOK_POOLING_TYPES:
|
||||
config["tok_pooling_type"] = pooling_type
|
||||
else:
|
||||
logger.debug("Skipping unrelated field: %r=%r", key, val)
|
||||
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_pooling_config_name(pooling_name: str) -> str | None:
|
||||
def parse_pooling_type(pooling_name: str):
|
||||
if "pooling_mode_" in pooling_name:
|
||||
pooling_name = pooling_name.replace("pooling_mode_", "")
|
||||
|
||||
if "_" in pooling_name:
|
||||
pooling_name = pooling_name.split("_")[0]
|
||||
pooling_name = pooling_name.split("_", 1)[0]
|
||||
|
||||
if "lasttoken" in pooling_name:
|
||||
pooling_name = "last"
|
||||
|
||||
supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"]
|
||||
pooling_type_name = pooling_name.upper()
|
||||
|
||||
if pooling_type_name in supported_pooling_types:
|
||||
return pooling_type_name
|
||||
|
||||
raise NotImplementedError(f"Pooling type {pooling_type_name} not supported")
|
||||
return pooling_name.upper()
|
||||
|
||||
|
||||
@cache
|
||||
|
||||
Reference in New Issue
Block a user