[Model][0/N] Improve all pooling task | clean up (#25817)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-10-13 16:44:50 +08:00
committed by GitHub
parent 4f207c7174
commit 767c3ab869
19 changed files with 198 additions and 189 deletions

View File

@@ -6,12 +6,16 @@ from collections.abc import Sequence
import mteb
import numpy as np
import pytest
import requests
import torch
import tests.ci_envs as ci_envs
from tests.models.utils import EmbedModelInfo, RerankModelInfo, check_embeddings_close
from tests.models.utils import (
EmbedModelInfo,
RerankModelInfo,
check_embeddings_close,
get_vllm_extra_kwargs,
)
# Most embedding models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
@@ -165,28 +169,11 @@ def mteb_test_embed_models(
hf_model_callback=None,
atol=MTEB_EMBED_TOL,
):
# A model family has many models with the same architecture,
# and we don't need to test each one.
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
pytest.skip("Skipping test.")
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
# Test embed_dims, isnan and whether to use normalize
example_prompts = ["The chef prepared a delicious meal." * 1000]
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
# Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
# Allow changing the head dtype used by vllm in tests
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
if "hf_overrides" not in vllm_extra_kwargs:
vllm_extra_kwargs["hf_overrides"] = {}
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
with vllm_runner(
model_info.name,
runner="pooling",
@@ -212,9 +199,12 @@ def mteb_test_embed_models(
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
head_dtype = model_config.head_dtype
# Test embed_dims, isnan and whether to use normalize
# Test embedding_size, isnan and whether to use normalize
vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1)
assert not torch.any(torch.isnan(torch.tensor(vllm_outputs)))
outputs_tensor = torch.tensor(vllm_outputs)
assert not torch.any(torch.isnan(outputs_tensor))
embedding_size = model_config.embedding_size
assert torch.tensor(vllm_outputs).shape[-1] == embedding_size
# Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant
@@ -231,7 +221,7 @@ def mteb_test_embed_models(
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
st_dtype = next(hf_model.model.parameters()).dtype
# Test embed_dims and whether to use normalize
# Check embeddings close to hf outputs
hf_outputs = hf_model.encode(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
@@ -323,24 +313,7 @@ def mteb_test_rerank_models(
vllm_mteb_encoder=VllmMtebEncoder,
atol=MTEB_RERANK_TOL,
):
# A model family has many models with the same architecture,
# and we don't need to test each one.
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
pytest.skip("Skipping test.")
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
# Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
# Allow changing the head dtype used by vllm in tests
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
if "hf_overrides" not in vllm_extra_kwargs:
vllm_extra_kwargs["hf_overrides"] = {}
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
with vllm_runner(
model_info.name,