[Model][0/N] Improve all pooling task | clean up (#25817)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -6,12 +6,16 @@ from collections.abc import Sequence
|
||||
|
||||
import mteb
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
import tests.ci_envs as ci_envs
|
||||
from tests.models.utils import EmbedModelInfo, RerankModelInfo, check_embeddings_close
|
||||
from tests.models.utils import (
|
||||
EmbedModelInfo,
|
||||
RerankModelInfo,
|
||||
check_embeddings_close,
|
||||
get_vllm_extra_kwargs,
|
||||
)
|
||||
|
||||
# Most embedding models on the STS12 task (See #17175):
|
||||
# - Model implementation and minor changes in tensor dtype
|
||||
@@ -165,28 +169,11 @@ def mteb_test_embed_models(
|
||||
hf_model_callback=None,
|
||||
atol=MTEB_EMBED_TOL,
|
||||
):
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
|
||||
pytest.skip("Skipping test.")
|
||||
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
|
||||
|
||||
# Test embed_dims, isnan and whether to use normalize
|
||||
example_prompts = ["The chef prepared a delicious meal." * 1000]
|
||||
|
||||
# Allow vllm to test using the given dtype, such as float32
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
|
||||
|
||||
# Allow vllm to test using hf_overrides
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
# Allow changing the head dtype used by vllm in tests
|
||||
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
|
||||
if "hf_overrides" not in vllm_extra_kwargs:
|
||||
vllm_extra_kwargs["hf_overrides"] = {}
|
||||
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
|
||||
|
||||
with vllm_runner(
|
||||
model_info.name,
|
||||
runner="pooling",
|
||||
@@ -212,9 +199,12 @@ def mteb_test_embed_models(
|
||||
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
|
||||
head_dtype = model_config.head_dtype
|
||||
|
||||
# Test embed_dims, isnan and whether to use normalize
|
||||
# Test embedding_size, isnan and whether to use normalize
|
||||
vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1)
|
||||
assert not torch.any(torch.isnan(torch.tensor(vllm_outputs)))
|
||||
outputs_tensor = torch.tensor(vllm_outputs)
|
||||
assert not torch.any(torch.isnan(outputs_tensor))
|
||||
embedding_size = model_config.embedding_size
|
||||
assert torch.tensor(vllm_outputs).shape[-1] == embedding_size
|
||||
|
||||
# Accelerate mteb test by setting
|
||||
# SentenceTransformers mteb score to a constant
|
||||
@@ -231,7 +221,7 @@ def mteb_test_embed_models(
|
||||
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
|
||||
st_dtype = next(hf_model.model.parameters()).dtype
|
||||
|
||||
# Test embed_dims and whether to use normalize
|
||||
# Check embeddings close to hf outputs
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
@@ -323,24 +313,7 @@ def mteb_test_rerank_models(
|
||||
vllm_mteb_encoder=VllmMtebEncoder,
|
||||
atol=MTEB_RERANK_TOL,
|
||||
):
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
# Allow vllm to test using the given dtype, such as float32
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
|
||||
|
||||
# Allow vllm to test using hf_overrides
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
# Allow changing the head dtype used by vllm in tests
|
||||
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
|
||||
if "hf_overrides" not in vllm_extra_kwargs:
|
||||
vllm_extra_kwargs["hf_overrides"] = {}
|
||||
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
|
||||
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
|
||||
|
||||
with vllm_runner(
|
||||
model_info.name,
|
||||
|
||||
Reference in New Issue
Block a user