feat: spec decode with draft models (#24322)

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
Tomas Ruiz
2026-01-19 15:05:46 -06:00
committed by GitHub
parent 73f2a81c75
commit 4a5299c93f
21 changed files with 897 additions and 115 deletions

View File

@@ -54,7 +54,7 @@ def parse_args():
"--method",
type=str,
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp"],
choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
)
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
@@ -70,7 +70,11 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--disable-padded-drafter-batch", action="store_true")
parser.add_argument("--max-num-seqs", type=int, default=None)
return parser.parse_args()
@@ -111,6 +115,7 @@ def main(args):
"method": args.method,
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"disable_padded_drafter_batch": args.disable_padded_drafter_batch,
}
elif args.method == "ngram":
speculative_config = {
@@ -119,6 +124,15 @@ def main(args):
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method == "draft_model":
assert args.draft_model is not None and args.draft_model != ""
speculative_config = {
"method": args.method,
"model": args.draft_model,
"num_speculative_tokens": args.num_spec_tokens,
"enforce_eager": args.enforce_eager,
"max_model_len": args.max_model_len,
}
elif args.method == "mtp":
speculative_config = {
"method": "mtp",
@@ -133,12 +147,13 @@ def main(args):
tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
gpu_memory_utilization=0.9,
gpu_memory_utilization=args.gpu_memory_utilization,
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=args.max_model_len,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
max_num_seqs=args.max_num_seqs,
)
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)

View File

