[Spec Decode] Unified Parallel Drafting (#32887)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-02-05 12:37:18 -05:00
committed by GitHub
parent 5b2a9422f0
commit af3162d3aa
14 changed files with 1085 additions and 392 deletions

View File

@@ -13,15 +13,12 @@ from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.benchmarks.datasets import InstructCoderDataset
from vllm.config.vllm import VllmConfig
from vllm.config import VllmConfig
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.v1.metrics.reader import Metric
from vllm.v1.spec_decode.draft_model import (
create_vllm_config_for_draft_model,
merge_toks_kernel,
)
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
MTP_SIMILARITY_RATE = 0.8
@@ -625,6 +622,8 @@ class ArgsTest:
expected_acceptance_rate: float
expected_acceptance_len: float
# Defaults
enforce_eager: bool = True
parallel_drafting: bool = False
target_tensor_parallel_size: int = 1
draft_tensor_parallel_size: int = 1
max_model_len: int = 1024
@@ -658,7 +657,8 @@ cases = [
@pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
assert_draft_model_correctness(args, enforce_eager)
args.enforce_eager = enforce_eager
assert_draft_model_correctness(args)
def test_draft_model_realistic_example():
@@ -668,11 +668,28 @@ def test_draft_model_realistic_example():
dataset="likaixin/InstructCoder",
num_speculative_tokens=3,
sampling_config=greedy_sampling(),
enforce_eager=False,
# values below are not derived, but just prevent a regression
expected_acceptance_len=2.8,
expected_acceptance_rate=0.55,
)
assert_draft_model_correctness(args, enforce_eager=False)
assert_draft_model_correctness(args)
def test_draft_model_parallel_drafting():
args = ArgsTest(
target_model="Qwen/Qwen3-1.7B",
draft_model="amd/PARD-Qwen3-0.6B",
dataset="likaixin/InstructCoder",
num_speculative_tokens=3,
sampling_config=greedy_sampling(),
parallel_drafting=True,
enforce_eager=False,
# values below are collected from a stable run, with ~5% tolerance
expected_acceptance_len=2.375,
expected_acceptance_rate=0.45,
)
assert_draft_model_correctness(args)
@pytest.mark.parametrize(
@@ -691,8 +708,9 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
target_model=tgt_model,
draft_model=draft_model,
**some_high_acceptance_metrics(),
enforce_eager=enforce_eager,
)
assert_draft_model_correctness(sd_case, enforce_eager)
assert_draft_model_correctness(sd_case)
def test_draft_model_tensor_parallelism():
@@ -704,8 +722,9 @@ def test_draft_model_tensor_parallelism():
draft_model="Qwen/Qwen3-0.6B",
draft_tensor_parallel_size=2,
**some_high_acceptance_metrics(),
enforce_eager=False,
)
assert_draft_model_correctness(sd_case, enforce_eager=False)
assert_draft_model_correctness(sd_case)
def test_draft_model_engine_args_tensor_parallelism():
@@ -750,7 +769,7 @@ def test_draft_model_engine_args_rejects_invalid_tp_argname():
engine_args.create_engine_config()
def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
def assert_draft_model_correctness(args: ArgsTest):
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
test_prompts: list[Messages] = get_messages(
@@ -764,14 +783,15 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"method": "draft_model",
"num_speculative_tokens": args.num_speculative_tokens,
"max_model_len": args.max_model_len,
"enforce_eager": enforce_eager,
"enforce_eager": args.enforce_eager,
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
"parallel_drafting": args.parallel_drafting,
},
max_num_seqs=100, # limit cudagraph capture runtime
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,
enforce_eager=enforce_eager,
enforce_eager=args.enforce_eager,
disable_log_stats=False, # enables get_metrics()
)
# we don't check the outputs, only check the metrics
@@ -813,57 +833,6 @@ def some_high_acceptance_metrics() -> dict:
}
def test_merge_toks_kernel():
device = "cuda"
merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2
merged = torch.full((merged_len,), -100, device=device) # -100 is arbitrary
is_rejected_tok = torch.full((merged_len,), True, device=device)
grid = (2,)
merge_toks_kernel[grid](
target_toks_ptr=torch.tensor([0, 1, 2, 0, 1], device=device),
next_toks_ptr=torch.tensor([3, 2], device=device),
query_start_locs_ptr=torch.tensor([0, 3], device=device),
query_end_locs_ptr=torch.tensor([2, 4], device=device),
out_ptr_merged_toks=merged,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=5,
rejected_tok_fill=-1,
)
expected_merged = torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device)
assert torch.allclose(merged, expected_merged)
expected_rejected_toks = torch.tensor([False] * merged_len, device=device)
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
def test_merge_toks_kernel_with_rejected_tokens():
device = "cuda"
merged_size = 9 + 2 # len(target_toks) = 9, batch_size = 2
merged = torch.full((merged_size,), -100, device=device)
is_rejected_tok = torch.full((merged_size,), True, device=device)
grid = (2,)
merge_toks_kernel[grid](
# rejected tokens
# ↓ ↓ ↓ ↓
target_toks_ptr=torch.tensor([0, 1, 2, 13, 14, 15, 0, 1, 22], device=device),
next_toks_ptr=torch.tensor([3, 2], device=device),
query_start_locs_ptr=torch.tensor([0, 6], device=device),
query_end_locs_ptr=torch.tensor([2, 7], device=device),
out_ptr_merged_toks=merged,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=9,
rejected_tok_fill=-1,
)
expected_merged = torch.tensor([0, 1, 2, 3, -1, -1, -1, 0, 1, 2, -1], device=device)
assert torch.allclose(merged, expected_merged)
expected_rejected_toks = torch.tensor(
[False, False, False, False, True, True, True, False, False, False, True],
device=device,
)
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
def compute_acceptance_rate(metrics: list[Metric]) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore