feat: spec decode with draft models (#24322)
Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
@@ -10,32 +12,45 @@ from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.benchmarks.datasets import InstructCoderDataset
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.metrics.reader import Metric
|
||||
from vllm.v1.spec_decode.draft_model import (
|
||||
create_vllm_config_for_draft_model,
|
||||
merge_toks_kernel,
|
||||
)
|
||||
|
||||
MTP_SIMILARITY_RATE = 0.8
|
||||
|
||||
|
||||
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
|
||||
"""Skip test if available GPUs < tp_size on ROCm."""
|
||||
if current_platform.is_rocm():
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if available_gpus < tp_size:
|
||||
pytest.skip(
|
||||
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
|
||||
)
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if available_gpus < tp_size:
|
||||
pytest.skip(
|
||||
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
|
||||
)
|
||||
|
||||
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
Messages = list[dict[str, Any]]
|
||||
|
||||
|
||||
def get_test_prompts(
|
||||
mm_enabled: bool, quiet: bool = False, num_prompts: int = 100
|
||||
) -> list[Messages]:
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
if mm_enabled:
|
||||
prompt_types.append("mm")
|
||||
num_prompts = 100
|
||||
prompts = []
|
||||
|
||||
random.seed(0)
|
||||
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
||||
print(f"Prompt types: {random_prompt_type_choices}")
|
||||
|
||||
if not quiet:
|
||||
print(f"Prompt types: {random_prompt_type_choices}")
|
||||
|
||||
# Generate a mixed batch of prompts, some of which can be easily
|
||||
# predicted by n-gram matching and some which likely cannot.
|
||||
@@ -75,11 +90,27 @@ def get_test_prompts(mm_enabled: bool):
|
||||
return prompts
|
||||
|
||||
|
||||
def get_instruct_coder_messages(n: int) -> list[Messages]:
|
||||
dataset = InstructCoderDataset(
|
||||
dataset_path="likaixin/InstructCoder", dataset_split="train"
|
||||
)
|
||||
prompts: Iterable[str] = dataset.sample_prompts(n=n)
|
||||
return [[{"role": "user", "content": prompt}] for prompt in prompts]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampling_config():
|
||||
return greedy_sampling()
|
||||
|
||||
|
||||
def greedy_sampling() -> SamplingParams:
|
||||
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
|
||||
|
||||
|
||||
def stochastic_sampling() -> SamplingParams:
|
||||
return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return "meta-llama/Llama-3.1-8B-Instruct"
|
||||
@@ -583,3 +614,269 @@ def test_mtp_correctness(
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArgsTest:
|
||||
target_model: str
|
||||
draft_model: str
|
||||
sampling_config: SamplingParams
|
||||
num_speculative_tokens: int
|
||||
expected_acceptance_rate: float
|
||||
expected_acceptance_len: float
|
||||
# Defaults
|
||||
target_tensor_parallel_size: int = 1
|
||||
draft_tensor_parallel_size: int = 1
|
||||
max_model_len: int = 1024
|
||||
gpu_memory_utilization: float = 0.5
|
||||
dataset: str = "test_prompts"
|
||||
num_prompts: int = 100
|
||||
|
||||
|
||||
cases = [
|
||||
# Same model for draft and target, greedy sampling.
|
||||
ArgsTest(
|
||||
target_model="Qwen/Qwen3-0.6B",
|
||||
draft_model="Qwen/Qwen3-0.6B",
|
||||
sampling_config=greedy_sampling(),
|
||||
num_speculative_tokens=3, # K
|
||||
expected_acceptance_len=3 + 1, # K + 1
|
||||
expected_acceptance_rate=1.0,
|
||||
),
|
||||
# Smaller draft model, stochastic sampling.
|
||||
ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
draft_model="Qwen/Qwen3-0.6B",
|
||||
sampling_config=stochastic_sampling(),
|
||||
num_speculative_tokens=3,
|
||||
expected_acceptance_len=2.8 + 1,
|
||||
expected_acceptance_rate=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("args", cases)
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
|
||||
assert_draft_model_correctness(args, enforce_eager)
|
||||
|
||||
|
||||
def test_draft_model_realistic_example():
|
||||
args = ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
draft_model="Qwen/Qwen3-0.6B",
|
||||
dataset="likaixin/InstructCoder",
|
||||
num_speculative_tokens=3,
|
||||
sampling_config=greedy_sampling(),
|
||||
# values below are not derived, but just prevent a regression
|
||||
expected_acceptance_len=2.8,
|
||||
expected_acceptance_rate=0.55,
|
||||
)
|
||||
assert_draft_model_correctness(args, enforce_eager=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"models",
|
||||
[
|
||||
# target_model, draft_model
|
||||
("Qwen/Qwen3-1.7B-FP8", "Qwen/Qwen3-0.6B"), # target quantized
|
||||
("Qwen/Qwen3-1.7B", "Qwen/Qwen3-0.6B-FP8"), # draft quantized
|
||||
],
|
||||
ids=["target_quantized", "draft_quantized"],
|
||||
)
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
|
||||
tgt_model, draft_model = models
|
||||
sd_case = ArgsTest(
|
||||
target_model=tgt_model,
|
||||
draft_model=draft_model,
|
||||
**some_high_acceptance_metrics(),
|
||||
)
|
||||
assert_draft_model_correctness(sd_case, enforce_eager)
|
||||
|
||||
|
||||
def test_draft_model_tensor_parallelism():
|
||||
"""Ensure spec decode works when running with TP > 1."""
|
||||
_skip_if_insufficient_gpus_for_tp(2)
|
||||
sd_case = ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
target_tensor_parallel_size=2,
|
||||
draft_model="Qwen/Qwen3-0.6B",
|
||||
draft_tensor_parallel_size=2,
|
||||
**some_high_acceptance_metrics(),
|
||||
)
|
||||
assert_draft_model_correctness(sd_case, enforce_eager=False)
|
||||
|
||||
|
||||
def test_draft_model_engine_args_tensor_parallelism():
|
||||
"""Ensure the vllm_config for the draft model is created correctly,
|
||||
and independently of the target model (quantization, TP, etc.)"""
|
||||
_skip_if_insufficient_gpus_for_tp(2)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-1.7B-FP8", # <<< tgt quantized
|
||||
tensor_parallel_size=2,
|
||||
speculative_config={
|
||||
"model": "Qwen/Qwen3-0.6B", # <<< draft not quantized
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": 3,
|
||||
"draft_tensor_parallel_size": 1, # <<< valid arg name
|
||||
},
|
||||
)
|
||||
tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
|
||||
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2
|
||||
assert tgt_vllm_config.quant_config.get_name() == "fp8"
|
||||
|
||||
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
|
||||
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
|
||||
assert draft_vllm_config.quant_config is None
|
||||
|
||||
|
||||
def test_draft_model_engine_args_rejects_invalid_tp_argname():
|
||||
"""The user should pass "draft_tensor_parallel_size" rather than
|
||||
"tensor_parallel_size". We enforce this with validation."""
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-1.7B",
|
||||
tensor_parallel_size=1,
|
||||
speculative_config={
|
||||
"model": "Qwen/Qwen3-0.6B",
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": 3,
|
||||
"tensor_parallel_size": 1, # <<< invalid arg name
|
||||
},
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
engine_args.create_engine_config()
|
||||
|
||||
|
||||
def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
|
||||
"""Compare the outputs using and not using speculative decoding.
|
||||
In the greedy decoding case, the outputs must match EXACTLY."""
|
||||
test_prompts: list[Messages] = get_messages(
|
||||
dataset=args.dataset, n=args.num_prompts
|
||||
)
|
||||
|
||||
spec_llm = LLM(
|
||||
model=args.target_model,
|
||||
speculative_config={
|
||||
"model": args.draft_model,
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": args.num_speculative_tokens,
|
||||
"max_model_len": args.max_model_len,
|
||||
"enforce_eager": enforce_eager,
|
||||
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
|
||||
"max_num_seqs": 100, # limit cudagraph capture runtime
|
||||
},
|
||||
max_model_len=args.max_model_len,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
tensor_parallel_size=args.target_tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
disable_log_stats=False, # enables get_metrics()
|
||||
)
|
||||
# we don't check the outputs, only check the metrics
|
||||
spec_llm.chat(test_prompts, args.sampling_config)
|
||||
metrics = spec_llm.get_metrics()
|
||||
|
||||
acceptance_rate: float = compute_acceptance_rate(metrics)
|
||||
acceptance_len: float = compute_acceptance_len(metrics)
|
||||
del spec_llm # CLEANUP
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
print(
|
||||
f"spec-decode: target={args.target_model}, draft={args.draft_model}, "
|
||||
f"temperature={args.sampling_config.temperature:.2f}, "
|
||||
f"acceptance_rate={acceptance_rate:.2f}, "
|
||||
f"acceptance_len={acceptance_len:.2f}, "
|
||||
)
|
||||
|
||||
assert acceptance_rate >= args.expected_acceptance_rate
|
||||
assert acceptance_len >= args.expected_acceptance_len
|
||||
|
||||
|
||||
def get_messages(dataset: str, n: int) -> list[Messages]:
|
||||
if dataset == "test_prompts":
|
||||
return get_test_prompts(mm_enabled=False, quiet=True, num_prompts=n)
|
||||
elif dataset == "likaixin/InstructCoder":
|
||||
return get_instruct_coder_messages(n=n)
|
||||
else:
|
||||
raise NotImplementedError(f"Dataset '{dataset}' not implemented")
|
||||
|
||||
|
||||
def some_high_acceptance_metrics() -> dict:
|
||||
return {
|
||||
"sampling_config": greedy_sampling(),
|
||||
"num_speculative_tokens": 3,
|
||||
"expected_acceptance_len": 2.90 + 1,
|
||||
"expected_acceptance_rate": 0.90,
|
||||
}
|
||||
|
||||
|
||||
def test_merge_toks_kernel():
|
||||
device = "cuda"
|
||||
merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2
|
||||
merged = torch.full((merged_len,), -100, device=device) # -100 is arbitrary
|
||||
is_rejected_tok = torch.full((merged_len,), True, device=device)
|
||||
grid = (2,)
|
||||
merge_toks_kernel[grid](
|
||||
target_toks_ptr=torch.tensor([0, 1, 2, 0, 1], device=device),
|
||||
next_toks_ptr=torch.tensor([3, 2], device=device),
|
||||
query_start_locs_ptr=torch.tensor([0, 3], device=device),
|
||||
query_end_locs_ptr=torch.tensor([2, 4], device=device),
|
||||
out_ptr_merged_toks=merged,
|
||||
out_ptr_is_rejected_tok=is_rejected_tok,
|
||||
target_toks_size=5,
|
||||
rejected_tok_fill=-1,
|
||||
)
|
||||
expected_merged = torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device)
|
||||
assert torch.allclose(merged, expected_merged)
|
||||
|
||||
expected_rejected_toks = torch.tensor([False] * merged_len, device=device)
|
||||
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
|
||||
|
||||
|
||||
def test_merge_toks_kernel_with_rejected_tokens():
|
||||
device = "cuda"
|
||||
merged_size = 9 + 2 # len(target_toks) = 9, batch_size = 2
|
||||
merged = torch.full((merged_size,), -100, device=device)
|
||||
is_rejected_tok = torch.full((merged_size,), True, device=device)
|
||||
grid = (2,)
|
||||
merge_toks_kernel[grid](
|
||||
# rejected tokens
|
||||
# ↓ ↓ ↓ ↓
|
||||
target_toks_ptr=torch.tensor([0, 1, 2, 13, 14, 15, 0, 1, 22], device=device),
|
||||
next_toks_ptr=torch.tensor([3, 2], device=device),
|
||||
query_start_locs_ptr=torch.tensor([0, 6], device=device),
|
||||
query_end_locs_ptr=torch.tensor([2, 7], device=device),
|
||||
out_ptr_merged_toks=merged,
|
||||
out_ptr_is_rejected_tok=is_rejected_tok,
|
||||
target_toks_size=9,
|
||||
rejected_tok_fill=-1,
|
||||
)
|
||||
expected_merged = torch.tensor([0, 1, 2, 3, -1, -1, -1, 0, 1, 2, -1], device=device)
|
||||
assert torch.allclose(merged, expected_merged)
|
||||
|
||||
expected_rejected_toks = torch.tensor(
|
||||
[False, False, False, False, True, True, True, False, False, False, True],
|
||||
device=device,
|
||||
)
|
||||
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
|
||||
|
||||
|
||||
def compute_acceptance_rate(metrics: list[Metric]) -> float:
|
||||
name2metric = {metric.name: metric for metric in metrics}
|
||||
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
|
||||
if n_draft_toks == 0:
|
||||
return float("nan")
|
||||
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
|
||||
return n_accepted_toks / n_draft_toks
|
||||
|
||||
|
||||
def compute_acceptance_len(metrics: list[Metric]) -> float:
|
||||
name2metric = {metric.name: metric for metric in metrics}
|
||||
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore
|
||||
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
|
||||
if n_drafts == 0:
|
||||
return 1
|
||||
return 1 + (n_accepted_toks / n_drafts)
|
||||
|
||||
Reference in New Issue
Block a user