[Speculative Decoding] MLPSpeculator Tensor Parallel support (1/2) (#6050)
Co-authored-by: Sirej Dua <sirej.dua@databricks.com> Co-authored-by: Sirej Dua <Sirej Dua>
This commit is contained in:
@@ -70,10 +70,6 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
@@ -88,15 +84,31 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
|
||||
# second run of the test to fail with internal NCCL error.
|
||||
"use_async": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_draft_tensor_parallel_size": 1,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs, test_llm_kwargs",
|
||||
[
|
||||
(
|
||||
{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a
|
||||
# tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
({
|
||||
"model": "ibm-granite/granite-3b-code-instruct",
|
||||
}, {
|
||||
"speculative_model":
|
||||
"ibm-granite/granite-3b-code-instruct-accelerator",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_draft_tensor_parallel_size": 1,
|
||||
})
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
|
||||
|
||||
Reference in New Issue
Block a user