[Frontend] support matryoshka representation / support embedding API dimensions (#16331)
This commit is contained in:
@@ -8,7 +8,8 @@ import math
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.embedding.utils import check_embeddings_close
|
||||
from tests.models.embedding.utils import check_embeddings_close, matryoshka_fy
|
||||
from vllm import PoolingParams
|
||||
|
||||
SCORING_MODELS = [
|
||||
"jinaai/jina-reranker-v2-base-multilingual", # Roberta
|
||||
@@ -126,3 +127,40 @@ def test_embeddings(
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", EMBEDDING_MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("dimensions", [16, 32])
|
||||
def test_matryoshka(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype: str,
|
||||
dimensions: int,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
|
||||
example_prompts = EMBEDDING_PROMPTS
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
is_sentence_transformer=True,
|
||||
) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts, task="text-matching")
|
||||
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
|
||||
|
||||
with vllm_runner(model, task="embed", dtype=dtype,
|
||||
max_model_len=None) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(
|
||||
example_prompts,
|
||||
pooling_params=PoolingParams(dimensions=dimensions))
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user