Support embedding models in V1 (#16188)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
4959915089
commit
799397ee4f
@@ -8,6 +8,8 @@ import pytest
|
||||
from vllm import LLM, PoolingParams, PoolingRequestOutput
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
from ...models.utils import check_embeddings_close
|
||||
|
||||
MODEL_NAME = "intfloat/multilingual-e5-small"
|
||||
|
||||
PROMPTS = [
|
||||
@@ -27,6 +29,14 @@ TOKEN_IDS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
@@ -46,9 +56,15 @@ def llm():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def assert_outputs_equal(o1: list[PoolingRequestOutput],
|
||||
def assert_outputs_match(o1: list[PoolingRequestOutput],
|
||||
o2: list[PoolingRequestOutput]):
|
||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[o.outputs.data for o in o1],
|
||||
embeddings_1_lst=[o.outputs.data for o in o2],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@@ -63,7 +79,7 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||
|
||||
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
|
||||
pooling_params=pooling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
assert_outputs_match(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@@ -80,7 +96,7 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
||||
} for p in TOKEN_IDS],
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
assert_outputs_match(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
|
||||
Reference in New Issue
Block a user