[Spec Decode] Unified Parallel Drafting (#32887)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
5b2a9422f0
commit
af3162d3aa
@@ -13,15 +13,12 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.benchmarks.datasets import InstructCoderDataset
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.metrics.reader import Metric
|
||||
from vllm.v1.spec_decode.draft_model import (
|
||||
create_vllm_config_for_draft_model,
|
||||
merge_toks_kernel,
|
||||
)
|
||||
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
|
||||
|
||||
MTP_SIMILARITY_RATE = 0.8
|
||||
|
||||
@@ -625,6 +622,8 @@ class ArgsTest:
|
||||
expected_acceptance_rate: float
|
||||
expected_acceptance_len: float
|
||||
# Defaults
|
||||
enforce_eager: bool = True
|
||||
parallel_drafting: bool = False
|
||||
target_tensor_parallel_size: int = 1
|
||||
draft_tensor_parallel_size: int = 1
|
||||
max_model_len: int = 1024
|
||||
@@ -658,7 +657,8 @@ cases = [
|
||||
@pytest.mark.parametrize("args", cases)
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
|
||||
assert_draft_model_correctness(args, enforce_eager)
|
||||
args.enforce_eager = enforce_eager
|
||||
assert_draft_model_correctness(args)
|
||||
|
||||
|
||||
def test_draft_model_realistic_example():
|
||||
@@ -668,11 +668,28 @@ def test_draft_model_realistic_example():
|
||||
dataset="likaixin/InstructCoder",
|
||||
num_speculative_tokens=3,
|
||||
sampling_config=greedy_sampling(),
|
||||
enforce_eager=False,
|
||||
# values below are not derived, but just prevent a regression
|
||||
expected_acceptance_len=2.8,
|
||||
expected_acceptance_rate=0.55,
|
||||
)
|
||||
assert_draft_model_correctness(args, enforce_eager=False)
|
||||
assert_draft_model_correctness(args)
|
||||
|
||||
|
||||
def test_draft_model_parallel_drafting():
|
||||
args = ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
draft_model="amd/PARD-Qwen3-0.6B",
|
||||
dataset="likaixin/InstructCoder",
|
||||
num_speculative_tokens=3,
|
||||
sampling_config=greedy_sampling(),
|
||||
parallel_drafting=True,
|
||||
enforce_eager=False,
|
||||
# values below are collected from a stable run, with ~5% tolerance
|
||||
expected_acceptance_len=2.375,
|
||||
expected_acceptance_rate=0.45,
|
||||
)
|
||||
assert_draft_model_correctness(args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -691,8 +708,9 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
|
||||
target_model=tgt_model,
|
||||
draft_model=draft_model,
|
||||
**some_high_acceptance_metrics(),
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
assert_draft_model_correctness(sd_case, enforce_eager)
|
||||
assert_draft_model_correctness(sd_case)
|
||||
|
||||
|
||||
def test_draft_model_tensor_parallelism():
|
||||
@@ -704,8 +722,9 @@ def test_draft_model_tensor_parallelism():
|
||||
draft_model="Qwen/Qwen3-0.6B",
|
||||
draft_tensor_parallel_size=2,
|
||||
**some_high_acceptance_metrics(),
|
||||
enforce_eager=False,
|
||||
)
|
||||
assert_draft_model_correctness(sd_case, enforce_eager=False)
|
||||
assert_draft_model_correctness(sd_case)
|
||||
|
||||
|
||||
def test_draft_model_engine_args_tensor_parallelism():
|
||||
@@ -750,7 +769,7 @@ def test_draft_model_engine_args_rejects_invalid_tp_argname():
|
||||
engine_args.create_engine_config()
|
||||
|
||||
|
||||
def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
|
||||
def assert_draft_model_correctness(args: ArgsTest):
|
||||
"""Compare the outputs using and not using speculative decoding.
|
||||
In the greedy decoding case, the outputs must match EXACTLY."""
|
||||
test_prompts: list[Messages] = get_messages(
|
||||
@@ -764,14 +783,15 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": args.num_speculative_tokens,
|
||||
"max_model_len": args.max_model_len,
|
||||
"enforce_eager": enforce_eager,
|
||||
"enforce_eager": args.enforce_eager,
|
||||
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
|
||||
"parallel_drafting": args.parallel_drafting,
|
||||
},
|
||||
max_num_seqs=100, # limit cudagraph capture runtime
|
||||
max_model_len=args.max_model_len,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
tensor_parallel_size=args.target_tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
enforce_eager=args.enforce_eager,
|
||||
disable_log_stats=False, # enables get_metrics()
|
||||
)
|
||||
# we don't check the outputs, only check the metrics
|
||||
@@ -813,57 +833,6 @@ def some_high_acceptance_metrics() -> dict:
|
||||
}
|
||||
|
||||
|
||||
def test_merge_toks_kernel():
|
||||
device = "cuda"
|
||||
merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2
|
||||
merged = torch.full((merged_len,), -100, device=device) # -100 is arbitrary
|
||||
is_rejected_tok = torch.full((merged_len,), True, device=device)
|
||||
grid = (2,)
|
||||
merge_toks_kernel[grid](
|
||||
target_toks_ptr=torch.tensor([0, 1, 2, 0, 1], device=device),
|
||||
next_toks_ptr=torch.tensor([3, 2], device=device),
|
||||
query_start_locs_ptr=torch.tensor([0, 3], device=device),
|
||||
query_end_locs_ptr=torch.tensor([2, 4], device=device),
|
||||
out_ptr_merged_toks=merged,
|
||||
out_ptr_is_rejected_tok=is_rejected_tok,
|
||||
target_toks_size=5,
|
||||
rejected_tok_fill=-1,
|
||||
)
|
||||
expected_merged = torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device)
|
||||
assert torch.allclose(merged, expected_merged)
|
||||
|
||||
expected_rejected_toks = torch.tensor([False] * merged_len, device=device)
|
||||
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
|
||||
|
||||
|
||||
def test_merge_toks_kernel_with_rejected_tokens():
|
||||
device = "cuda"
|
||||
merged_size = 9 + 2 # len(target_toks) = 9, batch_size = 2
|
||||
merged = torch.full((merged_size,), -100, device=device)
|
||||
is_rejected_tok = torch.full((merged_size,), True, device=device)
|
||||
grid = (2,)
|
||||
merge_toks_kernel[grid](
|
||||
# rejected tokens
|
||||
# ↓ ↓ ↓ ↓
|
||||
target_toks_ptr=torch.tensor([0, 1, 2, 13, 14, 15, 0, 1, 22], device=device),
|
||||
next_toks_ptr=torch.tensor([3, 2], device=device),
|
||||
query_start_locs_ptr=torch.tensor([0, 6], device=device),
|
||||
query_end_locs_ptr=torch.tensor([2, 7], device=device),
|
||||
out_ptr_merged_toks=merged,
|
||||
out_ptr_is_rejected_tok=is_rejected_tok,
|
||||
target_toks_size=9,
|
||||
rejected_tok_fill=-1,
|
||||
)
|
||||
expected_merged = torch.tensor([0, 1, 2, 3, -1, -1, -1, 0, 1, 2, -1], device=device)
|
||||
assert torch.allclose(merged, expected_merged)
|
||||
|
||||
expected_rejected_toks = torch.tensor(
|
||||
[False, False, False, False, True, True, True, False, False, False, True],
|
||||
device=device,
|
||||
)
|
||||
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
|
||||
|
||||
|
||||
def compute_acceptance_rate(metrics: list[Metric]) -> float:
|
||||
name2metric = {metric.name: metric for metric in metrics}
|
||||
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user