@@ -4,13 +4,13 @@ import asyncio
import copy
import logging
import os
import re
import socket
import threading
import uuid
import aiohttp
import msgpack
import regex as re
import zmq
from quart import Quart, make_response, request

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any
import pytest
@@ -10,32 +12,45 @@ from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.benchmarks.datasets import InstructCoderDataset
from vllm.config.vllm import VllmConfig
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.v1.metrics.reader import Metric
from vllm.v1.spec_decode.draft_model import (
create_vllm_config_for_draft_model,
merge_toks_kernel,
)
MTP_SIMILARITY_RATE = 0.8
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
"""Skip test if available GPUs < tp_size on ROCm."""
if current_platform.is_rocm():
available_gpus = torch.cuda.device_count()
if available_gpus < tp_size:
pytest.skip(
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
)
available_gpus = torch.cuda.device_count()
if available_gpus < tp_size:
pytest.skip(
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
)
def get_test_prompts(mm_enabled: bool):
Messages = list[dict[str, Any]]
def get_test_prompts(
mm_enabled: bool, quiet: bool = False, num_prompts: int = 100
) -> list[Messages]:
prompt_types = ["repeat", "sentence"]
if mm_enabled:
prompt_types.append("mm")
num_prompts = 100
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")
if not quiet:
print(f"Prompt types: {random_prompt_type_choices}")
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
@@ -75,11 +90,27 @@ def get_test_prompts(mm_enabled: bool):
return prompts
def get_instruct_coder_messages(n: int) -> list[Messages]:
dataset = InstructCoderDataset(
dataset_path="likaixin/InstructCoder", dataset_split="train"
)
prompts: Iterable[str] = dataset.sample_prompts(n=n)
return [[{"role": "user", "content": prompt}] for prompt in prompts]
@pytest.fixture
def sampling_config():
return greedy_sampling()
def greedy_sampling() -> SamplingParams:
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
def stochastic_sampling() -> SamplingParams:
return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False)
@pytest.fixture
def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"
@@ -583,3 +614,269 @@ def test_mtp_correctness(
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@dataclass
class ArgsTest:
target_model: str
draft_model: str
sampling_config: SamplingParams
num_speculative_tokens: int
expected_acceptance_rate: float
expected_acceptance_len: float
# Defaults
target_tensor_parallel_size: int = 1
draft_tensor_parallel_size: int = 1
max_model_len: int = 1024
gpu_memory_utilization: float = 0.5
dataset: str = "test_prompts"
num_prompts: int = 100
cases = [
# Same model for draft and target, greedy sampling.
ArgsTest(
target_model="Qwen/Qwen3-0.6B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=greedy_sampling(),
num_speculative_tokens=3, # K
expected_acceptance_len=3 + 1, # K + 1
expected_acceptance_rate=1.0,
),
# Smaller draft model, stochastic sampling.
ArgsTest(
target_model="Qwen/Qwen3-1.7B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=stochastic_sampling(),
num_speculative_tokens=3,
expected_acceptance_len=2.8 + 1,
expected_acceptance_rate=0.9,
),
]
@pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
assert_draft_model_correctness(args, enforce_eager)
def test_draft_model_realistic_example():
args = ArgsTest(
target_model="Qwen/Qwen3-1.7B",
draft_model="Qwen/Qwen3-0.6B",
dataset="likaixin/InstructCoder",
num_speculative_tokens=3,
sampling_config=greedy_sampling(),
# values below are not derived, but just prevent a regression
expected_acceptance_len=2.8,
expected_acceptance_rate=0.55,
)
assert_draft_model_correctness(args, enforce_eager=False)
@pytest.mark.parametrize(
"models",
[
# target_model, draft_model
("Qwen/Qwen3-1.7B-FP8", "Qwen/Qwen3-0.6B"), # target quantized
("Qwen/Qwen3-1.7B", "Qwen/Qwen3-0.6B-FP8"), # draft quantized
],
ids=["target_quantized", "draft_quantized"],
)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
tgt_model, draft_model = models
sd_case = ArgsTest(
target_model=tgt_model,
draft_model=draft_model,
**some_high_acceptance_metrics(),
)
assert_draft_model_correctness(sd_case, enforce_eager)
def test_draft_model_tensor_parallelism():
"""Ensure spec decode works when running with TP > 1."""
_skip_if_insufficient_gpus_for_tp(2)
sd_case = ArgsTest(
target_model="Qwen/Qwen3-1.7B",
target_tensor_parallel_size=2,
draft_model="Qwen/Qwen3-0.6B",
draft_tensor_parallel_size=2,
**some_high_acceptance_metrics(),
)
assert_draft_model_correctness(sd_case, enforce_eager=False)
def test_draft_model_engine_args_tensor_parallelism():
"""Ensure the vllm_config for the draft model is created correctly,
and independently of the target model (quantization, TP, etc.)"""
_skip_if_insufficient_gpus_for_tp(2)
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B-FP8", # <<< tgt quantized
tensor_parallel_size=2,
speculative_config={
"model": "Qwen/Qwen3-0.6B", # <<< draft not quantized
"method": "draft_model",
"num_speculative_tokens": 3,
"draft_tensor_parallel_size": 1, # <<< valid arg name
},
)
tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2
assert tgt_vllm_config.quant_config.get_name() == "fp8"
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
assert draft_vllm_config.quant_config is None
def test_draft_model_engine_args_rejects_invalid_tp_argname():
"""The user should pass "draft_tensor_parallel_size" rather than
"tensor_parallel_size". We enforce this with validation."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
"tensor_parallel_size": 1, # <<< invalid arg name
},
)
with pytest.raises(ValueError):
engine_args.create_engine_config()
def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
test_prompts: list[Messages] = get_messages(
dataset=args.dataset, n=args.num_prompts
)
spec_llm = LLM(
model=args.target_model,
speculative_config={
"model": args.draft_model,
"method": "draft_model",
"num_speculative_tokens": args.num_speculative_tokens,
"max_model_len": args.max_model_len,
"enforce_eager": enforce_eager,
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
"max_num_seqs": 100, # limit cudagraph capture runtime
},
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,
enforce_eager=enforce_eager,
disable_log_stats=False, # enables get_metrics()
)
# we don't check the outputs, only check the metrics
spec_llm.chat(test_prompts, args.sampling_config)
metrics = spec_llm.get_metrics()
acceptance_rate: float = compute_acceptance_rate(metrics)
acceptance_len: float = compute_acceptance_len(metrics)
del spec_llm # CLEANUP
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
print(
f"spec-decode: target={args.target_model}, draft={args.draft_model}, "
f"temperature={args.sampling_config.temperature:.2f}, "
f"acceptance_rate={acceptance_rate:.2f}, "
f"acceptance_len={acceptance_len:.2f}, "
)
assert acceptance_rate >= args.expected_acceptance_rate
assert acceptance_len >= args.expected_acceptance_len
def get_messages(dataset: str, n: int) -> list[Messages]:
if dataset == "test_prompts":
return get_test_prompts(mm_enabled=False, quiet=True, num_prompts=n)
elif dataset == "likaixin/InstructCoder":
return get_instruct_coder_messages(n=n)
else:
raise NotImplementedError(f"Dataset '{dataset}' not implemented")
def some_high_acceptance_metrics() -> dict:
return {
"sampling_config": greedy_sampling(),
"num_speculative_tokens": 3,
"expected_acceptance_len": 2.90 + 1,
"expected_acceptance_rate": 0.90,
}
def test_merge_toks_kernel():
device = "cuda"
merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2
merged = torch.full((merged_len,), -100, device=device) # -100 is arbitrary
is_rejected_tok = torch.full((merged_len,), True, device=device)
grid = (2,)
merge_toks_kernel[grid](
target_toks_ptr=torch.tensor([0, 1, 2, 0, 1], device=device),
next_toks_ptr=torch.tensor([3, 2], device=device),
query_start_locs_ptr=torch.tensor([0, 3], device=device),
query_end_locs_ptr=torch.tensor([2, 4], device=device),
out_ptr_merged_toks=merged,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=5,
rejected_tok_fill=-1,
)
expected_merged = torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device)
assert torch.allclose(merged, expected_merged)
expected_rejected_toks = torch.tensor([False] * merged_len, device=device)
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
def test_merge_toks_kernel_with_rejected_tokens():
device = "cuda"
merged_size = 9 + 2 # len(target_toks) = 9, batch_size = 2
merged = torch.full((merged_size,), -100, device=device)
is_rejected_tok = torch.full((merged_size,), True, device=device)
grid = (2,)
merge_toks_kernel[grid](
# rejected tokens
# ↓ ↓ ↓ ↓
target_toks_ptr=torch.tensor([0, 1, 2, 13, 14, 15, 0, 1, 22], device=device),
next_toks_ptr=torch.tensor([3, 2], device=device),
query_start_locs_ptr=torch.tensor([0, 6], device=device),
query_end_locs_ptr=torch.tensor([2, 7], device=device),
out_ptr_merged_toks=merged,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=9,
rejected_tok_fill=-1,
)
expected_merged = torch.tensor([0, 1, 2, 3, -1, -1, -1, 0, 1, 2, -1], device=device)
assert torch.allclose(merged, expected_merged)
expected_rejected_toks = torch.tensor(
[False, False, False, False, True, True, True, False, False, False, True],
device=device,
)
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
def compute_acceptance_rate(metrics: list[Metric]) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
if n_draft_toks == 0:
return float("nan")
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
return n_accepted_toks / n_draft_toks
def compute_acceptance_len(metrics: list[Metric]) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
if n_drafts == 0:
return 1
return 1 + (n_accepted_toks / n_drafts)

View File

@@ -55,3 +55,38 @@ def test_bind_kv_cache_non_attention(default_vllm_config):
assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]
def test_bind_kv_cache_draft_model(default_vllm_config):
from vllm.attention.layer import Attention
layer_names = [
"model.layers.0.attn",
"model.layers.1.attn",
"draft_model.layers.0.attn",
"draft_model.layers.1.attn",
]
ctx = {
layer_name: Attention(32, 128, 0.1, prefix=layer_name)
for layer_name in layer_names
}
kv_cache = {layer_name: torch.zeros((1,)) for layer_name in layer_names}
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"]
assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"]
assert (
ctx["draft_model.layers.0.attn"].kv_cache[0]
is kv_cache["draft_model.layers.0.attn"]
)
assert (
ctx["draft_model.layers.1.attn"].kv_cache[0]
is kv_cache["draft_model.layers.1.attn"]
)
# caches are ordered by layer_index, interleaving target and draft model
assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"]
assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"]
assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"]
assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"]

View File

@@ -2593,17 +2593,10 @@ class InstructCoderDataset(HuggingFaceDataset):
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list:
) -> list[SampleRequest]:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt = (
f"{item['input']}\n\n{item['instruction']} Just output "
"the code, do not include any explanation."
)
for i, prompt in enumerate(self.sample_prompts(n=num_requests)):
# apply template
if not skip_chat_template:
prompt = tokenizer.apply_chat_template(
@@ -2626,6 +2619,14 @@ class InstructCoderDataset(HuggingFaceDataset):
)
return sampled_requests
def sample_prompts(self, n: int) -> Iterator[str]:
for item in self.data.take(n):
prompt = (
f"{item['input']}\n\n{item['instruction']} Just output "
"the code, do not include any explanation."
)
yield prompt
# -----------------------------------------------------------------------------
# MT-Bench Dataset Implementation

View File

@@ -8,8 +8,12 @@ import time
import aiohttp
from tqdm.asyncio import tqdm
from vllm.logger import init_logger
from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput
logger = init_logger(__name__)
async def wait_for_endpoint(
request_func: RequestFunc,
@@ -61,6 +65,8 @@ async def wait_for_endpoint(
if output.success:
pbar.close()
return output
else:
logger.warning("Endpoint is not ready. Error='%s'", output.error)
except aiohttp.ClientConnectorError:
pass

View File

@@ -3,6 +3,7 @@
import os
from collections.abc import Callable
from dataclasses import replace
from typing import TYPE_CHECKING, Any, Literal
import torch
@@ -709,3 +710,6 @@ class ParallelConfig:
)
return self
def replace(self, **kwargs) -> Self:
return replace(self, **kwargs)

View File

@@ -77,6 +77,9 @@ class SpeculativeConfig:
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
tensor_parallel_size: int | None = None
"""Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
warn users when they mistakenly provide the wrong argument."""
# Draft model configuration
quantization: me_quant.QuantizationMethods | None = None
@@ -397,13 +400,11 @@ class SpeculativeConfig:
"one layer. Might need some code changes "
"to support multiple layers."
)
elif self.method == "draft_model":
pass
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or mtp."
f"Unsupported speculative method: '{self.method}'"
)
# Replace hf_config for EAGLE draft_model
@@ -631,6 +632,12 @@ class SpeculativeConfig:
@model_validator(mode="after")
def _verify_args(self) -> Self:
if self.tensor_parallel_size is not None:
raise ValueError(
"'tensor_parallel_size' is not a valid argument in the "
"speculative_config. Please pass 'draft_tensor_parallel_size' instead."
)
if self.num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
@@ -669,12 +676,32 @@ class SpeculativeConfig:
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
f"Got {self.target_model_config.hf_text_config.model_type=}"
)
self.verify_equal_vocab_size_if_draft_model()
return self
def verify_equal_vocab_size_if_draft_model(self):
if (
self.method == "draft_model"
and self.target_model_config is not None
and self.draft_model_config is not None
):
target_vocab_size = self.target_model_config.get_vocab_size()
draft_vocab_size = self.draft_model_config.get_vocab_size()
if target_vocab_size != draft_vocab_size:
raise ValueError(
f"Target and draft model should have the same vocabulary size. "
f"Target model vocab_size={target_vocab_size}. "
f"Draft model vocab_size={draft_vocab_size}. "
f"Using models with different tokenizers can cause out-of-bounds "
f"errors during speculative decoding."
)
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp")
def uses_draft_model(self) -> bool:
return self.method == "draft_model"
def __repr__(self) -> str:
method = self.method
model = None if method in ("ngram", "suffix") else self.draft_model_config.model

View File

@@ -1214,10 +1214,19 @@ class VllmConfig:
compilation_config = self.compilation_config
computed_compile_ranges_split_points = []
# The upper bound of the compile ranges is the max_num_batched_tokens
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
if max_num_batched_tokens is not None:
computed_compile_ranges_split_points.append(max_num_batched_tokens)
# The upper bound of the compile ranges is the max_num_batched_tokens.
# For speculative decoding with draft model, the compile range must be extended
# by 1 for each sequence.
compile_range_end = self.scheduler_config.max_num_batched_tokens
if compile_range_end is not None:
do_extend: bool = (
self.speculative_config is not None
and self.speculative_config.uses_draft_model()
)
if do_extend:
compile_range_end += self.scheduler_config.max_num_seqs
computed_compile_ranges_split_points.append(compile_range_end)
# Add the compile ranges for flashinfer
if compilation_config.pass_config.fuse_allreduce_rms:
@@ -1228,10 +1237,7 @@ class VllmConfig:
self.model_config.get_hidden_size()
* self.model_config.dtype.itemsize
)
if (
max_num_batched_tokens is not None
and max_token_num < max_num_batched_tokens
):
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_split_points.append(max_token_num)
else:
logger.debug(
@@ -1243,11 +1249,7 @@ class VllmConfig:
for x in compilation_config.compile_ranges_split_points:
assert isinstance(x, int)
assert x > 0, f"Invalid compile range split point: {x}"
if (
max_num_batched_tokens is not None
and x < max_num_batched_tokens
and x > 1
):
if compile_range_end is not None and x < compile_range_end and x > 1:
computed_compile_ranges_split_points.append(x)
compilation_config.compile_ranges_split_points = sorted(
computed_compile_ranges_split_points
@@ -1316,6 +1318,14 @@ class VllmConfig:
path = self.compilation_config.debug_dump_path / append_path
return path
def replace(self, **kwargs):
"""
Replace attributes of the config, and 'recompute' the config.
dataclass.replace() calls __init__() and __post_init__(), source:
https://docs.python.org/3/library/dataclasses.html#dataclasses.replace
"""
return replace(self, **kwargs)
def __str__(self):
return (
f"model={self.model_config.model!r}, "

View File

@@ -1776,21 +1776,6 @@ class EngineArgs:
):
_raise_unsupported_error(feature_name="Concurrent Partial Prefill")
# N-gram, Medusa, and Eagle are supported for speculative decoding.
if self.speculative_config is not None:
# speculative_config could still be a dict at this point
if isinstance(self.speculative_config, dict):
method = self.speculative_config.get("method", None)
else:
method = self.speculative_config.method
if method == "draft_model":
raise NotImplementedError(
"Draft model speculative decoding is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or mtp."
)
if self.pipeline_parallel_size > 1:
supports_pp = getattr(
self.distributed_executor_backend, "supports_pp", False

View File

@@ -124,12 +124,17 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
def get_model(
*, vllm_config: VllmConfig, model_config: ModelConfig | None = None
*,
vllm_config: VllmConfig,
model_config: ModelConfig | None = None,
prefix: str = "",
) -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
if model_config is None:
model_config = vllm_config.model_config
return loader.load_model(vllm_config=vllm_config, model_config=model_config)
return loader.load_model(
vllm_config=vllm_config, model_config=model_config, prefix=prefix
)
__all__ = [

View File

@@ -36,7 +36,7 @@ class BaseModelLoader(ABC):
raise NotImplementedError
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module:
"""Load a model with the given configurations."""
device_config = vllm_config.device_config
@@ -48,7 +48,7 @@ class BaseModelLoader(ABC):
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(
vllm_config=vllm_config, model_config=model_config
vllm_config=vllm_config, model_config=model_config, prefix=prefix
)
log_model_inspection(model)

View File

@@ -335,7 +335,7 @@ class GGUFModelLoader(BaseModelLoader):
)
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module:
device_config = vllm_config.device_config
local_model_path = self._prepare_weights(model_config)
@@ -364,7 +364,7 @@ class GGUFModelLoader(BaseModelLoader):
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
model = initialize_model(vllm_config=vllm_config, prefix=prefix)
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)

View File

@@ -68,6 +68,7 @@ class TensorizerLoader(BaseModelLoader):
def _load_model_serialized_cpu(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.
@@ -80,7 +81,7 @@ class TensorizerLoader(BaseModelLoader):
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = initialize_model(vllm_config=vllm_config)
model = initialize_model(vllm_config=vllm_config, prefix=prefix)
model.load_weights(self._get_weights_iterator())
return model.eval()
@@ -112,7 +113,7 @@ class TensorizerLoader(BaseModelLoader):
model.load_weights(self._get_weights_iterator())
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module:
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)
@@ -134,7 +135,7 @@ class TensorizerLoader(BaseModelLoader):
)
self.load_weights(model, model_config)
return model
return self._load_model_serialized_cpu(vllm_config=vllm_config)
return self._load_model_serialized_cpu(vllm_config=vllm_config, prefix=prefix)
@staticmethod
def save_model(

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, replace
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
@@ -329,6 +329,16 @@ class CommonAttentionMetadata:
_num_computed_tokens_cache: torch.Tensor | None = None
def batch_size(self) -> int:
return self.seq_lens.shape[0]
def naive_query_lens(self) -> torch.Tensor:
"""Naive because it assumes that query ends where the next query starts."""
return self.query_start_loc[1:] - self.query_start_loc[:-1]
def replace(self, **kwargs) -> "CommonAttentionMetadata":
return replace(self, **kwargs)
@property
@deprecated(
"""

