[CI] improve embed testing (#18747)

This commit is contained in:
wang.yuqi
2025-05-28 15:16:35 +08:00
committed by GitHub
parent 0c492b7824
commit de65fc8e1e
13 changed files with 248 additions and 178 deletions

View File

@@ -3,7 +3,8 @@ from typing import Any
import pytest
from ...utils import EmbedModelInfo, run_embedding_correctness_test
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [
########## BertModel
@@ -53,9 +54,8 @@ MODELS = [
@pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
from .mteb_utils import mteb_test_embed_models
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
@@ -66,28 +66,13 @@ def test_models_mteb(hf_runner, vllm_runner,
@pytest.mark.parametrize("model_info", MODELS)
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
example_prompts) -> None:
if not model_info.enable_test:
pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
def test_embed_models_correctness(hf_runner, vllm_runner,
model_info: EmbedModelInfo,
example_prompts) -> None:
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(
model_info.name,
dtype=model_info.dtype,
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
example_prompts, vllm_extra_kwargs)