feat: spec decode with draft models (#24322)

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
Tomas Ruiz
2026-01-19 15:05:46 -06:00
committed by GitHub
parent 73f2a81c75
commit 4a5299c93f
21 changed files with 897 additions and 115 deletions

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any
import pytest
@@ -10,32 +12,45 @@ from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.benchmarks.datasets import InstructCoderDataset
from vllm.config.vllm import VllmConfig
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.v1.metrics.reader import Metric
from vllm.v1.spec_decode.draft_model import (
create_vllm_config_for_draft_model,
merge_toks_kernel,
)
MTP_SIMILARITY_RATE = 0.8
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
"""Skip test if available GPUs < tp_size on ROCm."""
if current_platform.is_rocm():
available_gpus = torch.cuda.device_count()
if available_gpus < tp_size:
pytest.skip(
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
)
available_gpus = torch.cuda.device_count()
if available_gpus < tp_size:
pytest.skip(
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
)
def get_test_prompts(mm_enabled: bool):
Messages = list[dict[str, Any]]
def get_test_prompts(
mm_enabled: bool, quiet: bool = False, num_prompts: int = 100
) -> list[Messages]:
prompt_types = ["repeat", "sentence"]
if mm_enabled:
prompt_types.append("mm")
num_prompts = 100
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")
if not quiet:
print(f"Prompt types: {random_prompt_type_choices}")
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
@@ -75,11 +90,27 @@ def get_test_prompts(mm_enabled: bool):
return prompts
def get_instruct_coder_messages(n: int) -> list[Messages]:
dataset = InstructCoderDataset(
dataset_path="likaixin/InstructCoder", dataset_split="train"
)
prompts: Iterable[str] = dataset.sample_prompts(n=n)
return [[{"role": "user", "content": prompt}] for prompt in prompts]
@pytest.fixture
def sampling_config():
return greedy_sampling()
def greedy_sampling() -> SamplingParams:
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
def stochastic_sampling() -> SamplingParams:
return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False)
@pytest.fixture
def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"
@@ -583,3 +614,269 @@ def test_mtp_correctness(
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@dataclass
class ArgsTest:
target_model: str
draft_model: str
sampling_config: SamplingParams
num_speculative_tokens: int
expected_acceptance_rate: float
expected_acceptance_len: float
# Defaults
target_tensor_parallel_size: int = 1
draft_tensor_parallel_size: int = 1
max_model_len: int = 1024
gpu_memory_utilization: float = 0.5
dataset: str = "test_prompts"
num_prompts: int = 100
cases = [
# Same model for draft and target, greedy sampling.
ArgsTest(
target_model="Qwen/Qwen3-0.6B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=greedy_sampling(),
num_speculative_tokens=3, # K
expected_acceptance_len=3 + 1, # K + 1
expected_acceptance_rate=1.0,
),
# Smaller draft model, stochastic sampling.
ArgsTest(
target_model="Qwen/Qwen3-1.7B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=stochastic_sampling(),
num_speculative_tokens=3,
expected_acceptance_len=2.8 + 1,
expected_acceptance_rate=0.9,
),
]
@pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
assert_draft_model_correctness(args, enforce_eager)
def test_draft_model_realistic_example():
args = ArgsTest(
target_model="Qwen/Qwen3-1.7B",
draft_model="Qwen/Qwen3-0.6B",
dataset="likaixin/InstructCoder",
num_speculative_tokens=3,
sampling_config=greedy_sampling(),
# values below are not derived, but just prevent a regression
expected_acceptance_len=2.8,
expected_acceptance_rate=0.55,
)
assert_draft_model_correctness(args, enforce_eager=False)
@pytest.mark.parametrize(
"models",
[
# target_model, draft_model
("Qwen/Qwen3-1.7B-FP8", "Qwen/Qwen3-0.6B"), # target quantized
("Qwen/Qwen3-1.7B", "Qwen/Qwen3-0.6B-FP8"), # draft quantized
],
ids=["target_quantized", "draft_quantized"],
)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
tgt_model, draft_model = models
sd_case = ArgsTest(
target_model=tgt_model,
draft_model=draft_model,
**some_high_acceptance_metrics(),
)
assert_draft_model_correctness(sd_case, enforce_eager)
def test_draft_model_tensor_parallelism():
"""Ensure spec decode works when running with TP > 1."""
_skip_if_insufficient_gpus_for_tp(2)
sd_case = ArgsTest(
target_model="Qwen/Qwen3-1.7B",
target_tensor_parallel_size=2,
draft_model="Qwen/Qwen3-0.6B",
draft_tensor_parallel_size=2,
**some_high_acceptance_metrics(),
)
assert_draft_model_correctness(sd_case, enforce_eager=False)
def test_draft_model_engine_args_tensor_parallelism():
"""Ensure the vllm_config for the draft model is created correctly,
and independently of the target model (quantization, TP, etc.)"""
_skip_if_insufficient_gpus_for_tp(2)
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B-FP8", # <<< tgt quantized
tensor_parallel_size=2,
speculative_config={
"model": "Qwen/Qwen3-0.6B", # <<< draft not quantized
"method": "draft_model",
"num_speculative_tokens": 3,
"draft_tensor_parallel_size": 1, # <<< valid arg name
},
)
tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2
assert tgt_vllm_config.quant_config.get_name() == "fp8"
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
assert draft_vllm_config.quant_config is None
def test_draft_model_engine_args_rejects_invalid_tp_argname():
"""The user should pass "draft_tensor_parallel_size" rather than
"tensor_parallel_size". We enforce this with validation."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
"tensor_parallel_size": 1, # <<< invalid arg name
},
)
with pytest.raises(ValueError):
engine_args.create_engine_config()
def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
test_prompts: list[Messages] = get_messages(
dataset=args.dataset, n=args.num_prompts
)
spec_llm = LLM(
model=args.target_model,
speculative_config={
"model": args.draft_model,
"method": "draft_model",
"num_speculative_tokens": args.num_speculative_tokens,
"max_model_len": args.max_model_len,
"enforce_eager": enforce_eager,
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
"max_num_seqs": 100, # limit cudagraph capture runtime
},
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,
enforce_eager=enforce_eager,
disable_log_stats=False, # enables get_metrics()
)
# we don't check the outputs, only check the metrics
spec_llm.chat(test_prompts, args.sampling_config)
metrics = spec_llm.get_metrics()
acceptance_rate: float = compute_acceptance_rate(metrics)
acceptance_len: float = compute_acceptance_len(metrics)
del spec_llm # CLEANUP
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
print(
f"spec-decode: target={args.target_model}, draft={args.draft_model}, "
f"temperature={args.sampling_config.temperature:.2f}, "
f"acceptance_rate={acceptance_rate:.2f}, "
f"acceptance_len={acceptance_len:.2f}, "
)
assert acceptance_rate >= args.expected_acceptance_rate
assert acceptance_len >= args.expected_acceptance_len
def get_messages(dataset: str, n: int) -> list[Messages]:
if dataset == "test_prompts":
return get_test_prompts(mm_enabled=False, quiet=True, num_prompts=n)
elif dataset == "likaixin/InstructCoder":
return get_instruct_coder_messages(n=n)
else:
raise NotImplementedError(f"Dataset '{dataset}' not implemented")
def some_high_acceptance_metrics() -> dict:
return {
"sampling_config": greedy_sampling(),
"num_speculative_tokens": 3,
"expected_acceptance_len": 2.90 + 1,
"expected_acceptance_rate": 0.90,
}
def test_merge_toks_kernel():
device = "cuda"
merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2
merged = torch.full((merged_len,), -100, device=device) # -100 is arbitrary
is_rejected_tok = torch.full((merged_len,), True, device=device)
grid = (2,)
merge_toks_kernel[grid](
target_toks_ptr=torch.tensor([0, 1, 2, 0, 1], device=device),
next_toks_ptr=torch.tensor([3, 2], device=device),
query_start_locs_ptr=torch.tensor([0, 3], device=device),
query_end_locs_ptr=torch.tensor([2, 4], device=device),
out_ptr_merged_toks=merged,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=5,
rejected_tok_fill=-1,
)
expected_merged = torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device)
assert torch.allclose(merged, expected_merged)
expected_rejected_toks = torch.tensor([False] * merged_len, device=device)
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
def test_merge_toks_kernel_with_rejected_tokens():
device = "cuda"
merged_size = 9 + 2 # len(target_toks) = 9, batch_size = 2
merged = torch.full((merged_size,), -100, device=device)
is_rejected_tok = torch.full((merged_size,), True, device=device)
grid = (2,)
merge_toks_kernel[grid](
# rejected tokens
# ↓ ↓ ↓ ↓
target_toks_ptr=torch.tensor([0, 1, 2, 13, 14, 15, 0, 1, 22], device=device),
next_toks_ptr=torch.tensor([3, 2], device=device),
query_start_locs_ptr=torch.tensor([0, 6], device=device),
query_end_locs_ptr=torch.tensor([2, 7], device=device),
out_ptr_merged_toks=merged,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=9,
rejected_tok_fill=-1,
)
expected_merged = torch.tensor([0, 1, 2, 3, -1, -1, -1, 0, 1, 2, -1], device=device)
assert torch.allclose(merged, expected_merged)
expected_rejected_toks = torch.tensor(
[False, False, False, False, True, True, True, False, False, False, True],
device=device,
)
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
def compute_acceptance_rate(metrics: list[Metric]) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
if n_draft_toks == 0:
return float("nan")
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
return n_accepted_toks / n_draft_toks
def compute_acceptance_len(metrics: list[Metric]) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
if n_drafts == 0:
return 1
return 1 + (n_accepted_toks / n_drafts)