This commit is contained in:
@@ -47,6 +47,23 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
|
||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize('prompt', PROMPTS)
|
||||
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
|
||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompts'"):
|
||||
v1_output = llm.generate(prompts=prompt,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
v2_output = llm.generate(prompt, sampling_params=sampling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
v2_output = llm.generate({"prompt": prompt},
|
||||
sampling_params=sampling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
|
||||
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||
@@ -62,6 +79,26 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
|
||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompts'"):
|
||||
v1_output = llm.generate(prompts=PROMPTS,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
v2_output = llm.generate(
|
||||
[{
|
||||
"prompt": p
|
||||
} for p in PROMPTS],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||
|
||||
Reference in New Issue
Block a user