[V1][Spec Decode] Handle draft tokens beyond max_model_len (#16087)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-04-21 12:38:50 -07:00
committed by GitHub
parent 299ebb62b2
commit 3a0fba5cf4
7 changed files with 137 additions and 15 deletions

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
"""Test whether spec decoding handles the max model length properly."""
import pytest
from vllm import LLM, SamplingParams
_PROMPTS = [
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
"Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501
"Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501
]
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_ngram_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="facebook/opt-125m",
max_model_len=100,
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": num_speculative_tokens,
},
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
},
max_model_len=100,
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)

View File

@@ -2,7 +2,7 @@
import numpy as np
from vllm.config import SpeculativeConfig, VllmConfig
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
_find_subarray_kmp,
_kmp_lps_array)
@@ -42,14 +42,24 @@ def test_find_subarray_kmp():
def test_ngram_proposer():
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
return NgramProposer(vllm_config=VllmConfig(
speculative_config=SpeculativeConfig.from_dict(
{
"prompt_lookup_min": min_n,
"prompt_lookup_max": max_n,
"num_speculative_tokens": k,
"method": "ngram",
})))
# Dummy model config. Just to set max_model_len.
model_config = ModelConfig(model="facebook/opt-125m",
task="generate",
max_model_len=100,
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
dtype="auto",
seed=None,
trust_remote_code=False)
return NgramProposer(
vllm_config=VllmConfig(model_config=model_config,
speculative_config=SpeculativeConfig.
from_dict({
"prompt_lookup_min": min_n,
"prompt_lookup_max": max_n,
"num_speculative_tokens": k,
"method": "ngram",
})))
# No match.
result = ngram_proposer(