This commit is contained in:
@@ -49,6 +49,21 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
|
||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize('prompt', PROMPTS)
|
||||
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompts'"):
|
||||
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)
|
||||
|
||||
v2_output = llm.encode(prompt, pooling_params=pooling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
|
||||
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||
@@ -64,6 +79,25 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompts'"):
|
||||
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)
|
||||
|
||||
v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
v2_output = llm.encode(
|
||||
[{
|
||||
"prompt": p
|
||||
} for p in PROMPTS],
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
Reference in New Issue
Block a user