diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 4aeae8e36..0a923cc22 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -46,7 +46,8 @@ def test_model_loading_with_params(vllm_runner, monkeypatch): assert model_config.encoder_config["do_lower_case"] # asserts on the pooling config files - assert model_config.pooler_config.pooling_type == "CLS" + assert model_config.pooler_config.seq_pooling_type == "CLS" + assert model_config.pooler_config.tok_pooling_type == "ALL" assert model_config.pooler_config.normalize # asserts on the tokenizer loaded @@ -90,7 +91,8 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch): assert not model_config.encoder_config["do_lower_case"] # asserts on the pooling config files - assert model_config.pooler_config.pooling_type == "MEAN" + assert model_config.pooler_config.seq_pooling_type == "MEAN" + assert model_config.pooler_config.tok_pooling_type == "ALL" assert model_config.pooler_config.normalize # asserts on the tokenizer loaded diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 93b984331..982dc73f6 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -54,7 +54,7 @@ def test_models( vllm_extra_kwargs = {} if model == "ssmits/Qwen2-7B-Instruct-embed-base": vllm_extra_kwargs["pooler_config"] = PoolerConfig( - pooling_type="MEAN", normalize=False + seq_pooling_type="MEAN", normalize=False ) max_model_len: int | None = 512 diff --git a/tests/models/language/pooling/test_mm_classifier_conversion.py b/tests/models/language/pooling/test_mm_classifier_conversion.py index d50ee85b9..631fd394f 100644 --- a/tests/models/language/pooling/test_mm_classifier_conversion.py +++ b/tests/models/language/pooling/test_mm_classifier_conversion.py @@ -88,7 +88,7 @@ def test_gemma_multimodal( convert="classify", load_format="auto", hf_overrides=update_config, - pooler_config=PoolerConfig(pooling_type="LAST"), + pooler_config=PoolerConfig(seq_pooling_type="LAST"), max_model_len=512, enforce_eager=True, tensor_parallel_size=1, diff --git a/tests/models/language/pooling_mteb_test/mteb_embed_utils.py b/tests/models/language/pooling_mteb_test/mteb_embed_utils.py index e048318e9..a736b991d 100644 --- a/tests/models/language/pooling_mteb_test/mteb_embed_utils.py +++ b/tests/models/language/pooling_mteb_test/mteb_embed_utils.py @@ -162,8 +162,11 @@ def mteb_test_embed_models( assert model_info.architecture in model_config.architectures # Confirm whether the important configs in model_config are correct. - if model_info.pooling_type is not None: - assert model_config.pooler_config.pooling_type == model_info.pooling_type + pooler_config = model_config.pooler_config + if model_info.seq_pooling_type is not None: + assert pooler_config.seq_pooling_type == model_info.seq_pooling_type + if model_info.tok_pooling_type is not None: + assert pooler_config.tok_pooling_type == model_info.tok_pooling_type if model_info.attn_type is not None: assert model_config.attn_type == model_info.attn_type if model_info.is_prefix_caching_supported is not None: diff --git a/tests/models/language/pooling_mteb_test/mteb_score_utils.py b/tests/models/language/pooling_mteb_test/mteb_score_utils.py index c5c23b153..adc2cf3e4 100644 --- a/tests/models/language/pooling_mteb_test/mteb_score_utils.py +++ b/tests/models/language/pooling_mteb_test/mteb_score_utils.py @@ -254,8 +254,11 @@ def mteb_test_rerank_models( assert model_config.hf_config.num_labels == 1 # Confirm whether the important configs in model_config are correct. - if model_info.pooling_type is not None: - assert model_config.pooler_config.pooling_type == model_info.pooling_type + pooler_config = model_config.pooler_config + if model_info.seq_pooling_type is not None: + assert pooler_config.seq_pooling_type == model_info.seq_pooling_type + if model_info.tok_pooling_type is not None: + assert pooler_config.tok_pooling_type == model_info.tok_pooling_type if model_info.attn_type is not None: assert model_config.attn_type == model_info.attn_type if model_info.is_prefix_caching_supported is not None: diff --git a/tests/models/language/pooling_mteb_test/test_baai.py b/tests/models/language/pooling_mteb_test/test_baai.py index 2a639f550..1199393d4 100644 --- a/tests/models/language/pooling_mteb_test/test_baai.py +++ b/tests/models/language/pooling_mteb_test/test_baai.py @@ -17,7 +17,7 @@ MODELS = [ "BAAI/bge-base-en", architecture="BertModel", mteb_score=0.779336792, - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -54,7 +54,7 @@ MODELS = [ "BAAI/bge-m3", architecture="XLMRobertaModel", mteb_score=0.787343078, - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -65,7 +65,7 @@ MODELS = [ "BAAI/bge-code-v1", architecture="Qwen2Model", mteb_score=0.75724465, - pooling_type="LAST", + seq_pooling_type="LAST", attn_type="decoder", is_prefix_caching_supported=True, is_chunked_prefill_supported=True, @@ -79,7 +79,7 @@ RERANK_MODELS = [ "BAAI/bge-reranker-base", architecture="XLMRobertaForSequenceClassification", mteb_score=0.32398, - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, diff --git a/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py index 3e58d5999..23bc95548 100644 --- a/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py @@ -26,7 +26,7 @@ RERANK_MODELS = [ "method": "no_post_processing", }, mteb_score=0.33757, - pooling_type="LAST", + seq_pooling_type="LAST", attn_type="decoder", is_prefix_caching_supported=True, is_chunked_prefill_supported=True, diff --git a/tests/models/language/pooling_mteb_test/test_cross_encoder.py b/tests/models/language/pooling_mteb_test/test_cross_encoder.py index fb7b0fff3..0d1067d5e 100644 --- a/tests/models/language/pooling_mteb_test/test_cross_encoder.py +++ b/tests/models/language/pooling_mteb_test/test_cross_encoder.py @@ -12,7 +12,7 @@ RERANK_MODELS = [ RerankModelInfo( "cross-encoder/ms-marco-TinyBERT-L-2-v2", architecture="BertForSequenceClassification", - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -21,7 +21,7 @@ RERANK_MODELS = [ RerankModelInfo( "tomaarsen/Qwen3-Reranker-0.6B-seq-cls", architecture="Qwen3ForSequenceClassification", - pooling_type="LAST", + seq_pooling_type="LAST", attn_type="decoder", is_prefix_caching_supported=True, is_chunked_prefill_supported=True, diff --git a/tests/models/language/pooling_mteb_test/test_gte.py b/tests/models/language/pooling_mteb_test/test_gte.py index 2a5b2090b..f87fd832a 100644 --- a/tests/models/language/pooling_mteb_test/test_gte.py +++ b/tests/models/language/pooling_mteb_test/test_gte.py @@ -18,7 +18,7 @@ MODELS = [ "thenlper/gte-large", mteb_score=0.76807651, architecture="BertModel", - pooling_type="MEAN", + seq_pooling_type="MEAN", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -44,7 +44,7 @@ MODELS = [ architecture="GteNewModel", mteb_score=0.775074696, hf_overrides={"architectures": ["GteNewModel"]}, - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -67,7 +67,7 @@ MODELS = [ "Alibaba-NLP/gte-Qwen2-1.5B-instruct", mteb_score=0.758473459018872, architecture="Qwen2ForCausalLM", - pooling_type="LAST", + seq_pooling_type="LAST", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -78,7 +78,7 @@ MODELS = [ "Alibaba-NLP/gte-modernbert-base", mteb_score=0.748193353, architecture="ModernBertModel", - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -89,7 +89,7 @@ MODELS = [ "Qwen/Qwen3-Embedding-0.6B", mteb_score=0.771163695, architecture="Qwen3ForCausalLM", - pooling_type="LAST", + seq_pooling_type="LAST", attn_type="decoder", is_prefix_caching_supported=True, is_chunked_prefill_supported=True, @@ -108,7 +108,7 @@ RERANK_MODELS = [ "Alibaba-NLP/gte-reranker-modernbert-base", mteb_score=0.33386, architecture="ModernBertForSequenceClassification", - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -119,7 +119,7 @@ RERANK_MODELS = [ mteb_score=0.33062, architecture="GteNewForSequenceClassification", hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, diff --git a/tests/models/language/pooling_mteb_test/test_intfloat.py b/tests/models/language/pooling_mteb_test/test_intfloat.py index 377ab600a..adadb60ee 100644 --- a/tests/models/language/pooling_mteb_test/test_intfloat.py +++ b/tests/models/language/pooling_mteb_test/test_intfloat.py @@ -13,7 +13,7 @@ MODELS = [ "intfloat/e5-small", architecture="BertModel", mteb_score=0.742285423, - pooling_type="MEAN", + seq_pooling_type="MEAN", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -29,7 +29,7 @@ MODELS = [ "intfloat/multilingual-e5-base", architecture="XLMRobertaModel", mteb_score=0.779325955, - pooling_type="MEAN", + seq_pooling_type="MEAN", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, diff --git a/tests/models/language/pooling_mteb_test/test_jina.py b/tests/models/language/pooling_mteb_test/test_jina.py index cf6ba1851..627cc0431 100644 --- a/tests/models/language/pooling_mteb_test/test_jina.py +++ b/tests/models/language/pooling_mteb_test/test_jina.py @@ -24,7 +24,7 @@ EMBEDDING_MODELS = [ mteb_score=0.824413164, architecture="XLMRobertaModel", is_matryoshka=True, - pooling_type="MEAN", + seq_pooling_type="MEAN", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -36,7 +36,7 @@ RERANK_MODELS = [ "jinaai/jina-reranker-v2-base-multilingual", mteb_score=0.33643, architecture="XLMRobertaForSequenceClassification", - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, diff --git a/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py b/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py index b03f59962..74fe760e7 100644 --- a/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py +++ b/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py @@ -24,7 +24,7 @@ RERANK_MODELS = [ "mixedbread-ai/mxbai-rerank-base-v2", architecture="Qwen2ForSequenceClassification", hf_overrides=mxbai_rerank_hf_overrides, - pooling_type="LAST", + seq_pooling_type="LAST", attn_type="decoder", is_prefix_caching_supported=True, is_chunked_prefill_supported=True, diff --git a/tests/models/language/pooling_mteb_test/test_nemotron.py b/tests/models/language/pooling_mteb_test/test_nemotron.py index 4e8304dde..79fae2833 100644 --- a/tests/models/language/pooling_mteb_test/test_nemotron.py +++ b/tests/models/language/pooling_mteb_test/test_nemotron.py @@ -19,7 +19,7 @@ EMBEDDING_MODELS = [ "nvidia/llama-nemotron-embed-1b-v2", architecture="LlamaBidirectionalModel", mteb_score=0.689164662128673, - pooling_type="MEAN", + seq_pooling_type="MEAN", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -32,7 +32,7 @@ RERANK_MODELS = [ architecture="LlamaBidirectionalForSequenceClassification", chat_template_name="nemotron-rerank.jinja", mteb_score=0.33994, - pooling_type="MEAN", + seq_pooling_type="MEAN", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, diff --git a/tests/models/language/pooling_mteb_test/test_nomic.py b/tests/models/language/pooling_mteb_test/test_nomic.py index 06c568026..fa987fab7 100644 --- a/tests/models/language/pooling_mteb_test/test_nomic.py +++ b/tests/models/language/pooling_mteb_test/test_nomic.py @@ -14,7 +14,7 @@ MODELS = [ architecture="NomicBertModel", mteb_score=0.737568559, enable_test=True, - pooling_type="MEAN", + seq_pooling_type="MEAN", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -32,7 +32,7 @@ MODELS = [ architecture="NomicBertModel", mteb_score=0.715488912, enable_test=True, - pooling_type="MEAN", + seq_pooling_type="MEAN", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, diff --git a/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py b/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py index 228ae457b..3c182cb04 100644 --- a/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py +++ b/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py @@ -27,7 +27,7 @@ RERANK_MODELS = [ architecture="Qwen3ForSequenceClassification", hf_overrides=qwen3_reranker_hf_overrides, chat_template_name="qwen3_reranker.jinja", - pooling_type="LAST", + seq_pooling_type="LAST", attn_type="decoder", is_prefix_caching_supported=True, is_chunked_prefill_supported=True, diff --git a/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py b/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py index 37597a7e9..f3afbe84f 100644 --- a/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py @@ -14,7 +14,7 @@ MODELS = [ is_matryoshka=False, architecture="BertModel", mteb_score=0.714927797, - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -37,7 +37,7 @@ MODELS = [ is_matryoshka=False, architecture="NomicBertModel", mteb_score=0.681146831, - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -54,7 +54,7 @@ MODELS = [ is_matryoshka=True, architecture="BertModel", mteb_score=0.649088363, - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -65,7 +65,7 @@ MODELS = [ is_matryoshka=True, architecture="XLMRobertaModel", mteb_score=0.712258299, - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -76,7 +76,7 @@ MODELS = [ is_matryoshka=True, architecture="GteModel", mteb_score=0.706622444, - pooling_type="CLS", + seq_pooling_type="CLS", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, diff --git a/tests/models/language/pooling_mteb_test/test_st_projector.py b/tests/models/language/pooling_mteb_test/test_st_projector.py index 4ce7a4aed..395846347 100644 --- a/tests/models/language/pooling_mteb_test/test_st_projector.py +++ b/tests/models/language/pooling_mteb_test/test_st_projector.py @@ -14,7 +14,7 @@ ST_PROJECTOR_MODELS = [ "TencentBAC/Conan-embedding-v1", architecture="BertModel", mteb_score=0.688611955, - pooling_type="MEAN", + seq_pooling_type="MEAN", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, @@ -24,7 +24,7 @@ ST_PROJECTOR_MODELS = [ "google/embeddinggemma-300m", architecture="Gemma3TextModel", mteb_score=0.7473819294684156, - pooling_type="MEAN", + seq_pooling_type="MEAN", attn_type="encoder_only", is_prefix_caching_supported=False, is_chunked_prefill_supported=False, diff --git a/tests/models/utils.py b/tests/models/utils.py index bd9fcf31d..1b820d284 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -11,6 +11,7 @@ import torch.nn.functional as F from transformers import PretrainedConfig from vllm.config.model import AttnTypeStr, ModelConfig, ModelDType, RunnerOption +from vllm.config.pooler import SequencePoolingType, TokenPoolingType from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.multimodal.processing import InputProcessingContext from vllm.tokenizers import cached_tokenizer_from_config @@ -379,7 +380,8 @@ class ModelInfo: max_model_len: int | None = None hf_dtype: str = "float32" hf_overrides: dict[str, Any] | None = None - pooling_type: str | None = None + seq_pooling_type: SequencePoolingType | None = None + tok_pooling_type: TokenPoolingType | None = None attn_type: AttnTypeStr | None = None is_prefix_caching_supported: bool | None = None is_chunked_prefill_supported: bool | None = None diff --git a/tests/test_config.py b/tests/test_config.py index da5080fad..db8f3066c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -161,7 +161,8 @@ def test_get_pooling_config(): assert model_config.pooler_config is not None assert model_config.pooler_config.normalize - assert model_config.pooler_config.pooling_type == "MEAN" + assert model_config.pooler_config.seq_pooling_type == "MEAN" + assert model_config.pooler_config.tok_pooling_type == "ALL" @pytest.mark.skipif( @@ -169,7 +170,7 @@ def test_get_pooling_config(): ) def test_get_pooling_config_from_args(): model_id = "sentence-transformers/all-MiniLM-L12-v2" - pooler_config = PoolerConfig(pooling_type="CLS", normalize=True) + pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True) model_config = ModelConfig(model_id, pooler_config=pooler_config) assert asdict(model_config.pooler_config) == asdict(pooler_config) @@ -180,14 +181,25 @@ def test_get_pooling_config_from_args(): [ ("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM ("intfloat/e5-small", "CLS", "MEAN"), # BertModel + ], +) +def test_default_seq_pooling_type(model_id, default_pooling_type, pooling_type): + model_config = ModelConfig(model_id) + assert model_config._model_info.default_seq_pooling_type == default_pooling_type + assert model_config.pooler_config.seq_pooling_type == pooling_type + + +@pytest.mark.parametrize( + ("model_id", "default_pooling_type", "pooling_type"), + [ ("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"), # step reward ], ) -def test_default_pooling_type(model_id, default_pooling_type, pooling_type): +def test_default_tok_pooling_type(model_id, default_pooling_type, pooling_type): model_config = ModelConfig(model_id) - assert model_config._model_info.default_pooling_type == default_pooling_type - assert model_config.pooler_config.pooling_type == pooling_type + assert model_config._model_info.default_tok_pooling_type == default_pooling_type + assert model_config.pooler_config.tok_pooling_type == pooling_type @pytest.mark.parametrize( @@ -554,100 +566,100 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files): "jason9693/Qwen2.5-1.5B-apeach", "decoder", True, - "Pooling models with causal attn and last pooling support chunked prefill.", + "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501 ), ( "Qwen/Qwen3-Embedding-0.6B", "decoder", True, - "Pooling models with causal attn and last pooling support chunked prefill.", + "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501 ), ( "Qwen/Qwen2.5-Math-PRM-7B", "decoder", False, - "Pooling models with step pooling does not support chunked prefill.", + "Pooling models with causal attn and LAST/STEP pooling do not support chunked prefill.", # noqa: E501 ), ( "internlm/internlm2-1_8b-reward", "decoder", True, - "Pooling models with causal attn and all pooling support chunked prefill.", + "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501 ), ( "BAAI/bge-base-en", "encoder_only", False, - "Pooling models with bidirectional attn does not support chunked prefill.", + "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501 ), ( "boltuix/NeuroBERT-NER", "encoder_only", False, - "Pooling models with bidirectional attn does not support chunked prefill.", + "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501 ), ( "papluca/xlm-roberta-base-language-detection", "encoder_only", False, - "Pooling models with bidirectional attn does not support chunked prefill.", + "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501 ), ( "Alibaba-NLP/gte-Qwen2-1.5B-instruct", "encoder_only", False, - "Pooling models with bidirectional attn does not support chunked prefill.", + "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501 ), ( "intfloat/e5-small", "encoder_only", False, - "Pooling models with bidirectional attn does not support chunked prefill.", + "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501 ), # multimodal models ( "openai/clip-vit-base-patch32", "decoder", True, - "Pooling models with causal attn and last pooling support chunked prefill.", + "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501 ), ( "google/siglip-base-patch16-224", "encoder_only", False, - "Pooling models with bidirectional attn does not support chunked prefill.", + "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501 ), # generate models ( "Qwen/Qwen3-0.6B", "decoder", True, - "Generative models support chunked prefill.", + "Generative models support chunked prefill.", # noqa: E501 ), ( "Qwen/Qwen3-Next-80B-A3B-Instruct", "hybrid", True, - "Generative models support chunked prefill.", + "Generative models support chunked prefill.", # noqa: E501 ), ( "ibm-granite/granite-4.0-h-small", "hybrid", True, - "Generative models support chunked prefill.", + "Generative models support chunked prefill.", # noqa: E501 ), ( "state-spaces/mamba-130m-hf", "attention_free", True, - "Generative models support chunked prefill.", + "Generative models support chunked prefill.", # noqa: E501 ), # encoder_decoder models ( "openai/whisper-small", "encoder_decoder", False, - "Encoder decoder models does not support chunked prefill.", + "Encoder decoder models do not support chunked prefill.", # noqa: E501 ), ], ) @@ -673,100 +685,100 @@ def test_is_chunked_prefill_supported( "jason9693/Qwen2.5-1.5B-apeach", "decoder", True, - "Pooling models with causal attn and last pooling support prefix caching.", + "Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501 ), ( "Qwen/Qwen3-Embedding-0.6B", "decoder", True, - "Pooling models with causal attn and last pooling support prefix caching.", + "Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501 ), ( "Qwen/Qwen2.5-Math-PRM-7B", "decoder", False, - "Pooling models with step pooling does not support prefix caching.", + "Pooling models with causal attn and LAST/STEP pooling do not support prefix caching.", # noqa: E501 ), ( "internlm/internlm2-1_8b-reward", "decoder", True, - "Pooling models with causal attn and all pooling support prefix caching.", + "Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501 ), ( "BAAI/bge-base-en", "encoder_only", False, - "Pooling models with bidirectional attn does not support prefix caching.", + "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501 ), ( "boltuix/NeuroBERT-NER", "encoder_only", False, - "Pooling models with bidirectional attn does not support prefix caching.", + "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501 ), ( "papluca/xlm-roberta-base-language-detection", "encoder_only", False, - "Pooling models with bidirectional attn does not support prefix caching.", + "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501 ), ( "Alibaba-NLP/gte-Qwen2-1.5B-instruct", "encoder_only", False, - "Pooling models with bidirectional attn does not support prefix caching.", + "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501 ), ( "intfloat/e5-small", "encoder_only", False, - "Pooling models with bidirectional attn does not support prefix caching.", + "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501 ), # multimodal models ( "openai/clip-vit-base-patch32", "decoder", True, - "Pooling models with causal attn and last pooling support prefix caching.", + "Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501 ), ( "google/siglip-base-patch16-224", "encoder_only", False, - "Pooling models with bidirectional attn does not support prefix caching.", + "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501 ), # generate models ( "Qwen/Qwen3-0.6B", "decoder", True, - "Generative models support prefix caching.", + "Generative models support prefix caching.", # noqa: E501 ), ( "Qwen/Qwen3-Next-80B-A3B-Instruct", "hybrid", False, - "Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501 + "Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501 ), ( "ibm-granite/granite-4.0-h-small", "hybrid", False, - "Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501 + "Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501 ), ( "state-spaces/mamba-130m-hf", "attention_free", False, - "Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501 + "Attention free models do not support prefix caching since the feature is still experimental.", # noqa: E501 ), # encoder_decoder models ( "openai/whisper-small", "encoder_decoder", False, - "Encoder decoder models does not support prefix caching.", + "Encoder decoder models do not support prefix caching.", # noqa: E501 ), ], ) diff --git a/tests/test_pooling_params.py b/tests/test_pooling_params.py index 7812562c8..2c77c6b72 100644 --- a/tests/test_pooling_params.py +++ b/tests/test_pooling_params.py @@ -40,7 +40,7 @@ def test_task(): def test_embed(): task = "embed" - model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS")) + model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS")) pooling_params = PoolingParams(normalize=None) pooling_params.verify(task=task, model_config=model_config) @@ -86,7 +86,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo): @pytest.mark.parametrize("task", ["score", "classify"]) def test_classify(task): - model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS")) + model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS")) pooling_params = PoolingParams(use_activation=None) pooling_params.verify(task=task, model_config=model_config) @@ -108,7 +108,7 @@ def test_classify(task): def test_token_embed(pooling_type: str): task = "token_embed" model_config = MockModelConfig( - pooler_config=PoolerConfig(pooling_type=pooling_type) + pooler_config=PoolerConfig(tok_pooling_type=pooling_type) ) pooling_params = PoolingParams(normalize=None) @@ -134,7 +134,7 @@ def test_token_embed(pooling_type: str): def test_token_classify(pooling_type: str): task = "token_classify" model_config = MockModelConfig( - pooler_config=PoolerConfig(pooling_type=pooling_type) + pooler_config=PoolerConfig(tok_pooling_type=pooling_type) ) pooling_params = PoolingParams(use_activation=None) diff --git a/vllm/config/model.py b/vllm/config/model.py index bec1de554..249fb5668 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -539,9 +539,12 @@ class ModelConfig: if getattr(self.pooler_config, k) is None: setattr(self.pooler_config, k, v) - default_pooling_type = self._model_info.default_pooling_type - if self.pooler_config.pooling_type is None: - self.pooler_config.pooling_type = default_pooling_type + default_seq_pooling_type = self._model_info.default_seq_pooling_type + if self.pooler_config.seq_pooling_type is None: + self.pooler_config.seq_pooling_type = default_seq_pooling_type + default_tok_pooling_type = self._model_info.default_tok_pooling_type + if self.pooler_config.tok_pooling_type is None: + self.pooler_config.tok_pooling_type = default_tok_pooling_type self.dtype: torch.dtype = _get_and_verify_dtype( self.model, @@ -1543,8 +1546,8 @@ class ModelConfig: @property def attn_type(self) -> AttnTypeStr: if self.pooler_config is not None: - pooling_type = self._model_info.default_pooling_type.lower() - if pooling_type == "cls": + seq_pooling_type = self._model_info.default_seq_pooling_type + if seq_pooling_type == "CLS": return "encoder_only" else: is_causal = getattr(self.hf_config, "is_causal", True) @@ -1561,89 +1564,102 @@ class ModelConfig: @property def is_chunked_prefill_supported(self) -> bool: attn_type = self.attn_type - if self.pooler_config is not None: + + if pooler_config := self.pooler_config: # for pooling models if attn_type == "encoder_only": logger.debug( - "Pooling models with bidirectional attn does not support " - "chunked prefill." + "Pooling models with bidirectional attn " + "do not support chunked prefill." ) return False - elif attn_type == "decoder": - pooling_type = self.pooler_config.pooling_type.lower() - if pooling_type in ["mean", "step", "cls"]: + + if attn_type == "decoder": + if ( + pooler_config.seq_pooling_type in ("MEAN", "CLS") + or pooler_config.tok_pooling_type == "STEP" + ): logger.debug( - "Pooling models with %s pooling does not " - "support chunked prefill.", - pooling_type, + "Pooling models with causal attn and %s/%s pooling " + "do not support chunked prefill.", + pooler_config.seq_pooling_type, + pooler_config.tok_pooling_type, ) return False - elif pooling_type in ["all", "last"]: + else: logger.debug( - "Pooling models with causal attn and %s pooling support " - "chunked prefill.", - pooling_type, + "Pooling models with causal attn and %s/%s pooling " + "support chunked prefill.", + pooler_config.seq_pooling_type, + pooler_config.tok_pooling_type, ) return True - else: - raise ValueError(f"{pooling_type=} not supported.") + # vllm currently does not have pooling models using hybrid, # attention_free or encoder_decoder attn types. return attn_type != "encoder_decoder" else: + # for generative models if attn_type == "encoder_decoder": - logger.debug("Encoder decoder models does not support chunked prefill.") + logger.debug("Encoder decoder models do not support chunked prefill.") return False + logger.debug("Generative models support chunked prefill.") return True @property def is_prefix_caching_supported(self) -> bool: attn_type = self.attn_type - if self.pooler_config is not None: + + if pooler_config := self.pooler_config: # for pooling models if attn_type == "encoder_only": logger.debug( - "Pooling models with bidirectional attn does not " - "support prefix caching." + "Pooling models with bidirectional attn " + "do not support prefix caching." ) return False - elif attn_type == "decoder": - pooling_type = self.pooler_config.pooling_type.lower() - if pooling_type in ["mean", "step", "cls"]: + + if attn_type == "decoder": + if ( + pooler_config.seq_pooling_type in ("MEAN", "CLS") + or pooler_config.tok_pooling_type == "STEP" + ): logger.debug( - "Pooling models with %s pooling does not " - "support prefix caching.", - pooling_type, + "Pooling models with causal attn and %s/%s pooling " + "do not support prefix caching.", + pooler_config.seq_pooling_type, + pooler_config.tok_pooling_type, ) return False - elif pooling_type in ["all", "last"]: + else: logger.debug( - "Pooling models with causal attn and %s pooling support " - "prefix caching.", - pooling_type, + "Pooling models with causal attn and %s/%s pooling " + "support prefix caching.", + pooler_config.seq_pooling_type, + pooler_config.tok_pooling_type, ) return True - else: - raise ValueError(f"{pooling_type=} not supported.") + # vllm currently does not have pooling models using hybrid, # attention_free or encoder_decoder attn types. return False else: + # for generative models if attn_type == "hybrid": logger.debug( - "Hybrid models does not support prefix caching since the feature " + "Hybrid models do not support prefix caching since the feature " "is still experimental." ) return False elif attn_type == "attention_free": logger.debug( - "Attention free models does not support prefix caching since the " + "Attention free models do not support prefix caching since the " "feature is still experimental." ) return False elif attn_type == "encoder_decoder": - logger.debug("Encoder decoder models does not support prefix caching.") + logger.debug("Encoder decoder models do not support prefix caching.") return False else: # attn_type == "decoder" logger.debug("Generative models support prefix caching.") diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py index 008fefadf..afcc697bb 100644 --- a/vllm/config/pooler.py +++ b/vllm/config/pooler.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Literal +from typing import Any, Literal, get_args from pydantic.dataclasses import dataclass @@ -11,7 +11,11 @@ from vllm.utils.hashing import safe_hash logger = init_logger(__name__) -PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"] +SequencePoolingType = Literal["CLS", "LAST", "MEAN"] +SEQ_POOLING_TYPES: tuple[SequencePoolingType, ...] = get_args(SequencePoolingType) + +TokenPoolingType = Literal["ALL", "STEP"] +TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType) @config @@ -19,9 +23,26 @@ PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"] class PoolerConfig: """Controls the behavior of output pooling in pooling models.""" - pooling_type: PoolingTypeStr | None = None + pooling_type: SequencePoolingType | TokenPoolingType | None = None """ - The pooling method of the pooling model. + The pooling method used for pooling. + + If set, `seq_pooling_type` or `tok_pooling_type` are automatically populated + with this field. Alternatively, users can set `seq_pooling_type` and + `tok_pooling_type` explicitly. + + This field is mainly for user convenience. Internal code should always use + `seq_pooling_type` or `tok_pooling_type` instead of `pooling_type`. + """ + + seq_pooling_type: SequencePoolingType | None = None + """ + The pooling method used for sequence pooling. + """ + + tok_pooling_type: TokenPoolingType | None = None + """ + The pooling method used for tokenwise pooling. """ ## for embeddings models @@ -88,9 +109,40 @@ class PoolerConfig: # raise deprecated warning for softmax and activation self.use_activation = get_use_activation(self) - def get_pooling_type(self) -> PoolingTypeStr: - assert self.pooling_type is not None, "Should be resolved by ModelConfig" - return self.pooling_type + if pooling_type := self.pooling_type: + if self.seq_pooling_type is not None: + raise ValueError( + "Cannot set both `pooling_type` and `seq_pooling_type`" + ) + if self.tok_pooling_type is not None: + raise ValueError( + "Cannot set both `pooling_type` and `tok_pooling_type`" + ) + + if pooling_type in SEQ_POOLING_TYPES: + logger.debug( + "Resolved `pooling_type=%r` to `seq_pooling_type=%r`.", + pooling_type, + pooling_type, + ) + self.seq_pooling_type = pooling_type + elif pooling_type in TOK_POOLING_TYPES: + logger.debug( + "Resolved `pooling_type=%r` to `tok_pooling_type=%r`.", + pooling_type, + pooling_type, + ) + self.tok_pooling_type = pooling_type + else: + raise NotImplementedError(pooling_type) + + def get_seq_pooling_type(self) -> SequencePoolingType: + assert self.seq_pooling_type is not None, "Should be resolved by ModelConfig" + return self.seq_pooling_type + + def get_tok_pooling_type(self) -> TokenPoolingType: + assert self.tok_pooling_type is not None, "Should be resolved by ModelConfig" + return self.tok_pooling_type def compute_hash(self) -> str: """ diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ca3ef37fa..9c10e28c2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -172,7 +172,7 @@ class LLM: The available overrides depend on the model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. pooler_config: Initialize non-default pooling config for the pooling - model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`. + model. e.g. `PoolerConfig(seq_pooling_type="MEAN", normalize=False)`. compilation_config: Either an integer or a dictionary. If it is an integer, it is used as the mode of compilation optimization. If it is a dictionary, it can specify the full compilation configuration. diff --git a/vllm/model_executor/layers/pooler/seqwise/methods.py b/vllm/model_executor/layers/pooler/seqwise/methods.py index e71a9de3f..5d8551095 100644 --- a/vllm/model_executor/layers/pooler/seqwise/methods.py +++ b/vllm/model_executor/layers/pooler/seqwise/methods.py @@ -7,7 +7,7 @@ from typing import TypeAlias import torch import torch.nn as nn -from vllm.config.pooler import PoolingTypeStr +from vllm.config.pooler import SequencePoolingType from vllm.model_executor.layers.pooler import PoolingParamsUpdate from vllm.tasks import PoolingTask from vllm.v1.pool.metadata import PoolingMetadata @@ -82,11 +82,11 @@ class MeanPool(SequencePoolingMethod): ) / prompt_lens.unsqueeze(1) -def get_seq_pooling_method(pooling_type: PoolingTypeStr | str): - if pooling_type == "LAST": - return LastPool() +def get_seq_pooling_method(pooling_type: SequencePoolingType | str): if pooling_type == "CLS": return CLSPool() + if pooling_type == "LAST": + return LastPool() if pooling_type == "MEAN": return MeanPool() diff --git a/vllm/model_executor/layers/pooler/seqwise/poolers.py b/vllm/model_executor/layers/pooler/seqwise/poolers.py index 586dcfb99..db867fb60 100644 --- a/vllm/model_executor/layers/pooler/seqwise/poolers.py +++ b/vllm/model_executor/layers/pooler/seqwise/poolers.py @@ -85,7 +85,7 @@ class SequencePooler(Pooler): def pooler_for_embed(pooler_config: PoolerConfig): - pooling = get_seq_pooling_method(pooler_config.get_pooling_type()) + pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type()) head = EmbeddingPoolerHead() return SequencePooler(pooling=pooling, head=head) @@ -99,7 +99,7 @@ def pooler_for_classify( act_fn: PoolerActivation | str | None = None, ): if pooling is None: - pooling = get_seq_pooling_method(pooler_config.get_pooling_type()) + pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type()) head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn) diff --git a/vllm/model_executor/layers/pooler/tokwise/methods.py b/vllm/model_executor/layers/pooler/tokwise/methods.py index 4e84f57d7..baa9d4075 100644 --- a/vllm/model_executor/layers/pooler/tokwise/methods.py +++ b/vllm/model_executor/layers/pooler/tokwise/methods.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn from vllm.config import get_current_vllm_config -from vllm.config.pooler import PoolingTypeStr +from vllm.config.pooler import TokenPoolingType from vllm.model_executor.layers.pooler import PoolingParamsUpdate from vllm.tasks import PoolingTask from vllm.v1.pool.metadata import PoolingMetadata @@ -113,12 +113,10 @@ class StepPool(AllPool): return pooled_data -def get_tok_pooling_method(pooling_type: PoolingTypeStr | str): +def get_tok_pooling_method(pooling_type: TokenPoolingType | str): if pooling_type == "ALL": return AllPool() if pooling_type == "STEP": return StepPool() - # TODO: Separate seq and tok pooling types so we don't need this fallback - return AllPool() raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}") diff --git a/vllm/model_executor/layers/pooler/tokwise/poolers.py b/vllm/model_executor/layers/pooler/tokwise/poolers.py index ff68359bb..991daaeba 100644 --- a/vllm/model_executor/layers/pooler/tokwise/poolers.py +++ b/vllm/model_executor/layers/pooler/tokwise/poolers.py @@ -85,7 +85,7 @@ class TokenPooler(Pooler): def pooler_for_token_embed(pooler_config: PoolerConfig): - pooling = get_tok_pooling_method(pooler_config.get_pooling_type()) + pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type()) head = TokenEmbeddingPoolerHead() return TokenPooler(pooling=pooling, head=head) @@ -99,7 +99,7 @@ def pooler_for_token_classify( act_fn: PoolerActivation | str | None = None, ): if pooling is None: - pooling = get_tok_pooling_method(pooler_config.get_pooling_type()) + pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type()) head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index cce01ea50..b09e76015 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -357,7 +357,7 @@ class BertOutput(nn.Module): @support_torch_compile -@default_pooling_type("CLS") +@default_pooling_type(seq_pooling_type="CLS") class BertModel(nn.Module, SupportsQuant): is_pooling_model = True @@ -461,7 +461,7 @@ class BertPoolingModel(BertModel): return loaded_params -@default_pooling_type("CLS") +@default_pooling_type(seq_pooling_type="CLS") class BertEmbeddingModel(nn.Module, SupportsQuant): """A model that uses Bert to provide embedding functionalities. @@ -675,7 +675,7 @@ class SPLADESparsePooler(Pooler): return torch.stack(pooled_list, dim=0).contiguous() -@default_pooling_type("CLS") +@default_pooling_type(seq_pooling_type="CLS") class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): """ BertEmbeddingModel + SPLADE sparse embedding. @@ -780,7 +780,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): return loaded -@default_pooling_type("CLS") +@default_pooling_type(seq_pooling_type="CLS") class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant): """A model that uses Bert to provide embedding functionalities. @@ -849,7 +849,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu @attn_type("encoder_only") -@default_pooling_type("ALL") +@default_pooling_type(tok_pooling_type="ALL") class BertForTokenClassification(nn.Module): is_pooling_model = True diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index a5c43bbb3..cfe350db1 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -441,7 +441,7 @@ class BertWithRopeEncoder(nn.Module): @support_torch_compile -@default_pooling_type("CLS") +@default_pooling_type(seq_pooling_type="CLS") class BertWithRope(nn.Module, SupportsQuant): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) @@ -670,7 +670,7 @@ class JinaRobertaModel(BertWithRope): return super().load_weights(weights) -@default_pooling_type("CLS") +@default_pooling_type(seq_pooling_type="CLS") class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): is_pooling_model = True diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index d18904fdf..1a4811ef8 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -145,7 +145,7 @@ class CLIPProcessingInfo(BaseProcessingInfo): image_width=image_width, image_height=image_height, ), - _get_vision_feature_select_strategy(pooler_config.pooling_type), + _get_vision_feature_select_strategy(pooler_config.seq_pooling_type), ) def get_image_size_with_most_features(self) -> ImageSize: @@ -819,7 +819,7 @@ class CLIPVisionModel(nn.Module): # Assume EOS token corresponds to LAST token in text model -@default_pooling_type("LAST") +@default_pooling_type(seq_pooling_type="LAST") @MULTIMODAL_REGISTRY.register_processor( CLIPMultiModalProcessor, info=CLIPProcessingInfo, @@ -908,7 +908,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ) -> torch.Tensor: if feature_select_strategy is None: feature_select_strategy = _get_vision_feature_select_strategy( - self.pooler_config.pooling_type + self.pooler_config.seq_pooling_type ) pooled_output = self.vision_model( diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 82e6df199..92a18684b 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -94,12 +94,12 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig): class LlamaBidirectionalConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_model_config(model_config: "ModelConfig") -> None: - from vllm.config.pooler import PoolingTypeStr + from vllm.config.pooler import SequencePoolingType hf_config = model_config.hf_config hf_config.is_causal = False - pooling_type_map: dict[str, PoolingTypeStr] = { + pooling_type_map: dict[str, SequencePoolingType] = { "avg": "MEAN", "cls": "CLS", "last": "LAST", @@ -107,8 +107,9 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig): pooling_type = pooling_type_map.get(hf_config.pooling, None) if pooling_type is None: - raise ValueError(f"pool_type {hf_config.pooling} not supported") - model_config.pooler_config.pooling_type = pooling_type + raise ValueError(f"pool_type {hf_config.pooling!r} not supported") + + model_config.pooler_config.seq_pooling_type = pooling_type class NomicBertModelConfig(VerifyAndUpdateConfig): diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 34dbd8050..34d7e5c92 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -193,7 +193,7 @@ class GritLMPooler(SequencePooler): return self.activation(pooled_data) -@default_pooling_type("MEAN") +@default_pooling_type(seq_pooling_type="MEAN") class GritLM(LlamaForCausalLM): """This class implements the embedding model for parasail-ai/GritLM-7B-vllm. diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 134a1d948..e658825e1 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -20,12 +20,13 @@ from vllm.utils.func_utils import supports_kw if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.config.model import AttnTypeStr - from vllm.config.pooler import PoolingTypeStr + from vllm.config.pooler import SequencePoolingType, TokenPoolingType from vllm.model_executor.layers.pooler import Pooler else: VllmConfig = Any Pooler = Any - PoolingTypeStr = Any + SequencePoolingType = Any + TokenPoolingType = Any AttnTypeStr = Any logger = init_logger(__name__) @@ -155,9 +156,19 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): MRO of your model class. """ - default_pooling_type: ClassVar[PoolingTypeStr] = "LAST" + default_seq_pooling_type: ClassVar[SequencePoolingType] = "LAST" """ - Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][] + Indicates the [vllm.config.pooler.PoolerConfig.seq_pooling_type][] + to use by default. + + You can use the + [vllm.model_executor.models.interfaces_base.default_pooling_type][] + decorator to conveniently set this field. + """ + + default_tok_pooling_type: ClassVar[TokenPoolingType] = "ALL" + """ + Indicates the [vllm.config.pooler.PoolerConfig.tok_pooling_type][] to use by default. You can use the @@ -200,18 +211,31 @@ def is_pooling_model( _T = TypeVar("_T", bound=type[nn.Module]) -def default_pooling_type(pooling_type: PoolingTypeStr): - """Decorator to set `VllmModelForPooling.default_pooling_type`.""" +def default_pooling_type( + *, + seq_pooling_type: SequencePoolingType = "LAST", + tok_pooling_type: TokenPoolingType = "ALL", +): + """Decorator to set `VllmModelForPooling.default_*_pooling_type`.""" def func(model: _T) -> _T: - model.default_pooling_type = pooling_type # type: ignore + model.default_seq_pooling_type = seq_pooling_type # type: ignore + model.default_tok_pooling_type = tok_pooling_type # type: ignore return model return func -def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr: - return getattr(model, "default_pooling_type", "LAST") +def get_default_seq_pooling_type( + model: type[object] | object, +) -> SequencePoolingType: + return getattr(model, "default_seq_pooling_type", "LAST") + + +def get_default_tok_pooling_type( + model: type[object] | object, +) -> TokenPoolingType: + return getattr(model, "default_tok_pooling_type", "ALL") def attn_type(attn_type: AttnTypeStr): diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 37309cd09..45628b4fe 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -402,7 +402,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): return loaded_params -@default_pooling_type("ALL") +@default_pooling_type(tok_pooling_type="ALL") class InternLM2ForRewardModel(InternLM2ForCausalLM): is_pooling_model = True diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 773948039..b80258daf 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -221,7 +221,7 @@ class ModernBertEncoderLayer(nn.Module): @support_torch_compile -@default_pooling_type("CLS") +@default_pooling_type(seq_pooling_type="CLS") class ModernBertModel(nn.Module): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"layers.": "encoder_layer.layers."} @@ -308,7 +308,7 @@ class ModernBertPooler(SequencePooler): return self.norm(self.act(self.dense(pooled_data))) -@default_pooling_type("CLS") +@default_pooling_type(seq_pooling_type="CLS") class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): is_pooling_model = True @@ -395,7 +395,7 @@ class ModernBertPredictionHead(nn.Module): @attn_type("encoder_only") -@default_pooling_type("ALL") +@default_pooling_type(tok_pooling_type="ALL") class ModernBertForTokenClassification(nn.Module): is_pooling_model = True diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 963edcb75..b0fa576f5 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -96,7 +96,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): return loader.load_weights(weights) -@default_pooling_type("ALL") +@default_pooling_type(tok_pooling_type="ALL") class Qwen2ForRewardModel(Qwen2RewardBaseModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.num_labels = 1 @@ -108,7 +108,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel): self.pooler = pooler_for_token_classify(pooler_config) -@default_pooling_type("STEP") +@default_pooling_type(tok_pooling_type="STEP") class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.num_labels = 2 diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 9124f79ba..a8aca4e89 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -35,10 +35,11 @@ from vllm.utils.hashing import safe_hash if TYPE_CHECKING: from vllm.config.model import AttnTypeStr - from vllm.config.pooler import PoolingTypeStr + from vllm.config.pooler import SequencePoolingType, TokenPoolingType else: AttnTypeStr = Any - PoolingTypeStr = Any + SequencePoolingType = Any + TokenPoolingType = Any from .interfaces import ( @@ -57,7 +58,8 @@ from .interfaces import ( ) from .interfaces_base import ( get_attn_type, - get_default_pooling_type, + get_default_seq_pooling_type, + get_default_tok_pooling_type, is_pooling_model, is_text_generation_model, ) @@ -548,7 +550,8 @@ class _ModelInfo: is_text_generation_model: bool is_pooling_model: bool attn_type: AttnTypeStr - default_pooling_type: PoolingTypeStr + default_seq_pooling_type: SequencePoolingType + default_tok_pooling_type: TokenPoolingType supports_cross_encoding: bool supports_multimodal: bool supports_multimodal_raw_input_only: bool @@ -569,7 +572,8 @@ class _ModelInfo: architecture=model.__name__, is_text_generation_model=is_text_generation_model(model), is_pooling_model=is_pooling_model(model), - default_pooling_type=get_default_pooling_type(model), + default_seq_pooling_type=get_default_seq_pooling_type(model), + default_tok_pooling_type=get_default_tok_pooling_type(model), attn_type=get_attn_type(model), supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 647fb70ef..f52123901 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -93,7 +93,7 @@ class RobertaClassificationHead(nn.Module): return x -@default_pooling_type("CLS") +@default_pooling_type(seq_pooling_type="CLS") class RobertaEmbeddingModel(BertEmbeddingModel): """A model that uses Roberta to provide embedding functionalities.""" @@ -150,7 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel): return loader.load_weights(weights_list, mapper=mapper) -@default_pooling_type("CLS") +@default_pooling_type(seq_pooling_type="CLS") class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index c047415d4..1bda00653 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -129,7 +129,7 @@ class SiglipProcessingInfo(BaseProcessingInfo): image_width=image_width, image_height=image_height, ), - _get_vision_feature_select_strategy(pooler_config.pooling_type), + _get_vision_feature_select_strategy(pooler_config.seq_pooling_type), ) def get_image_size_with_most_features(self) -> ImageSize: @@ -998,7 +998,7 @@ class SiglipTextEmbeddings(nn.Module): # Assume EOS token corresponds to CLS token in text model -@default_pooling_type("CLS") +@default_pooling_type(seq_pooling_type="CLS") @MULTIMODAL_REGISTRY.register_processor( SiglipMultiModalProcessor, info=SiglipProcessingInfo, @@ -1125,7 +1125,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ) -> torch.Tensor: if feature_select_strategy is None: feature_select_strategy = _get_vision_feature_select_strategy( - self.pooler_config.pooling_type + self.pooler_config.seq_pooling_type ) pooled_output = self.vision_model( diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 4a5caa7e2..09fd8d0bd 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -140,7 +140,7 @@ class PoolingParams( self, pooler_config: "PoolerConfig", valid_parameters: list[str] ): step_pooling_parameters = ["step_tag_id", "returned_token_ids"] - if pooler_config.pooling_type != "STEP": + if pooler_config.tok_pooling_type != "STEP": invalid_parameters = [] for k in step_pooling_parameters: if getattr(self, k, None) is not None: diff --git a/vllm/tasks.py b/vllm/tasks.py index b02cde74c..bd3e5af77 100644 --- a/vllm/tasks.py +++ b/vllm/tasks.py @@ -3,11 +3,11 @@ from typing import Literal, get_args GenerationTask = Literal["generate", "transcription"] -GENERATION_TASKS = get_args(GenerationTask) +GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask) PoolingTask = Literal[ "embed", "classify", "score", "token_embed", "token_classify", "plugin" ] -POOLING_TASKS = get_args(PoolingTask) +POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask) SupportedTask = Literal[GenerationTask, PoolingTask] diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 16cb2bea3..02ffc37d9 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -10,9 +10,7 @@ from pathlib import Path from typing import Any, Literal, TypeAlias import huggingface_hub -from huggingface_hub import ( - get_safetensors_metadata, -) +from huggingface_hub import get_safetensors_metadata from packaging.version import Version from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import get_image_processor_config @@ -742,7 +740,10 @@ def get_config( @cache -def get_pooling_config(model: str, revision: str | None = "main") -> dict | None: +def get_pooling_config( + model: str, + revision: str | None = "main", +) -> dict[str, Any] | None: """ This function gets the pooling and normalize config from the model - only applies to @@ -793,38 +794,40 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None ) if pooling: - pooling_file_name = "{}/config.json".format(pooling["path"]) - pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision) - pooling_type_name = next( - (item for item, val in pooling_dict.items() if val is True), None - ) + from vllm.config.pooler import SEQ_POOLING_TYPES, TOK_POOLING_TYPES - if pooling_type_name is not None: - pooling_type_name = get_pooling_config_name(pooling_type_name) + pooling_file_name = "{}/config.json".format(pooling["path"]) + pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision) or {} logger.info("Found pooling configuration.") - return {"pooling_type": pooling_type_name, "normalize": normalize} + + config: dict[str, Any] = {"normalize": normalize} + for key, val in pooling_dict.items(): + if val is True: + pooling_type = parse_pooling_type(key) + if pooling_type in SEQ_POOLING_TYPES: + config["seq_pooling_type"] = pooling_type + elif pooling_type in TOK_POOLING_TYPES: + config["tok_pooling_type"] = pooling_type + else: + logger.debug("Skipping unrelated field: %r=%r", key, val) + + return config return None -def get_pooling_config_name(pooling_name: str) -> str | None: +def parse_pooling_type(pooling_name: str): if "pooling_mode_" in pooling_name: pooling_name = pooling_name.replace("pooling_mode_", "") if "_" in pooling_name: - pooling_name = pooling_name.split("_")[0] + pooling_name = pooling_name.split("_", 1)[0] if "lasttoken" in pooling_name: pooling_name = "last" - supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"] - pooling_type_name = pooling_name.upper() - - if pooling_type_name in supported_pooling_types: - return pooling_type_name - - raise NotImplementedError(f"Pooling type {pooling_type_name} not supported") + return pooling_name.upper() @cache