[Misc] Upgrade to pytorch 2.5 (#9588)
Signed-off-by: Bill Nell <bill@neuralmagic.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -8,7 +8,7 @@ import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
from ...utils import check_logprobs_close, check_outputs_equal
|
||||
|
||||
MODELS = [
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
@@ -43,18 +43,40 @@ def test_models(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
if model == "openbmb/MiniCPM3-4B":
|
||||
# the output becomes slightly different when upgrading to
|
||||
# pytorch 2.5 . Changing to logprobs checks instead of exact
|
||||
# output checks.
|
||||
NUM_LOG_PROBS = 8
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
else:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
|
||||
Reference in New Issue
Block a user