[Refactor] Separate sequence and token pooling types (#32026)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -161,7 +161,8 @@ def test_get_pooling_config():
|
||||
|
||||
assert model_config.pooler_config is not None
|
||||
assert model_config.pooler_config.normalize
|
||||
assert model_config.pooler_config.pooling_type == "MEAN"
|
||||
assert model_config.pooler_config.seq_pooling_type == "MEAN"
|
||||
assert model_config.pooler_config.tok_pooling_type == "ALL"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -169,7 +170,7 @@ def test_get_pooling_config():
|
||||
)
|
||||
def test_get_pooling_config_from_args():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True)
|
||||
pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True)
|
||||
model_config = ModelConfig(model_id, pooler_config=pooler_config)
|
||||
|
||||
assert asdict(model_config.pooler_config) == asdict(pooler_config)
|
||||
@@ -180,14 +181,25 @@ def test_get_pooling_config_from_args():
|
||||
[
|
||||
("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM
|
||||
("intfloat/e5-small", "CLS", "MEAN"), # BertModel
|
||||
],
|
||||
)
|
||||
def test_default_seq_pooling_type(model_id, default_pooling_type, pooling_type):
|
||||
model_config = ModelConfig(model_id)
|
||||
assert model_config._model_info.default_seq_pooling_type == default_pooling_type
|
||||
assert model_config.pooler_config.seq_pooling_type == pooling_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "default_pooling_type", "pooling_type"),
|
||||
[
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward
|
||||
("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"), # step reward
|
||||
],
|
||||
)
|
||||
def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
|
||||
def test_default_tok_pooling_type(model_id, default_pooling_type, pooling_type):
|
||||
model_config = ModelConfig(model_id)
|
||||
assert model_config._model_info.default_pooling_type == default_pooling_type
|
||||
assert model_config.pooler_config.pooling_type == pooling_type
|
||||
assert model_config._model_info.default_tok_pooling_type == default_pooling_type
|
||||
assert model_config.pooler_config.tok_pooling_type == pooling_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -554,100 +566,100 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
|
||||
"jason9693/Qwen2.5-1.5B-apeach",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support chunked prefill.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-Embedding-0.6B",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support chunked prefill.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen2.5-Math-PRM-7B",
|
||||
"decoder",
|
||||
False,
|
||||
"Pooling models with step pooling does not support chunked prefill.",
|
||||
"Pooling models with causal attn and LAST/STEP pooling do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"internlm/internlm2-1_8b-reward",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and all pooling support chunked prefill.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"BAAI/bge-base-en",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"boltuix/NeuroBERT-NER",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"papluca/xlm-roberta-base-language-detection",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"intfloat/e5-small",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
# multimodal models
|
||||
(
|
||||
"openai/clip-vit-base-patch32",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support chunked prefill.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"google/siglip-base-patch16-224",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
# generate models
|
||||
(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
"decoder",
|
||||
True,
|
||||
"Generative models support chunked prefill.",
|
||||
"Generative models support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
"hybrid",
|
||||
True,
|
||||
"Generative models support chunked prefill.",
|
||||
"Generative models support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"ibm-granite/granite-4.0-h-small",
|
||||
"hybrid",
|
||||
True,
|
||||
"Generative models support chunked prefill.",
|
||||
"Generative models support chunked prefill.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"attention_free",
|
||||
True,
|
||||
"Generative models support chunked prefill.",
|
||||
"Generative models support chunked prefill.", # noqa: E501
|
||||
),
|
||||
# encoder_decoder models
|
||||
(
|
||||
"openai/whisper-small",
|
||||
"encoder_decoder",
|
||||
False,
|
||||
"Encoder decoder models does not support chunked prefill.",
|
||||
"Encoder decoder models do not support chunked prefill.", # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -673,100 +685,100 @@ def test_is_chunked_prefill_supported(
|
||||
"jason9693/Qwen2.5-1.5B-apeach",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support prefix caching.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-Embedding-0.6B",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support prefix caching.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen2.5-Math-PRM-7B",
|
||||
"decoder",
|
||||
False,
|
||||
"Pooling models with step pooling does not support prefix caching.",
|
||||
"Pooling models with causal attn and LAST/STEP pooling do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"internlm/internlm2-1_8b-reward",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and all pooling support prefix caching.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"BAAI/bge-base-en",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"boltuix/NeuroBERT-NER",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"papluca/xlm-roberta-base-language-detection",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"intfloat/e5-small",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
# multimodal models
|
||||
(
|
||||
"openai/clip-vit-base-patch32",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support prefix caching.",
|
||||
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"google/siglip-base-patch16-224",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
# generate models
|
||||
(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
"decoder",
|
||||
True,
|
||||
"Generative models support prefix caching.",
|
||||
"Generative models support prefix caching.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
"hybrid",
|
||||
False,
|
||||
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
"Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"ibm-granite/granite-4.0-h-small",
|
||||
"hybrid",
|
||||
False,
|
||||
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
"Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"attention_free",
|
||||
False,
|
||||
"Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
"Attention free models do not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
),
|
||||
# encoder_decoder models
|
||||
(
|
||||
"openai/whisper-small",
|
||||
"encoder_decoder",
|
||||
False,
|
||||
"Encoder decoder models does not support prefix caching.",
|
||||
"Encoder decoder models do not support prefix caching.", # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user