[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs with ChunkedPrefill (#10132)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: wallashss <wallashss@ibm.com>
Co-authored-by: wallashss <wallashss@ibm.com>
This commit is contained in:
Nicolò Lucchesi
2025-01-27 22:38:35 +01:00
committed by GitHub
parent 2bc3fbba0c
commit 6116ca8cd7
16 changed files with 468 additions and 165 deletions

View File

@@ -4,26 +4,27 @@ import pytest
from vllm import SamplingParams
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
"enforce_eager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-160m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@@ -36,12 +37,15 @@ from .conftest import run_equality_correctness_test
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12])
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int, logprobs: int):
"""Verify output logprobs are equal with and without speculative decoding.
seed: int, logprobs: int, prefill_chunk_size: int):
"""Verify output logprobs are equal with and without speculative decoding,
as well as with and without chunked prefill.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,