[Speculative Decoding] Test refactor (#8317)

Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
Lily Liu
2024-09-11 14:07:34 -07:00
committed by GitHub
parent 8baa454937
commit 775f00f81e
12 changed files with 927 additions and 1042 deletions

View File

@@ -4,7 +4,9 @@ other features, e.g. cuda graphs.
import pytest
from .conftest import run_greedy_equality_correctness_test
from .conftest import run_equality_correctness_test
MAIN_MODEL = "JackFram/llama-68m"
@pytest.mark.parametrize(
@@ -15,7 +17,7 @@ from .conftest import run_greedy_equality_correctness_test
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model": "JackFram/llama-68m",
"model_name": "JackFram/llama-68m",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
@@ -31,23 +33,27 @@ from .conftest import run_greedy_equality_correctness_test
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
@@ -80,13 +86,19 @@ def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_speculative_model_quantization_config(baseline_llm_generator,
test_llm_generator,
batch_size: int):
def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int, seed: int):
"""Verify spec decode works well with draft model quantization configs.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0)