[Cleanup] Remove obsolete spec decoding compatibility logic (#32003)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -940,27 +940,62 @@ def test_correct_decoded_token_preserves_valid_tokens():
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||
{
|
||||
"method": "eagle",
|
||||
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
0,
|
||||
),
|
||||
marks=large_gpu_mark(min_gb=32),
|
||||
id="eagle0",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
{
|
||||
"method": "eagle",
|
||||
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
3,
|
||||
),
|
||||
marks=large_gpu_mark(min_gb=32),
|
||||
id="eagle3",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"ngram",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
{
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
3,
|
||||
),
|
||||
marks=large_gpu_mark(min_gb=32),
|
||||
id="ngram",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("top_logprobs", [0, 3])
|
||||
def test_spec_decode_logprobs(
|
||||
logprobs_mode: LogprobsMode,
|
||||
model_setup: tuple[str, str, str],
|
||||
top_logprobs: int,
|
||||
model_setup: tuple[str, str, dict, int],
|
||||
):
|
||||
"""Spec decode logprobs should match those of the base model.
|
||||
|
||||
Args:
|
||||
logprobs_mode: logprobs mode.
|
||||
model_setup: Spec decode method, base model name, and
|
||||
draft model name.
|
||||
model_setup: Tuple of (method, base model name,
|
||||
speculative_config dict, top_logprobs).
|
||||
"""
|
||||
from vllm import LLM
|
||||
|
||||
method, model_name, spec_config, top_logprobs = model_setup
|
||||
|
||||
prompt = "Hello world " * 50
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
|
||||
@@ -972,7 +1007,7 @@ def test_spec_decode_logprobs(
|
||||
ignore_eos=False,
|
||||
presence_penalty=-1.0,
|
||||
)
|
||||
method, model_name, spec_model_name = model_setup
|
||||
|
||||
max_model_len = 256
|
||||
|
||||
# Run base LLM.
|
||||
@@ -999,14 +1034,11 @@ def test_spec_decode_logprobs(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Run spec decode LLM.
|
||||
# Add max_model_len to spec_config if not present
|
||||
spec_config_with_len = {**spec_config, "max_model_len": max_model_len}
|
||||
spec_llm = LLM(
|
||||
model_name,
|
||||
speculative_config={
|
||||
"method": method,
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 3,
|
||||
"max_model_len": max_model_len,
|
||||
},
|
||||
speculative_config=spec_config_with_len,
|
||||
max_logprobs=5,
|
||||
max_model_len=max_model_len,
|
||||
seed=42,
|
||||
|
||||
Reference in New Issue
Block a user