[Speculative Decoding] MLPSpeculator Tensor Parallel support (1/2) (#6050)

Co-authored-by: Sirej Dua <sirej.dua@databricks.com>
Co-authored-by: Sirej Dua <Sirej Dua>
This commit is contained in:
Sirej Dua
2024-07-02 07:20:29 -07:00
committed by GitHub
parent 31354e563f
commit 15aba081f3
3 changed files with 35 additions and 25 deletions

View File

@@ -70,10 +70,6 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
@@ -88,15 +84,31 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_draft_tensor_parallel_size": 1,
},
])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs, test_llm_kwargs",
[
(
{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a
# tokenizer.
"model": "JackFram/llama-68m",
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_draft_tensor_parallel_size": 1,
}),
({
"model": "ibm-granite/granite-3b-code-instruct",
}, {
"speculative_model":
"ibm-granite/granite-3b-code-instruct-accelerator",
"num_speculative_tokens": 5,
"speculative_draft_tensor_parallel_size": 1,
})
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,