View File

@@ -818,3 +818,35 @@ def get_dcp_local_seq_lens(
)
dcp_local_seq_lens = base + remainder
return dcp_local_seq_lens.squeeze(1)
def extend_all_queries_by_1(
common_attn_metadata: CommonAttentionMetadata,
arange: torch.Tensor,
new_slot_mapping: torch.Tensor,
) -> CommonAttentionMetadata:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by 1.
Also all seq lens are increased by 1.
This is useful e.g. in speculative decoding with draft models, where we
extend each sequence by 1 token.
The slot mapping is computed externally, as it requires more information.
"""
cad = common_attn_metadata
# query start loc must be increased by [+0, +1, +2, ..., +batch_size]
new_query_start_loc = cad.query_start_loc + arange[: len(cad.query_start_loc)]
new_query_start_loc_cpu = cad.query_start_loc_cpu + torch.arange(
len(cad.query_start_loc_cpu), dtype=torch.int32
)
new_cad = cad.replace(
query_start_loc=new_query_start_loc,
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens=cad.seq_lens + 1,
# each request is extended by 1 token -> batch_size tokens are added
num_actual_tokens=cad.num_actual_tokens + cad.batch_size(),
# All query lens increase by 1, so max query len increases by 1
max_query_len=cad.max_query_len + 1,
max_seq_len=cad.max_seq_len + 1,
slot_mapping=new_slot_mapping,
)
return new_cad

View File

@@ -208,6 +208,8 @@ class Scheduler(SchedulerInterface):
if speculative_config.use_eagle():
self.use_eagle = True
self.num_lookahead_tokens = self.num_spec_tokens
if speculative_config.uses_draft_model():
self.num_lookahead_tokens = self.num_spec_tokens
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(

View File

@@ -0,0 +1,271 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.speculative import SpeculativeConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
extend_all_queries_by_1,
)
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer
logger = init_logger(__name__)
class DraftModelProposer(SpecDecodeBaseProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(
vllm_config=vllm_config,
device=device,
pass_hidden_states_to_model=False,
runner=runner,
)
self._raise_if_multimodal()
self._raise_if_mrope()
self._raise_if_padded_drafter_batch_disabled()
self._raise_if_vocab_size_mismatch()
self._raise_if_draft_tp_mismatch()
def _block_size(self) -> int:
builder = self._get_attention_metadata_builder()
return builder.kv_cache_spec.block_size
def _raise_if_multimodal(self):
if self.supports_mm_inputs:
raise NotImplementedError(
"Speculative Decoding with draft models "
"does not support multimodal models yet"
)
def _raise_if_mrope(self):
if self.draft_model_config.uses_mrope:
raise NotImplementedError(
"Speculative Decoding with draft models does not support M-RoPE yet"
)
def _raise_if_padded_drafter_batch_disabled(self):
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
raise NotImplementedError(
"Speculative Decoding with draft models only supports "
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
" in the speculative_config."
)
def _raise_if_vocab_size_mismatch(self):
self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model()
def _raise_if_draft_tp_mismatch(self):
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
# the draft model with TP = 1, then the different TP ranks collide.
# Specifically when all ranks compile the draft model on rank 0
# (because TP=1), then the torch compile cache is overwritten and corrupted.
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
# To prevent this error, we assert that both TP sizes must be the same.
spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
if draft_tp != tgt_tp:
raise ValueError(
f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
f"must be the same. Got {draft_tp} and {tgt_tp}. "
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
)
def set_inputs_first_pass(
self,
target_token_ids: torch.Tensor,
next_token_ids: torch.Tensor,
target_positions: torch.Tensor,
last_token_indices: torch.Tensor | None,
cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
batch_size = cad.batch_size()
grid = (batch_size,)
start_locs = cad.query_start_loc[:-1]
end_locs = cad.query_start_loc[1:] - 1
if num_rejected_tokens_gpu is not None:
end_locs -= num_rejected_tokens_gpu
num_tokens = target_token_ids.shape[0] + batch_size
is_rejected_tok = torch.empty(
(num_tokens,), device=self.input_ids.device, dtype=torch.bool
)
merge_toks_kernel[grid](
target_toks_ptr=target_token_ids,
next_toks_ptr=next_token_ids,
query_start_locs_ptr=start_locs,
query_end_locs_ptr=end_locs,
out_ptr_merged_toks=self.input_ids,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=target_token_ids.shape[0],
# passing a negative rejected_tok_fill value will raise an error
# when the value is used to index into embeddings.
# Therefore, we pass a valid integer, e.g. 0.
rejected_tok_fill=0,
)
merge_toks_kernel[grid](
target_toks_ptr=target_positions,
next_toks_ptr=target_positions[end_locs] + 1,
query_start_locs_ptr=start_locs,
query_end_locs_ptr=end_locs,
out_ptr_merged_toks=self.positions,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=target_positions.shape[0],
rejected_tok_fill=0,
)
# recompute slot mapping
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:num_tokens],
is_rejected_token_mask=is_rejected_tok,
block_size=self._block_size(),
max_model_len=self.max_model_len,
)
# update common_attn_metadata
new_cad: CommonAttentionMetadata = extend_all_queries_by_1(
cad,
arange=self.arange,
new_slot_mapping=new_slot_mapping,
)
new_last_token_indices = new_cad.query_start_loc[1:] - 1
if num_rejected_tokens_gpu is not None:
new_last_token_indices -= num_rejected_tokens_gpu
return num_tokens, new_last_token_indices, new_cad
def load_model(self, target_model: Any) -> None:
"""Takes target_model to satisfy the type checker."""
# This must be computed before loading the draft model
# because that mutates the forward_context of the vllm_config
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
)
from vllm.compilation.backends import set_model_tag
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(
target_model_vllm_config=self.vllm_config
)
logger.info(
"Starting to load draft model %s. TP=%d, rank=%d",
draft_vllm_config.model_config.model,
draft_vllm_config.parallel_config.tensor_parallel_size,
draft_vllm_config.parallel_config.rank,
)
with set_model_tag("draft_model"):
self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model")
# This must be computed after loading the draft model
# because that mutates the forward_context of the vllm_config
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
- target_attn_layer_names
)
self.attn_layer_names = list(draft_attn_layer_names)
def create_vllm_config_for_draft_model(
target_model_vllm_config: VllmConfig,
) -> VllmConfig:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the draft model.
The vllm_config is useful when loading the draft model with get_model().
"""
old = target_model_vllm_config
new_parallel_config = old.speculative_config.draft_parallel_config.replace(
rank=old.parallel_config.rank
)
new: VllmConfig = old.replace(
quant_config=None, # quant_config is recomputed in __init__()
model_config=old.speculative_config.draft_model_config,
parallel_config=new_parallel_config,
)
return new
def compute_new_slot_mapping(
cad: CommonAttentionMetadata,
new_positions: torch.Tensor,
is_rejected_token_mask: torch.Tensor,
block_size: int,
max_model_len: int,
):
batch_size, n_blocks_per_req = cad.block_table_tensor.shape
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)
req_indices = torch.repeat_interleave(
req_indices, cad.naive_query_lens() + 1, output_size=len(new_positions)
)
# Clamp the positions to prevent an out-of-bounds error when indexing
# into block_table_tensor.
clamped_positions = torch.clamp(new_positions, max=max_model_len - 1)
block_table_indices = (
req_indices * n_blocks_per_req + clamped_positions // block_size
)
block_nums = cad.block_table_tensor.view(-1)[block_table_indices]
block_offsets = clamped_positions % block_size
new_slot_mapping = block_nums * block_size + block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len = new_positions >= max_model_len
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID)
return new_slot_mapping
@triton.jit
def merge_toks_kernel(
target_toks_ptr,
next_toks_ptr,
query_start_locs_ptr,
query_end_locs_ptr,
out_ptr_merged_toks,
out_ptr_is_rejected_tok,
target_toks_size,
rejected_tok_fill,
):
"""
Merges the `target_toks_ptr` and the `next_toks_ptr` into a new tensor
called `out_ptr_merged_toks`. Rejected tokens are those after the
`query_end_locs_ptr` and before the next `query_start_locs_ptr`. Fills the
rejected tokens positions with the value `rejected_tok_fill`. Also fills a mask
of the rejected tokens in `out_ptr_is_rejected_tok`.
"""
pid = tl.program_id(0)
start_loc = tl.load(query_start_locs_ptr + pid)
is_last_program = pid == tl.num_programs(0) - 1
if is_last_program:
next_start_loc = target_toks_size.to(tl.int32)
else:
next_start_loc = tl.load(query_start_locs_ptr + pid + 1).to(tl.int32)
end_loc = tl.load(query_end_locs_ptr + pid)
new_val = tl.load(next_toks_ptr + pid)
for i in range(start_loc, next_start_loc + 1):
if i <= end_loc: # copy existing tokens
old_val = tl.load(target_toks_ptr + i)
tl.store(out_ptr_merged_toks + pid + i, old_val)
tl.store(out_ptr_is_rejected_tok + pid + i, False)
elif i == end_loc + 1: # copy bonus token
tl.store(out_ptr_merged_toks + pid + i, new_val)
tl.store(out_ptr_is_rejected_tok + pid + i, False)
else: # fill rejected tokens
tl.store(out_ptr_merged_toks + pid + i, rejected_tok_fill)
tl.store(out_ptr_is_rejected_tok + pid + i, True)

