[Spec Decode] Disable Log Prob serialization to CPU for spec decoding for both draft and target models. (#6485)
This commit is contained in:
@@ -22,10 +22,12 @@ from .conftest import get_logprobs_from_llm_generator
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@@ -59,10 +61,12 @@ def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize("num_logprobs", [6])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -99,13 +103,16 @@ def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 6,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 6,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@@ -143,6 +150,7 @@ def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
@@ -181,10 +189,12 @@ def test_logprobs_when_skip_speculation(baseline_llm_generator,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
|
||||
Reference in New Issue
Block a user