[Refactor] Separate sequence and token pooling types (#32026)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-10 12:53:24 +08:00
committed by GitHub
parent 52d428295d
commit 583a90e005
42 changed files with 324 additions and 204 deletions

View File

@@ -161,7 +161,8 @@ def test_get_pooling_config():
assert model_config.pooler_config is not None
assert model_config.pooler_config.normalize
assert model_config.pooler_config.pooling_type == "MEAN"
assert model_config.pooler_config.seq_pooling_type == "MEAN"
assert model_config.pooler_config.tok_pooling_type == "ALL"
@pytest.mark.skipif(
@@ -169,7 +170,7 @@ def test_get_pooling_config():
)
def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True)
pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True)
model_config = ModelConfig(model_id, pooler_config=pooler_config)
assert asdict(model_config.pooler_config) == asdict(pooler_config)
@@ -180,14 +181,25 @@ def test_get_pooling_config_from_args():
[
("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM
("intfloat/e5-small", "CLS", "MEAN"), # BertModel
],
)
def test_default_seq_pooling_type(model_id, default_pooling_type, pooling_type):
model_config = ModelConfig(model_id)
assert model_config._model_info.default_seq_pooling_type == default_pooling_type
assert model_config.pooler_config.seq_pooling_type == pooling_type
@pytest.mark.parametrize(
("model_id", "default_pooling_type", "pooling_type"),
[
("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward
("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"), # step reward
],
)
def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
def test_default_tok_pooling_type(model_id, default_pooling_type, pooling_type):
model_config = ModelConfig(model_id)
assert model_config._model_info.default_pooling_type == default_pooling_type
assert model_config.pooler_config.pooling_type == pooling_type
assert model_config._model_info.default_tok_pooling_type == default_pooling_type
assert model_config.pooler_config.tok_pooling_type == pooling_type
@pytest.mark.parametrize(
@@ -554,100 +566,100 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
"jason9693/Qwen2.5-1.5B-apeach",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
),
(
"Qwen/Qwen3-Embedding-0.6B",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
),
(
"Qwen/Qwen2.5-Math-PRM-7B",
"decoder",
False,
"Pooling models with step pooling does not support chunked prefill.",
"Pooling models with causal attn and LAST/STEP pooling do not support chunked prefill.", # noqa: E501
),
(
"internlm/internlm2-1_8b-reward",
"decoder",
True,
"Pooling models with causal attn and all pooling support chunked prefill.",
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
),
(
"BAAI/bge-base-en",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
(
"boltuix/NeuroBERT-NER",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
(
"papluca/xlm-roberta-base-language-detection",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
(
"intfloat/e5-small",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
# multimodal models
(
"openai/clip-vit-base-patch32",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
),
(
"google/siglip-base-patch16-224",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
# generate models
(
"Qwen/Qwen3-0.6B",
"decoder",
True,
"Generative models support chunked prefill.",
"Generative models support chunked prefill.", # noqa: E501
),
(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid",
True,
"Generative models support chunked prefill.",
"Generative models support chunked prefill.", # noqa: E501
),
(
"ibm-granite/granite-4.0-h-small",
"hybrid",
True,
"Generative models support chunked prefill.",
"Generative models support chunked prefill.", # noqa: E501
),
(
"state-spaces/mamba-130m-hf",
"attention_free",
True,
"Generative models support chunked prefill.",
"Generative models support chunked prefill.", # noqa: E501
),
# encoder_decoder models
(
"openai/whisper-small",
"encoder_decoder",
False,
"Encoder decoder models does not support chunked prefill.",
"Encoder decoder models do not support chunked prefill.", # noqa: E501
),
],
)
@@ -673,100 +685,100 @@ def test_is_chunked_prefill_supported(
"jason9693/Qwen2.5-1.5B-apeach",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
),
(
"Qwen/Qwen3-Embedding-0.6B",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
),
(
"Qwen/Qwen2.5-Math-PRM-7B",
"decoder",
False,
"Pooling models with step pooling does not support prefix caching.",
"Pooling models with causal attn and LAST/STEP pooling do not support prefix caching.", # noqa: E501
),
(
"internlm/internlm2-1_8b-reward",
"decoder",
True,
"Pooling models with causal attn and all pooling support prefix caching.",
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
),
(
"BAAI/bge-base-en",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
(
"boltuix/NeuroBERT-NER",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
(
"papluca/xlm-roberta-base-language-detection",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
(
"intfloat/e5-small",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
# multimodal models
(
"openai/clip-vit-base-patch32",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
),
(
"google/siglip-base-patch16-224",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
# generate models
(
"Qwen/Qwen3-0.6B",
"decoder",
True,
"Generative models support prefix caching.",
"Generative models support prefix caching.", # noqa: E501
),
(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid",
False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
"Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
),
(
"ibm-granite/granite-4.0-h-small",
"hybrid",
False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
"Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
),
(
"state-spaces/mamba-130m-hf",
"attention_free",
False,
"Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501
"Attention free models do not support prefix caching since the feature is still experimental.", # noqa: E501
),
# encoder_decoder models
(
"openai/whisper-small",
"encoder_decoder",
False,
"Encoder decoder models does not support prefix caching.",
"Encoder decoder models do not support prefix caching.", # noqa: E501
),
],
)