[Bugfix][2/n] Fix speculative decoding CI - Fix test_ngram_e2e_greedy_correctness (#19644)
This commit is contained in:
@@ -17,7 +17,10 @@ from .conftest import run_equality_correctness_test
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@@ -75,6 +78,9 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@@ -128,6 +134,9 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@@ -182,6 +191,9 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@@ -256,8 +268,12 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# The original model is float32, keep it for numerical stability.
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
|
||||
Reference in New Issue
Block a user