[Spec Decode] Integrate Suffix Decoding from Arctic Inference (#25784)

Co-authored-by: Aurick Qiao <aurick.qiao@snowflake.com>
This commit is contained in:
Aurick Qiao
2025-11-03 09:23:31 -08:00
committed by GitHub
parent 4bc400f47e
commit 2c19d96777
8 changed files with 304 additions and 11 deletions

View File

@@ -75,7 +75,23 @@ def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"
def test_ngram_correctness(
@pytest.mark.parametrize(
"speculative_config",
[
{
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
{
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
},
],
)
def test_ngram_and_suffix_correctness(
speculative_config: dict,
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
@@ -94,12 +110,7 @@ def test_ngram_correctness(
spec_llm = LLM(
model=model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
speculative_config=speculative_config,
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
@@ -121,6 +132,66 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory()
def test_suffix_decoding_acceptance(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
):
"""
Check that suffix decoding caching takes effect and improves acceptance
lengths and acceptance rates over multiple runs of the same prompts.
"""
test_prompts = get_test_prompts(mm_enabled=False)
spec_llm = LLM(
model=model_name,
speculative_config={
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
"suffix_decoding_max_cached_requests": 1000,
},
max_model_len=1024,
disable_log_stats=False,
)
# Run several times and check that the accepted tokens increase.
spec_llm.chat(test_prompts, sampling_config)
num_draft = []
num_accept = []
for i in range(10): # Run multiple times to warm up the cache.
spec_llm.chat(test_prompts, sampling_config)
# Collect draft and acceptance stats.
metrics = spec_llm.get_metrics()
for metric in metrics:
if metric.name == "vllm:spec_decode_num_draft_tokens":
num_draft.append(metric.value)
if metric.name == "vllm:spec_decode_num_accepted_tokens":
num_accept.append(metric.value)
# Calculate the acceptance rates for the first and last runs.
first_accept_tokens = num_accept[0]
first_draft_tokens = num_draft[0]
first_accept_rate = first_accept_tokens / first_draft_tokens
# Take the diff since the stats are cumulative.
last_accept_tokens = num_accept[-1] - num_accept[-2]
last_draft_tokens = num_draft[-1] - num_draft[-2]
last_accept_rate = last_accept_tokens / last_draft_tokens
# Expect the acceptance length to improve.
assert first_accept_tokens < last_accept_tokens
# Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate
# Heuristic: expect at least 85% acceptance rate at the end.
assert last_accept_rate > 0.85
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@pytest.mark.parametrize(
"model_path",
[