[Model] Pooling models default to using chunked prefill & prefix caching if supported. (#20930)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -177,9 +177,12 @@ def mteb_test_embed_models(hf_runner,
|
||||
max_model_len=None,
|
||||
**vllm_extra_kwargs) as vllm_model:
|
||||
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
|
||||
if model_info.architecture:
|
||||
assert (model_info.architecture
|
||||
in vllm_model.llm.llm_engine.model_config.architectures)
|
||||
assert model_info.architecture in model_config.architectures
|
||||
assert (model_config._model_info.default_pooling_type ==
|
||||
model_info.default_pooling_type)
|
||||
|
||||
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
||||
MTEB_EMBED_TASKS)
|
||||
@@ -286,7 +289,12 @@ def mteb_test_rerank_models(hf_runner,
|
||||
**vllm_extra_kwargs) as vllm_model:
|
||||
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
|
||||
if model_info.architecture:
|
||||
assert (model_info.architecture in model_config.architectures)
|
||||
assert model_config.hf_config.num_labels == 1
|
||||
assert (model_config._model_info.default_pooling_type ==
|
||||
model_info.default_pooling_type)
|
||||
|
||||
vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model),
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
from tests.models.language.pooling.embed_utils import (
|
||||
run_embedding_correctness_test)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["jason9693/Qwen2.5-1.5B-apeach"],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_classify_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
|
||||
example_prompts = example_prompts * 2
|
||||
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
enable_prefix_caching=True) as vllm_model:
|
||||
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||
assert cache_config.enable_prefix_caching
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
||||
hf_outputs = hf_model.classify(example_prompts)
|
||||
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
hf_output = torch.tensor(hf_output)
|
||||
vllm_output = torch.tensor(vllm_output)
|
||||
|
||||
assert torch.allclose(hf_output, vllm_output,
|
||||
1e-3 if dtype == "float" else 1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["Qwen/Qwen3-Embedding-0.6B"],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_embed_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
):
|
||||
example_prompts = [str(s).strip() for s in example_prompts] * 2
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
max_model_len=None,
|
||||
enable_prefix_caching=True,
|
||||
) as vllm_model:
|
||||
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||
assert cache_config.enable_prefix_caching
|
||||
vllm_outputs = vllm_model.embed(example_prompts)
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
is_sentence_transformer=True,
|
||||
) as hf_model:
|
||||
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"intfloat/e5-small",
|
||||
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", # is_causal == False
|
||||
"papluca/xlm-roberta-base-language-detection",
|
||||
])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_non_causal_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
dtype: str) -> None:
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
enable_prefix_caching=True) as vllm_model:
|
||||
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||
assert not cache_config.enable_prefix_caching
|
||||
@@ -2,73 +2,78 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from ...utils import EmbedModelInfo, RerankModelInfo
|
||||
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
|
||||
EmbedModelInfo, LASTPoolingEmbedModelInfo,
|
||||
RerankModelInfo)
|
||||
from .embed_utils import correctness_test_embed_models
|
||||
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
|
||||
|
||||
MODELS = [
|
||||
########## BertModel
|
||||
EmbedModelInfo("BAAI/bge-base-en",
|
||||
architecture="BertModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("BAAI/bge-base-zh",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("BAAI/bge-small-en",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("BAAI/bge-small-zh",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("BAAI/bge-large-en",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("BAAI/bge-large-zh",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("BAAI/bge-large-zh-noinstruct",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("BAAI/bge-base-en-v1.5",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("BAAI/bge-base-zh-v1.5",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("BAAI/bge-small-en-v1.5",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("BAAI/bge-small-zh-v1.5",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("BAAI/bge-large-en-v1.5",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("BAAI/bge-large-zh-v1.5",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-base-en",
|
||||
architecture="BertModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-base-zh",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-small-en",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-small-zh",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-large-en",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-noinstruct",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-base-en-v1.5",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-base-zh-v1.5",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-small-en-v1.5",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-small-zh-v1.5",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-large-en-v1.5",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-v1.5",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
########## XLMRobertaModel
|
||||
EmbedModelInfo("BAAI/bge-m3",
|
||||
architecture="XLMRobertaModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("BAAI/bge-m3",
|
||||
architecture="XLMRobertaModel",
|
||||
enable_test=True),
|
||||
########## Qwen2Model
|
||||
EmbedModelInfo("BAAI/bge-code-v1",
|
||||
architecture="Qwen2Model",
|
||||
dtype="float32",
|
||||
enable_test=True),
|
||||
LASTPoolingEmbedModelInfo("BAAI/bge-code-v1",
|
||||
architecture="Qwen2Model",
|
||||
dtype="float32",
|
||||
enable_test=True),
|
||||
]
|
||||
|
||||
RERANK_MODELS = [
|
||||
########## XLMRobertaForSequenceClassification
|
||||
RerankModelInfo("BAAI/bge-reranker-base",
|
||||
architecture="XLMRobertaForSequenceClassification",
|
||||
enable_test=True),
|
||||
RerankModelInfo("BAAI/bge-reranker-large",
|
||||
architecture="XLMRobertaForSequenceClassification",
|
||||
enable_test=False),
|
||||
RerankModelInfo("BAAI/bge-reranker-v2-m3",
|
||||
architecture="XLMRobertaForSequenceClassification",
|
||||
enable_test=False)
|
||||
CLSPoolingRerankModelInfo(
|
||||
"BAAI/bge-reranker-base",
|
||||
architecture="XLMRobertaForSequenceClassification",
|
||||
enable_test=True),
|
||||
CLSPoolingRerankModelInfo(
|
||||
"BAAI/bge-reranker-large",
|
||||
architecture="XLMRobertaForSequenceClassification",
|
||||
enable_test=False),
|
||||
CLSPoolingRerankModelInfo(
|
||||
"BAAI/bge-reranker-v2-m3",
|
||||
architecture="XLMRobertaForSequenceClassification",
|
||||
enable_test=False)
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@ import torch
|
||||
|
||||
from tests.conftest import HfRunner
|
||||
|
||||
from .mteb_utils import (RerankModelInfo, VllmMtebEncoder,
|
||||
mteb_test_rerank_models)
|
||||
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
||||
from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
|
||||
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
||||
architecture="GemmaForSequenceClassification"),
|
||||
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
||||
architecture="GemmaForSequenceClassification"),
|
||||
]
|
||||
|
||||
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
|
||||
from ...utils import (CLSPoolingRerankModelInfo, LASTPoolingRerankModelInfo,
|
||||
RerankModelInfo)
|
||||
from .mteb_utils import mteb_test_rerank_models
|
||||
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
||||
architecture="BertForSequenceClassification"),
|
||||
RerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
|
||||
architecture="Qwen3ForSequenceClassification")
|
||||
CLSPoolingRerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
||||
architecture="BertForSequenceClassification"),
|
||||
LASTPoolingRerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
|
||||
architecture="Qwen3ForSequenceClassification")
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -4,57 +4,58 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from ...utils import check_transformers_version
|
||||
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
|
||||
from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo,
|
||||
LASTPoolingEmbedModelInfo, check_transformers_version)
|
||||
from .embed_utils import correctness_test_embed_models
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
|
||||
MODELS = [
|
||||
########## BertModel
|
||||
EmbedModelInfo("thenlper/gte-large",
|
||||
architecture="BertModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("thenlper/gte-base",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("thenlper/gte-small",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("thenlper/gte-large-zh",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("thenlper/gte-base-zh",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("thenlper/gte-small-zh",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("thenlper/gte-large",
|
||||
architecture="BertModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("thenlper/gte-base",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("thenlper/gte-small",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("thenlper/gte-large-zh",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("thenlper/gte-base-zh",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("thenlper/gte-small-zh",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
########### NewModel
|
||||
EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
|
||||
architecture="GteNewModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
|
||||
architecture="GteNewModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
|
||||
architecture="GteNewModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
|
||||
architecture="GteNewModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
|
||||
architecture="GteNewModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
|
||||
architecture="GteNewModel",
|
||||
enable_test=True),
|
||||
########### Qwen2ForCausalLM
|
||||
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
architecture="Qwen2ForCausalLM",
|
||||
enable_test=True),
|
||||
LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
architecture="Qwen2ForCausalLM",
|
||||
enable_test=True),
|
||||
########## ModernBertModel
|
||||
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
|
||||
architecture="ModernBertModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
|
||||
architecture="ModernBertModel",
|
||||
enable_test=True),
|
||||
########## Qwen3ForCausalLM
|
||||
EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
|
||||
architecture="Qwen3ForCausalLM",
|
||||
dtype="float32",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("Qwen/Qwen3-Embedding-4B",
|
||||
architecture="Qwen3ForCausalLM",
|
||||
dtype="float32",
|
||||
enable_test=False),
|
||||
LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
|
||||
architecture="Qwen3ForCausalLM",
|
||||
dtype="float32",
|
||||
enable_test=True),
|
||||
LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-4B",
|
||||
architecture="Qwen3ForCausalLM",
|
||||
dtype="float32",
|
||||
enable_test=False),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -2,34 +2,34 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from ...utils import EmbedModelInfo
|
||||
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
|
||||
from .embed_utils import correctness_test_embed_models
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
|
||||
MODELS = [
|
||||
########## BertModel
|
||||
EmbedModelInfo("intfloat/e5-small",
|
||||
architecture="BertModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("intfloat/e5-base",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("intfloat/e5-large",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("intfloat/multilingual-e5-small",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("intfloat/e5-small",
|
||||
architecture="BertModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("intfloat/e5-base",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("intfloat/e5-large",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-small",
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
########## XLMRobertaModel
|
||||
EmbedModelInfo("intfloat/multilingual-e5-base",
|
||||
architecture="XLMRobertaModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("intfloat/multilingual-e5-large",
|
||||
architecture="XLMRobertaModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("intfloat/multilingual-e5-large-instruct",
|
||||
architecture="XLMRobertaModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-base",
|
||||
architecture="XLMRobertaModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large",
|
||||
architecture="XLMRobertaModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large-instruct",
|
||||
architecture="XLMRobertaModel",
|
||||
enable_test=False),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -6,20 +6,22 @@ import pytest
|
||||
|
||||
from vllm import PoolingParams
|
||||
|
||||
from ...utils import EmbedModelInfo, RerankModelInfo
|
||||
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
|
||||
EmbedModelInfo, RerankModelInfo)
|
||||
from .embed_utils import (check_embeddings_close,
|
||||
correctness_test_embed_models, matryoshka_fy)
|
||||
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
|
||||
|
||||
EMBEDDING_MODELS = [
|
||||
EmbedModelInfo("jinaai/jina-embeddings-v3",
|
||||
architecture="XLMRobertaModel",
|
||||
is_matryoshka=True)
|
||||
CLSPoolingEmbedModelInfo("jinaai/jina-embeddings-v3",
|
||||
architecture="XLMRobertaModel",
|
||||
is_matryoshka=True)
|
||||
]
|
||||
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo("jinaai/jina-reranker-v2-base-multilingual",
|
||||
architecture="XLMRobertaForSequenceClassification")
|
||||
CLSPoolingRerankModelInfo(
|
||||
"jinaai/jina-reranker-v2-base-multilingual",
|
||||
architecture="XLMRobertaForSequenceClassification")
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -7,15 +7,16 @@ import torch
|
||||
|
||||
from tests.conftest import HfRunner
|
||||
|
||||
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
|
||||
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
||||
from .mteb_utils import mteb_test_rerank_models
|
||||
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
enable_test=True),
|
||||
RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
enable_test=False)
|
||||
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
enable_test=True),
|
||||
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
enable_test=False)
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -3,22 +3,23 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
|
||||
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
|
||||
from .embed_utils import correctness_test_embed_models
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
|
||||
MODELS = [
|
||||
EmbedModelInfo("nomic-ai/nomic-embed-text-v1",
|
||||
architecture="NomicBertModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
|
||||
architecture="NomicBertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("nomic-ai/CodeRankEmbed",
|
||||
architecture="NomicBertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
|
||||
architecture="NomicBertModel",
|
||||
enable_test=True)
|
||||
CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1",
|
||||
architecture="NomicBertModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
|
||||
architecture="NomicBertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("nomic-ai/CodeRankEmbed",
|
||||
architecture="NomicBertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
|
||||
architecture="NomicBertModel",
|
||||
enable_test=True)
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -8,15 +8,16 @@ import torch
|
||||
from tests.conftest import HfRunner
|
||||
from tests.utils import multi_gpu_test
|
||||
|
||||
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
|
||||
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
||||
from .mteb_utils import mteb_test_rerank_models
|
||||
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
enable_test=True),
|
||||
RerankModelInfo("Qwen/Qwen3-Reranker-4B",
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
enable_test=False)
|
||||
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
enable_test=True),
|
||||
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B",
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
enable_test=False)
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -3,42 +3,43 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
|
||||
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
|
||||
from .embed_utils import correctness_test_embed_models
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
|
||||
MODELS = [
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
|
||||
is_matryoshka=False,
|
||||
architecture="NomicBertModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
|
||||
is_matryoshka=True,
|
||||
architecture="BertModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
|
||||
is_matryoshka=True,
|
||||
architecture="XLMRobertaModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
||||
is_matryoshka=True,
|
||||
architecture="GteModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
|
||||
is_matryoshka=False,
|
||||
architecture="NomicBertModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
|
||||
is_matryoshka=True,
|
||||
architecture="BertModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
|
||||
is_matryoshka=True,
|
||||
architecture="XLMRobertaModel",
|
||||
enable_test=True),
|
||||
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
||||
is_matryoshka=True,
|
||||
architecture="GteModel",
|
||||
enable_test=True),
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user