View File

@@ -53,11 +53,12 @@ logger = init_logger(__name__)
PADDING_SLOT_ID = -1
class EagleProposer:
class SpecDecodeBaseProposer:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
pass_hidden_states_to_model: bool,
runner=None,
):
self.vllm_config = vllm_config
@@ -65,6 +66,7 @@ class EagleProposer:
assert self.speculative_config is not None
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
self.pass_hidden_states_to_model = pass_hidden_states_to_model
self.runner = runner
self.device = device
@@ -72,7 +74,11 @@ class EagleProposer:
self.max_model_len = vllm_config.model_config.max_model_len
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
# The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
)
self.token_arange_np = np.arange(self.max_num_tokens)
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
@@ -143,7 +149,6 @@ class EagleProposer:
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
self.arange = torch.arange(
max_num_slots_for_arange, device=device, dtype=torch.int32
@@ -245,11 +250,7 @@ class EagleProposer:
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
batch_size = common_attn_metadata.batch_size()
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
@@ -257,12 +258,17 @@ class EagleProposer:
target_hidden_states
)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
num_tokens, last_token_indices, common_attn_metadata = (
self.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
last_token_indices=last_token_indices,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
)
assert self.runner is not None
@@ -311,9 +317,10 @@ class EagleProposer:
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
# copy inputs to buffer for cudagraph
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
if self.pass_hidden_states_to_model:
# target_hidden_states and self.hidden_states can have different
# hidden dims. E.g. large target model and small draft model.
self.hidden_states[:num_tokens] = target_hidden_states
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
@@ -330,6 +337,14 @@ class EagleProposer:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
@@ -337,17 +352,13 @@ class EagleProposer:
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self._get_positions(num_input_tokens),
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method == "mtp":
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
@@ -357,9 +368,9 @@ class EagleProposer:
return draft_token_ids.view(-1, 1)
if self.uses_mrope:
positions = target_positions[:, last_token_indices]
positions = self.positions[:, last_token_indices]
else:
positions = target_positions[last_token_indices]
positions = self.positions[last_token_indices]
if self.method in (
"deepseek_mtp",
"ernie_mtp",
@@ -527,6 +538,14 @@ class EagleProposer:
inputs_embeds = None
# Run the model.
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(input_batch_size),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
@@ -534,17 +553,13 @@ class EagleProposer:
num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self._get_positions(input_batch_size),
hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
if self.method == "mtp":
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size])
draft_token_ids = logits.argmax(dim=-1)
@@ -554,6 +569,34 @@ class EagleProposer:
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def set_inputs_first_pass(
self,
target_token_ids: torch.Tensor,
next_token_ids: torch.Tensor,
target_positions: torch.Tensor,
last_token_indices: torch.Tensor | None,
cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
if last_token_indices is None:
last_token_indices = cad.query_start_loc[1:] - 1
num_tokens = target_token_ids.shape[0]
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
# copy inputs to buffer for cudagraph
self._set_positions(num_tokens, target_positions)
return num_tokens, last_token_indices, cad
def model_returns_tuple(self) -> bool:
return self.method not in ("mtp", "draft_model")
def prepare_next_token_ids_cpu(
self,
sampled_token_ids: list[list[int]],
@@ -1214,12 +1257,14 @@ class EagleProposer:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
self.model(
kwargs = dict(
input_ids=input_ids,
positions=self._get_positions(num_input_tokens),
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.pass_hidden_states_to_model:
kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
self.model(**kwargs)
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
"""Find and return the attention metadata builders for EAGLE layers.
@@ -1264,8 +1309,8 @@ class EagleProposer:
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
"""
Validate that all eagle layers belong to the same KVCacheGroup.
Need this assumption to ensure all eagle layers can use the
Validate that all drafting layers belong to the same KVCacheGroup.
Need this assumption to ensure all drafting layers can use the
same AttentionMetadata.
May extend to multiple AttentionMetadata in the future.
"""
@@ -1283,7 +1328,7 @@ class EagleProposer:
)
)
== 1
), "All eagle layers should belong to the same kv cache group"
), "All drafting layers should belong to the same kv cache group"
def _pad_batch_across_dp(
self,
@@ -1308,6 +1353,21 @@ class EagleProposer:
return num_tokens_dp_padded, num_toks_across_dp
class EagleProposer(SpecDecodeBaseProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(
vllm_config,
device,
pass_hidden_states_to_model=True,
runner=runner,
)
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.

View File

@@ -145,6 +145,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -432,10 +433,20 @@ class GPUModelRunner(
# layers in the draft model.
if self.speculative_config and get_pp_group().is_last_rank:
self.drafter: (
NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer
NgramProposer
| SuffixDecodingProposer
| EagleProposer
| DraftModelProposer
| MedusaProposer
)
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.uses_draft_model():
self.drafter = DraftModelProposer(
vllm_config=self.vllm_config,
device=self.device,
runner=self,
)
elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
@@ -3443,10 +3454,13 @@ class GPUModelRunner(
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
<= self.effective_drafter_max_model_len
)
if spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch:
# EAGLE speculative decoding can use the GPU sampled tokens
use_gpu_toks = (
spec_config.use_eagle() or spec_config.uses_draft_model()
) and not spec_config.disable_padded_drafter_batch
if use_gpu_toks:
# EAGLE/DraftModel speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
assert isinstance(self.drafter, EagleProposer)
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter:
propose_draft_token_ids(sampled_token_ids)
@@ -3679,8 +3693,8 @@ class GPUModelRunner(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
elif spec_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
elif spec_config.use_eagle() or spec_config.uses_draft_model():
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
if spec_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
@@ -4475,8 +4489,12 @@ class GPUModelRunner(
else:
hidden_states = outputs
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
if self.speculative_config and (
self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model()
):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
assert self.speculative_config is not None
# Eagle currently only supports PIECEWISE cudagraphs.
# Therefore only use cudagraphs if the main model uses PIECEWISE
# NOTE(lucas): this is a hack, need to clean up.
@@ -5652,8 +5670,11 @@ class GPUModelRunner(
kv_cache_config, kernel_block_sizes
)
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
if self.speculative_config and (
self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model()
):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
# validate all draft model layers belong to the same kv cache
# group
self.drafter.validate_same_kv_cache_group(kv_cache_config)

View File

@@ -352,8 +352,8 @@ def bind_kv_cache(
pass
else:
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
for layer_name in layer_names:
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():