feat: spec decode with draft models (#24322)
Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
@@ -54,7 +54,7 @@ def parse_args():
|
||||
"--method",
|
||||
type=str,
|
||||
default="eagle",
|
||||
choices=["ngram", "eagle", "eagle3", "mtp"],
|
||||
choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
|
||||
)
|
||||
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
||||
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
||||
@@ -70,7 +70,11 @@ def parse_args():
|
||||
parser.add_argument("--output-len", type=int, default=256)
|
||||
parser.add_argument("--model-dir", type=str, default=None)
|
||||
parser.add_argument("--eagle-dir", type=str, default=None)
|
||||
parser.add_argument("--draft-model", type=str, default=None)
|
||||
parser.add_argument("--custom-mm-prompts", action="store_true")
|
||||
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
|
||||
parser.add_argument("--disable-padded-drafter-batch", action="store_true")
|
||||
parser.add_argument("--max-num-seqs", type=int, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -111,6 +115,7 @@ def main(args):
|
||||
"method": args.method,
|
||||
"model": eagle_dir,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
"disable_padded_drafter_batch": args.disable_padded_drafter_batch,
|
||||
}
|
||||
elif args.method == "ngram":
|
||||
speculative_config = {
|
||||
@@ -119,6 +124,15 @@ def main(args):
|
||||
"prompt_lookup_max": args.prompt_lookup_max,
|
||||
"prompt_lookup_min": args.prompt_lookup_min,
|
||||
}
|
||||
elif args.method == "draft_model":
|
||||
assert args.draft_model is not None and args.draft_model != ""
|
||||
speculative_config = {
|
||||
"method": args.method,
|
||||
"model": args.draft_model,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
"enforce_eager": args.enforce_eager,
|
||||
"max_model_len": args.max_model_len,
|
||||
}
|
||||
elif args.method == "mtp":
|
||||
speculative_config = {
|
||||
"method": "mtp",
|
||||
@@ -133,12 +147,13 @@ def main(args):
|
||||
tensor_parallel_size=args.tp,
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
enforce_eager=args.enforce_eager,
|
||||
gpu_memory_utilization=0.9,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
max_model_len=args.max_model_len,
|
||||
limit_mm_per_prompt={"image": 5},
|
||||
disable_chunked_mm_input=True,
|
||||
max_num_seqs=args.max_num_seqs,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||
|
||||
@@ -4,13 +4,13 @@ import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import msgpack
|
||||
import regex as re
|
||||
import zmq
|
||||
from quart import Quart, make_response, request
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
@@ -10,15 +12,22 @@ from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.benchmarks.datasets import InstructCoderDataset
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.metrics.reader import Metric
|
||||
from vllm.v1.spec_decode.draft_model import (
|
||||
create_vllm_config_for_draft_model,
|
||||
merge_toks_kernel,
|
||||
)
|
||||
|
||||
MTP_SIMILARITY_RATE = 0.8
|
||||
|
||||
|
||||
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
|
||||
"""Skip test if available GPUs < tp_size on ROCm."""
|
||||
if current_platform.is_rocm():
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if available_gpus < tp_size:
|
||||
pytest.skip(
|
||||
@@ -26,15 +35,21 @@ def _skip_if_insufficient_gpus_for_tp(tp_size: int):
|
||||
)
|
||||
|
||||
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
Messages = list[dict[str, Any]]
|
||||
|
||||
|
||||
def get_test_prompts(
|
||||
mm_enabled: bool, quiet: bool = False, num_prompts: int = 100
|
||||
) -> list[Messages]:
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
if mm_enabled:
|
||||
prompt_types.append("mm")
|
||||
num_prompts = 100
|
||||
prompts = []
|
||||
|
||||
random.seed(0)
|
||||
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
||||
|
||||
if not quiet:
|
||||
print(f"Prompt types: {random_prompt_type_choices}")
|
||||
|
||||
# Generate a mixed batch of prompts, some of which can be easily
|
||||
@@ -75,11 +90,27 @@ def get_test_prompts(mm_enabled: bool):
|
||||
return prompts
|
||||
|
||||
|
||||
def get_instruct_coder_messages(n: int) -> list[Messages]:
|
||||
dataset = InstructCoderDataset(
|
||||
dataset_path="likaixin/InstructCoder", dataset_split="train"
|
||||
)
|
||||
prompts: Iterable[str] = dataset.sample_prompts(n=n)
|
||||
return [[{"role": "user", "content": prompt}] for prompt in prompts]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampling_config():
|
||||
return greedy_sampling()
|
||||
|
||||
|
||||
def greedy_sampling() -> SamplingParams:
|
||||
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
|
||||
|
||||
|
||||
def stochastic_sampling() -> SamplingParams:
|
||||
return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return "meta-llama/Llama-3.1-8B-Instruct"
|
||||
@@ -583,3 +614,269 @@ def test_mtp_correctness(
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArgsTest:
|
||||
target_model: str
|
||||
draft_model: str
|
||||
sampling_config: SamplingParams
|
||||
num_speculative_tokens: int
|
||||
expected_acceptance_rate: float
|
||||
expected_acceptance_len: float
|
||||
# Defaults
|
||||
target_tensor_parallel_size: int = 1
|
||||
draft_tensor_parallel_size: int = 1
|
||||
max_model_len: int = 1024
|
||||
gpu_memory_utilization: float = 0.5
|
||||
dataset: str = "test_prompts"
|
||||
num_prompts: int = 100
|
||||
|
||||
|
||||
cases = [
|
||||
# Same model for draft and target, greedy sampling.
|
||||
ArgsTest(
|
||||
target_model="Qwen/Qwen3-0.6B",
|
||||
draft_model="Qwen/Qwen3-0.6B",
|
||||
sampling_config=greedy_sampling(),
|
||||
num_speculative_tokens=3, # K
|
||||
expected_acceptance_len=3 + 1, # K + 1
|
||||
expected_acceptance_rate=1.0,
|
||||
),
|
||||
# Smaller draft model, stochastic sampling.
|
||||
ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
draft_model="Qwen/Qwen3-0.6B",
|
||||
sampling_config=stochastic_sampling(),
|
||||
num_speculative_tokens=3,
|
||||
expected_acceptance_len=2.8 + 1,
|
||||
expected_acceptance_rate=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("args", cases)
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
|
||||
assert_draft_model_correctness(args, enforce_eager)
|
||||
|
||||
|
||||
def test_draft_model_realistic_example():
|
||||
args = ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
draft_model="Qwen/Qwen3-0.6B",
|
||||
dataset="likaixin/InstructCoder",
|
||||
num_speculative_tokens=3,
|
||||
sampling_config=greedy_sampling(),
|
||||
# values below are not derived, but just prevent a regression
|
||||
expected_acceptance_len=2.8,
|
||||
expected_acceptance_rate=0.55,
|
||||
)
|
||||
assert_draft_model_correctness(args, enforce_eager=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"models",
|
||||
[
|
||||
# target_model, draft_model
|
||||
("Qwen/Qwen3-1.7B-FP8", "Qwen/Qwen3-0.6B"), # target quantized
|
||||
("Qwen/Qwen3-1.7B", "Qwen/Qwen3-0.6B-FP8"), # draft quantized
|
||||
],
|
||||
ids=["target_quantized", "draft_quantized"],
|
||||
)
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
|
||||
tgt_model, draft_model = models
|
||||
sd_case = ArgsTest(
|
||||
target_model=tgt_model,
|
||||
draft_model=draft_model,
|
||||
**some_high_acceptance_metrics(),
|
||||
)
|
||||
assert_draft_model_correctness(sd_case, enforce_eager)
|
||||
|
||||
|
||||
def test_draft_model_tensor_parallelism():
|
||||
"""Ensure spec decode works when running with TP > 1."""
|
||||
_skip_if_insufficient_gpus_for_tp(2)
|
||||
sd_case = ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
target_tensor_parallel_size=2,
|
||||
draft_model="Qwen/Qwen3-0.6B",
|
||||
draft_tensor_parallel_size=2,
|
||||
**some_high_acceptance_metrics(),
|
||||
)
|
||||
assert_draft_model_correctness(sd_case, enforce_eager=False)
|
||||
|
||||
|
||||
def test_draft_model_engine_args_tensor_parallelism():
|
||||
"""Ensure the vllm_config for the draft model is created correctly,
|
||||
and independently of the target model (quantization, TP, etc.)"""
|
||||
_skip_if_insufficient_gpus_for_tp(2)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-1.7B-FP8", # <<< tgt quantized
|
||||
tensor_parallel_size=2,
|
||||
speculative_config={
|
||||
"model": "Qwen/Qwen3-0.6B", # <<< draft not quantized
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": 3,
|
||||
"draft_tensor_parallel_size": 1, # <<< valid arg name
|
||||
},
|
||||
)
|
||||
tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
|
||||
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2
|
||||
assert tgt_vllm_config.quant_config.get_name() == "fp8"
|
||||
|
||||
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
|
||||
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
|
||||
assert draft_vllm_config.quant_config is None
|
||||
|
||||
|
||||
def test_draft_model_engine_args_rejects_invalid_tp_argname():
|
||||
"""The user should pass "draft_tensor_parallel_size" rather than
|
||||
"tensor_parallel_size". We enforce this with validation."""
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-1.7B",
|
||||
tensor_parallel_size=1,
|
||||
speculative_config={
|
||||
"model": "Qwen/Qwen3-0.6B",
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": 3,
|
||||
"tensor_parallel_size": 1, # <<< invalid arg name
|
||||
},
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
engine_args.create_engine_config()
|
||||
|
||||
|
||||
def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
|
||||
"""Compare the outputs using and not using speculative decoding.
|
||||
In the greedy decoding case, the outputs must match EXACTLY."""
|
||||
test_prompts: list[Messages] = get_messages(
|
||||
dataset=args.dataset, n=args.num_prompts
|
||||
)
|
||||
|
||||
spec_llm = LLM(
|
||||
model=args.target_model,
|
||||
speculative_config={
|
||||
"model": args.draft_model,
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": args.num_speculative_tokens,
|
||||
"max_model_len": args.max_model_len,
|
||||
"enforce_eager": enforce_eager,
|
||||
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
|
||||
"max_num_seqs": 100, # limit cudagraph capture runtime
|
||||
},
|
||||
max_model_len=args.max_model_len,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
tensor_parallel_size=args.target_tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
disable_log_stats=False, # enables get_metrics()
|
||||
)
|
||||
# we don't check the outputs, only check the metrics
|
||||
spec_llm.chat(test_prompts, args.sampling_config)
|
||||
metrics = spec_llm.get_metrics()
|
||||
|
||||
acceptance_rate: float = compute_acceptance_rate(metrics)
|
||||
acceptance_len: float = compute_acceptance_len(metrics)
|
||||
del spec_llm # CLEANUP
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
print(
|
||||
f"spec-decode: target={args.target_model}, draft={args.draft_model}, "
|
||||
f"temperature={args.sampling_config.temperature:.2f}, "
|
||||
f"acceptance_rate={acceptance_rate:.2f}, "
|
||||
f"acceptance_len={acceptance_len:.2f}, "
|
||||
)
|
||||
|
||||
assert acceptance_rate >= args.expected_acceptance_rate
|
||||
assert acceptance_len >= args.expected_acceptance_len
|
||||
|
||||
|
||||
def get_messages(dataset: str, n: int) -> list[Messages]:
|
||||
if dataset == "test_prompts":
|
||||
return get_test_prompts(mm_enabled=False, quiet=True, num_prompts=n)
|
||||
elif dataset == "likaixin/InstructCoder":
|
||||
return get_instruct_coder_messages(n=n)
|
||||
else:
|
||||
raise NotImplementedError(f"Dataset '{dataset}' not implemented")
|
||||
|
||||
|
||||
def some_high_acceptance_metrics() -> dict:
|
||||
return {
|
||||
"sampling_config": greedy_sampling(),
|
||||
"num_speculative_tokens": 3,
|
||||
"expected_acceptance_len": 2.90 + 1,
|
||||
"expected_acceptance_rate": 0.90,
|
||||
}
|
||||
|
||||
|
||||
def test_merge_toks_kernel():
|
||||
device = "cuda"
|
||||
merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2
|
||||
merged = torch.full((merged_len,), -100, device=device) # -100 is arbitrary
|
||||
is_rejected_tok = torch.full((merged_len,), True, device=device)
|
||||
grid = (2,)
|
||||
merge_toks_kernel[grid](
|
||||
target_toks_ptr=torch.tensor([0, 1, 2, 0, 1], device=device),
|
||||
next_toks_ptr=torch.tensor([3, 2], device=device),
|
||||
query_start_locs_ptr=torch.tensor([0, 3], device=device),
|
||||
query_end_locs_ptr=torch.tensor([2, 4], device=device),
|
||||
out_ptr_merged_toks=merged,
|
||||
out_ptr_is_rejected_tok=is_rejected_tok,
|
||||
target_toks_size=5,
|
||||
rejected_tok_fill=-1,
|
||||
)
|
||||
expected_merged = torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device)
|
||||
assert torch.allclose(merged, expected_merged)
|
||||
|
||||
expected_rejected_toks = torch.tensor([False] * merged_len, device=device)
|
||||
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
|
||||
|
||||
|
||||
def test_merge_toks_kernel_with_rejected_tokens():
|
||||
device = "cuda"
|
||||
merged_size = 9 + 2 # len(target_toks) = 9, batch_size = 2
|
||||
merged = torch.full((merged_size,), -100, device=device)
|
||||
is_rejected_tok = torch.full((merged_size,), True, device=device)
|
||||
grid = (2,)
|
||||
merge_toks_kernel[grid](
|
||||
# rejected tokens
|
||||
# ↓ ↓ ↓ ↓
|
||||
target_toks_ptr=torch.tensor([0, 1, 2, 13, 14, 15, 0, 1, 22], device=device),
|
||||
next_toks_ptr=torch.tensor([3, 2], device=device),
|
||||
query_start_locs_ptr=torch.tensor([0, 6], device=device),
|
||||
query_end_locs_ptr=torch.tensor([2, 7], device=device),
|
||||
out_ptr_merged_toks=merged,
|
||||
out_ptr_is_rejected_tok=is_rejected_tok,
|
||||
target_toks_size=9,
|
||||
rejected_tok_fill=-1,
|
||||
)
|
||||
expected_merged = torch.tensor([0, 1, 2, 3, -1, -1, -1, 0, 1, 2, -1], device=device)
|
||||
assert torch.allclose(merged, expected_merged)
|
||||
|
||||
expected_rejected_toks = torch.tensor(
|
||||
[False, False, False, False, True, True, True, False, False, False, True],
|
||||
device=device,
|
||||
)
|
||||
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
|
||||
|
||||
|
||||
def compute_acceptance_rate(metrics: list[Metric]) -> float:
|
||||
name2metric = {metric.name: metric for metric in metrics}
|
||||
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
|
||||
if n_draft_toks == 0:
|
||||
return float("nan")
|
||||
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
|
||||
return n_accepted_toks / n_draft_toks
|
||||
|
||||
|
||||
def compute_acceptance_len(metrics: list[Metric]) -> float:
|
||||
name2metric = {metric.name: metric for metric in metrics}
|
||||
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore
|
||||
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
|
||||
if n_drafts == 0:
|
||||
return 1
|
||||
return 1 + (n_accepted_toks / n_drafts)
|
||||
|
||||
@@ -55,3 +55,38 @@ def test_bind_kv_cache_non_attention(default_vllm_config):
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
|
||||
assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]
|
||||
|
||||
|
||||
def test_bind_kv_cache_draft_model(default_vllm_config):
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
layer_names = [
|
||||
"model.layers.0.attn",
|
||||
"model.layers.1.attn",
|
||||
"draft_model.layers.0.attn",
|
||||
"draft_model.layers.1.attn",
|
||||
]
|
||||
ctx = {
|
||||
layer_name: Attention(32, 128, 0.1, prefix=layer_name)
|
||||
for layer_name in layer_names
|
||||
}
|
||||
kv_cache = {layer_name: torch.zeros((1,)) for layer_name in layer_names}
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
|
||||
assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"]
|
||||
assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"]
|
||||
assert (
|
||||
ctx["draft_model.layers.0.attn"].kv_cache[0]
|
||||
is kv_cache["draft_model.layers.0.attn"]
|
||||
)
|
||||
assert (
|
||||
ctx["draft_model.layers.1.attn"].kv_cache[0]
|
||||
is kv_cache["draft_model.layers.1.attn"]
|
||||
)
|
||||
|
||||
# caches are ordered by layer_index, interleaving target and draft model
|
||||
assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"]
|
||||
assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"]
|
||||
assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"]
|
||||
assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"]
|
||||
|
||||
@@ -2593,17 +2593,10 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
) -> list[SampleRequest]:
|
||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
||||
sampled_requests = []
|
||||
for i, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = (
|
||||
f"{item['input']}\n\n{item['instruction']} Just output "
|
||||
"the code, do not include any explanation."
|
||||
)
|
||||
|
||||
for i, prompt in enumerate(self.sample_prompts(n=num_requests)):
|
||||
# apply template
|
||||
if not skip_chat_template:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
@@ -2626,6 +2619,14 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
)
|
||||
return sampled_requests
|
||||
|
||||
def sample_prompts(self, n: int) -> Iterator[str]:
|
||||
for item in self.data.take(n):
|
||||
prompt = (
|
||||
f"{item['input']}\n\n{item['instruction']} Just output "
|
||||
"the code, do not include any explanation."
|
||||
)
|
||||
yield prompt
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MT-Bench Dataset Implementation
|
||||
|
||||
@@ -8,8 +8,12 @@ import time
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
async def wait_for_endpoint(
|
||||
request_func: RequestFunc,
|
||||
@@ -61,6 +65,8 @@ async def wait_for_endpoint(
|
||||
if output.success:
|
||||
pbar.close()
|
||||
return output
|
||||
else:
|
||||
logger.warning("Endpoint is not ready. Error='%s'", output.error)
|
||||
except aiohttp.ClientConnectorError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
@@ -709,3 +710,6 @@ class ParallelConfig:
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def replace(self, **kwargs) -> Self:
|
||||
return replace(self, **kwargs)
|
||||
|
||||
@@ -77,6 +77,9 @@ class SpeculativeConfig:
|
||||
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
|
||||
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
||||
or the same as the target model's tensor parallel size."""
|
||||
tensor_parallel_size: int | None = None
|
||||
"""Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
|
||||
warn users when they mistakenly provide the wrong argument."""
|
||||
|
||||
# Draft model configuration
|
||||
quantization: me_quant.QuantizationMethods | None = None
|
||||
@@ -397,13 +400,11 @@ class SpeculativeConfig:
|
||||
"one layer. Might need some code changes "
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif self.method == "draft_model":
|
||||
pass
|
||||
else:
|
||||
self.method = "draft_model"
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding with draft model is not "
|
||||
"supported yet. Please consider using other "
|
||||
"speculative decoding methods such as ngram, medusa, "
|
||||
"eagle, or mtp."
|
||||
f"Unsupported speculative method: '{self.method}'"
|
||||
)
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
@@ -631,6 +632,12 @@ class SpeculativeConfig:
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _verify_args(self) -> Self:
|
||||
if self.tensor_parallel_size is not None:
|
||||
raise ValueError(
|
||||
"'tensor_parallel_size' is not a valid argument in the "
|
||||
"speculative_config. Please pass 'draft_tensor_parallel_size' instead."
|
||||
)
|
||||
|
||||
if self.num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens must be provided with "
|
||||
@@ -669,12 +676,32 @@ class SpeculativeConfig:
|
||||
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
|
||||
f"Got {self.target_model_config.hf_text_config.model_type=}"
|
||||
)
|
||||
|
||||
self.verify_equal_vocab_size_if_draft_model()
|
||||
return self
|
||||
|
||||
def verify_equal_vocab_size_if_draft_model(self):
|
||||
if (
|
||||
self.method == "draft_model"
|
||||
and self.target_model_config is not None
|
||||
and self.draft_model_config is not None
|
||||
):
|
||||
target_vocab_size = self.target_model_config.get_vocab_size()
|
||||
draft_vocab_size = self.draft_model_config.get_vocab_size()
|
||||
if target_vocab_size != draft_vocab_size:
|
||||
raise ValueError(
|
||||
f"Target and draft model should have the same vocabulary size. "
|
||||
f"Target model vocab_size={target_vocab_size}. "
|
||||
f"Draft model vocab_size={draft_vocab_size}. "
|
||||
f"Using models with different tokenizers can cause out-of-bounds "
|
||||
f"errors during speculative decoding."
|
||||
)
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3", "mtp")
|
||||
|
||||
def uses_draft_model(self) -> bool:
|
||||
return self.method == "draft_model"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
|
||||
|
||||
@@ -1214,10 +1214,19 @@ class VllmConfig:
|
||||
compilation_config = self.compilation_config
|
||||
computed_compile_ranges_split_points = []
|
||||
|
||||
# The upper bound of the compile ranges is the max_num_batched_tokens
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
if max_num_batched_tokens is not None:
|
||||
computed_compile_ranges_split_points.append(max_num_batched_tokens)
|
||||
# The upper bound of the compile ranges is the max_num_batched_tokens.
|
||||
# For speculative decoding with draft model, the compile range must be extended
|
||||
# by 1 for each sequence.
|
||||
compile_range_end = self.scheduler_config.max_num_batched_tokens
|
||||
if compile_range_end is not None:
|
||||
do_extend: bool = (
|
||||
self.speculative_config is not None
|
||||
and self.speculative_config.uses_draft_model()
|
||||
)
|
||||
if do_extend:
|
||||
compile_range_end += self.scheduler_config.max_num_seqs
|
||||
|
||||
computed_compile_ranges_split_points.append(compile_range_end)
|
||||
|
||||
# Add the compile ranges for flashinfer
|
||||
if compilation_config.pass_config.fuse_allreduce_rms:
|
||||
@@ -1228,10 +1237,7 @@ class VllmConfig:
|
||||
self.model_config.get_hidden_size()
|
||||
* self.model_config.dtype.itemsize
|
||||
)
|
||||
if (
|
||||
max_num_batched_tokens is not None
|
||||
and max_token_num < max_num_batched_tokens
|
||||
):
|
||||
if compile_range_end is not None and max_token_num < compile_range_end:
|
||||
computed_compile_ranges_split_points.append(max_token_num)
|
||||
else:
|
||||
logger.debug(
|
||||
@@ -1243,11 +1249,7 @@ class VllmConfig:
|
||||
for x in compilation_config.compile_ranges_split_points:
|
||||
assert isinstance(x, int)
|
||||
assert x > 0, f"Invalid compile range split point: {x}"
|
||||
if (
|
||||
max_num_batched_tokens is not None
|
||||
and x < max_num_batched_tokens
|
||||
and x > 1
|
||||
):
|
||||
if compile_range_end is not None and x < compile_range_end and x > 1:
|
||||
computed_compile_ranges_split_points.append(x)
|
||||
compilation_config.compile_ranges_split_points = sorted(
|
||||
computed_compile_ranges_split_points
|
||||
@@ -1316,6 +1318,14 @@ class VllmConfig:
|
||||
path = self.compilation_config.debug_dump_path / append_path
|
||||
return path
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""
|
||||
Replace attributes of the config, and 'recompute' the config.
|
||||
dataclass.replace() calls __init__() and __post_init__(), source:
|
||||
https://docs.python.org/3/library/dataclasses.html#dataclasses.replace
|
||||
"""
|
||||
return replace(self, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"model={self.model_config.model!r}, "
|
||||
|
||||
@@ -1776,21 +1776,6 @@ class EngineArgs:
|
||||
):
|
||||
_raise_unsupported_error(feature_name="Concurrent Partial Prefill")
|
||||
|
||||
# N-gram, Medusa, and Eagle are supported for speculative decoding.
|
||||
if self.speculative_config is not None:
|
||||
# speculative_config could still be a dict at this point
|
||||
if isinstance(self.speculative_config, dict):
|
||||
method = self.speculative_config.get("method", None)
|
||||
else:
|
||||
method = self.speculative_config.method
|
||||
|
||||
if method == "draft_model":
|
||||
raise NotImplementedError(
|
||||
"Draft model speculative decoding is not supported yet. "
|
||||
"Please consider using other speculative decoding methods "
|
||||
"such as ngram, medusa, eagle, or mtp."
|
||||
)
|
||||
|
||||
if self.pipeline_parallel_size > 1:
|
||||
supports_pp = getattr(
|
||||
self.distributed_executor_backend, "supports_pp", False
|
||||
|
||||
@@ -124,12 +124,17 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
|
||||
|
||||
def get_model(
|
||||
*, vllm_config: VllmConfig, model_config: ModelConfig | None = None
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: ModelConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
loader = get_model_loader(vllm_config.load_config)
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
return loader.load_model(vllm_config=vllm_config, model_config=model_config)
|
||||
return loader.load_model(
|
||||
vllm_config=vllm_config, model_config=model_config, prefix=prefix
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -36,7 +36,7 @@ class BaseModelLoader(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
|
||||
) -> nn.Module:
|
||||
"""Load a model with the given configurations."""
|
||||
device_config = vllm_config.device_config
|
||||
@@ -48,7 +48,7 @@ class BaseModelLoader(ABC):
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(
|
||||
vllm_config=vllm_config, model_config=model_config
|
||||
vllm_config=vllm_config, model_config=model_config, prefix=prefix
|
||||
)
|
||||
|
||||
log_model_inspection(model)
|
||||
|
||||
@@ -335,7 +335,7 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
)
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
|
||||
) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
local_model_path = self._prepare_weights(model_config)
|
||||
@@ -364,7 +364,7 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
model = initialize_model(vllm_config=vllm_config, prefix=prefix)
|
||||
self.load_weights(model, model_config)
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
@@ -68,6 +68,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
def _load_model_serialized_cpu(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer to the CPU.
|
||||
|
||||
@@ -80,7 +81,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
model_config = vllm_config.model_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
model = initialize_model(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
return model.eval()
|
||||
@@ -112,7 +113,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
|
||||
) -> nn.Module:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self._verify_config(model_config, parallel_config)
|
||||
@@ -134,7 +135,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
)
|
||||
self.load_weights(model, model_config)
|
||||
return model
|
||||
return self._load_model_serialized_cpu(vllm_config=vllm_config)
|
||||
return self._load_model_serialized_cpu(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
|
||||
@@ -329,6 +329,16 @@ class CommonAttentionMetadata:
|
||||
|
||||
_num_computed_tokens_cache: torch.Tensor | None = None
|
||||
|
||||
def batch_size(self) -> int:
|
||||
return self.seq_lens.shape[0]
|
||||
|
||||
def naive_query_lens(self) -> torch.Tensor:
|
||||
"""Naive because it assumes that query ends where the next query starts."""
|
||||
return self.query_start_loc[1:] - self.query_start_loc[:-1]
|
||||
|
||||
def replace(self, **kwargs) -> "CommonAttentionMetadata":
|
||||
return replace(self, **kwargs)
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"""
|
||||
|
||||
@@ -818,3 +818,35 @@ def get_dcp_local_seq_lens(
|
||||
)
|
||||
dcp_local_seq_lens = base + remainder
|
||||
return dcp_local_seq_lens.squeeze(1)
|
||||
|
||||
|
||||
def extend_all_queries_by_1(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
arange: torch.Tensor,
|
||||
new_slot_mapping: torch.Tensor,
|
||||
) -> CommonAttentionMetadata:
|
||||
"""
|
||||
Creates a new CommonAttentionMetadata with all query lengths increased by 1.
|
||||
Also all seq lens are increased by 1.
|
||||
This is useful e.g. in speculative decoding with draft models, where we
|
||||
extend each sequence by 1 token.
|
||||
The slot mapping is computed externally, as it requires more information.
|
||||
"""
|
||||
cad = common_attn_metadata
|
||||
# query start loc must be increased by [+0, +1, +2, ..., +batch_size]
|
||||
new_query_start_loc = cad.query_start_loc + arange[: len(cad.query_start_loc)]
|
||||
new_query_start_loc_cpu = cad.query_start_loc_cpu + torch.arange(
|
||||
len(cad.query_start_loc_cpu), dtype=torch.int32
|
||||
)
|
||||
new_cad = cad.replace(
|
||||
query_start_loc=new_query_start_loc,
|
||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||
seq_lens=cad.seq_lens + 1,
|
||||
# each request is extended by 1 token -> batch_size tokens are added
|
||||
num_actual_tokens=cad.num_actual_tokens + cad.batch_size(),
|
||||
# All query lens increase by 1, so max query len increases by 1
|
||||
max_query_len=cad.max_query_len + 1,
|
||||
max_seq_len=cad.max_seq_len + 1,
|
||||
slot_mapping=new_slot_mapping,
|
||||
)
|
||||
return new_cad
|
||||
|
||||
@@ -208,6 +208,8 @@ class Scheduler(SchedulerInterface):
|
||||
if speculative_config.use_eagle():
|
||||
self.use_eagle = True
|
||||
self.num_lookahead_tokens = self.num_spec_tokens
|
||||
if speculative_config.uses_draft_model():
|
||||
self.num_lookahead_tokens = self.num_spec_tokens
|
||||
|
||||
# Create the KV cache manager.
|
||||
self.kv_cache_manager = KVCacheManager(
|
||||
|
||||
271
vllm/v1/spec_decode/draft_model.py
Normal file
271
vllm/v1/spec_decode/draft_model.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
extend_all_queries_by_1,
|
||||
)
|
||||
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DraftModelProposer(SpecDecodeBaseProposer):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
pass_hidden_states_to_model=False,
|
||||
runner=runner,
|
||||
)
|
||||
self._raise_if_multimodal()
|
||||
self._raise_if_mrope()
|
||||
self._raise_if_padded_drafter_batch_disabled()
|
||||
self._raise_if_vocab_size_mismatch()
|
||||
self._raise_if_draft_tp_mismatch()
|
||||
|
||||
def _block_size(self) -> int:
|
||||
builder = self._get_attention_metadata_builder()
|
||||
return builder.kv_cache_spec.block_size
|
||||
|
||||
def _raise_if_multimodal(self):
|
||||
if self.supports_mm_inputs:
|
||||
raise NotImplementedError(
|
||||
"Speculative Decoding with draft models "
|
||||
"does not support multimodal models yet"
|
||||
)
|
||||
|
||||
def _raise_if_mrope(self):
|
||||
if self.draft_model_config.uses_mrope:
|
||||
raise NotImplementedError(
|
||||
"Speculative Decoding with draft models does not support M-RoPE yet"
|
||||
)
|
||||
|
||||
def _raise_if_padded_drafter_batch_disabled(self):
|
||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
raise NotImplementedError(
|
||||
"Speculative Decoding with draft models only supports "
|
||||
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
|
||||
" in the speculative_config."
|
||||
)
|
||||
|
||||
def _raise_if_vocab_size_mismatch(self):
|
||||
self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model()
|
||||
|
||||
def _raise_if_draft_tp_mismatch(self):
|
||||
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
|
||||
# the draft model with TP = 1, then the different TP ranks collide.
|
||||
# Specifically when all ranks compile the draft model on rank 0
|
||||
# (because TP=1), then the torch compile cache is overwritten and corrupted.
|
||||
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
|
||||
# To prevent this error, we assert that both TP sizes must be the same.
|
||||
spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config
|
||||
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
|
||||
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
|
||||
if draft_tp != tgt_tp:
|
||||
raise ValueError(
|
||||
f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
|
||||
f"must be the same. Got {draft_tp} and {tgt_tp}. "
|
||||
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
|
||||
)
|
||||
|
||||
def set_inputs_first_pass(
|
||||
self,
|
||||
target_token_ids: torch.Tensor,
|
||||
next_token_ids: torch.Tensor,
|
||||
target_positions: torch.Tensor,
|
||||
last_token_indices: torch.Tensor | None,
|
||||
cad: CommonAttentionMetadata,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None,
|
||||
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
|
||||
batch_size = cad.batch_size()
|
||||
grid = (batch_size,)
|
||||
start_locs = cad.query_start_loc[:-1]
|
||||
end_locs = cad.query_start_loc[1:] - 1
|
||||
if num_rejected_tokens_gpu is not None:
|
||||
end_locs -= num_rejected_tokens_gpu
|
||||
|
||||
num_tokens = target_token_ids.shape[0] + batch_size
|
||||
is_rejected_tok = torch.empty(
|
||||
(num_tokens,), device=self.input_ids.device, dtype=torch.bool
|
||||
)
|
||||
merge_toks_kernel[grid](
|
||||
target_toks_ptr=target_token_ids,
|
||||
next_toks_ptr=next_token_ids,
|
||||
query_start_locs_ptr=start_locs,
|
||||
query_end_locs_ptr=end_locs,
|
||||
out_ptr_merged_toks=self.input_ids,
|
||||
out_ptr_is_rejected_tok=is_rejected_tok,
|
||||
target_toks_size=target_token_ids.shape[0],
|
||||
# passing a negative rejected_tok_fill value will raise an error
|
||||
# when the value is used to index into embeddings.
|
||||
# Therefore, we pass a valid integer, e.g. 0.
|
||||
rejected_tok_fill=0,
|
||||
)
|
||||
merge_toks_kernel[grid](
|
||||
target_toks_ptr=target_positions,
|
||||
next_toks_ptr=target_positions[end_locs] + 1,
|
||||
query_start_locs_ptr=start_locs,
|
||||
query_end_locs_ptr=end_locs,
|
||||
out_ptr_merged_toks=self.positions,
|
||||
out_ptr_is_rejected_tok=is_rejected_tok,
|
||||
target_toks_size=target_positions.shape[0],
|
||||
rejected_tok_fill=0,
|
||||
)
|
||||
|
||||
# recompute slot mapping
|
||||
new_slot_mapping = compute_new_slot_mapping(
|
||||
cad=cad,
|
||||
new_positions=self.positions[:num_tokens],
|
||||
is_rejected_token_mask=is_rejected_tok,
|
||||
block_size=self._block_size(),
|
||||
max_model_len=self.max_model_len,
|
||||
)
|
||||
# update common_attn_metadata
|
||||
new_cad: CommonAttentionMetadata = extend_all_queries_by_1(
|
||||
cad,
|
||||
arange=self.arange,
|
||||
new_slot_mapping=new_slot_mapping,
|
||||
)
|
||||
|
||||
new_last_token_indices = new_cad.query_start_loc[1:] - 1
|
||||
if num_rejected_tokens_gpu is not None:
|
||||
new_last_token_indices -= num_rejected_tokens_gpu
|
||||
|
||||
return num_tokens, new_last_token_indices, new_cad
|
||||
|
||||
def load_model(self, target_model: Any) -> None:
|
||||
"""Takes target_model to satisfy the type checker."""
|
||||
|
||||
# This must be computed before loading the draft model
|
||||
# because that mutates the forward_context of the vllm_config
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
|
||||
)
|
||||
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(
|
||||
target_model_vllm_config=self.vllm_config
|
||||
)
|
||||
logger.info(
|
||||
"Starting to load draft model %s. TP=%d, rank=%d",
|
||||
draft_vllm_config.model_config.model,
|
||||
draft_vllm_config.parallel_config.tensor_parallel_size,
|
||||
draft_vllm_config.parallel_config.rank,
|
||||
)
|
||||
with set_model_tag("draft_model"):
|
||||
self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model")
|
||||
|
||||
# This must be computed after loading the draft model
|
||||
# because that mutates the forward_context of the vllm_config
|
||||
draft_attn_layer_names = (
|
||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
|
||||
- target_attn_layer_names
|
||||
)
|
||||
self.attn_layer_names = list(draft_attn_layer_names)
|
||||
|
||||
|
||||
def create_vllm_config_for_draft_model(
|
||||
target_model_vllm_config: VllmConfig,
|
||||
) -> VllmConfig:
|
||||
"""The vllm_config is configured for the target model, e.g.
|
||||
its quant_config and parallel_config. But the draft model is potentially
|
||||
quantized differently, and has potentially different tensor_parallel_size.
|
||||
This function creates a new vllm_config configured for the draft model.
|
||||
The vllm_config is useful when loading the draft model with get_model().
|
||||
"""
|
||||
old = target_model_vllm_config
|
||||
new_parallel_config = old.speculative_config.draft_parallel_config.replace(
|
||||
rank=old.parallel_config.rank
|
||||
)
|
||||
new: VllmConfig = old.replace(
|
||||
quant_config=None, # quant_config is recomputed in __init__()
|
||||
model_config=old.speculative_config.draft_model_config,
|
||||
parallel_config=new_parallel_config,
|
||||
)
|
||||
return new
|
||||
|
||||
|
||||
def compute_new_slot_mapping(
|
||||
cad: CommonAttentionMetadata,
|
||||
new_positions: torch.Tensor,
|
||||
is_rejected_token_mask: torch.Tensor,
|
||||
block_size: int,
|
||||
max_model_len: int,
|
||||
):
|
||||
batch_size, n_blocks_per_req = cad.block_table_tensor.shape
|
||||
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)
|
||||
req_indices = torch.repeat_interleave(
|
||||
req_indices, cad.naive_query_lens() + 1, output_size=len(new_positions)
|
||||
)
|
||||
# Clamp the positions to prevent an out-of-bounds error when indexing
|
||||
# into block_table_tensor.
|
||||
clamped_positions = torch.clamp(new_positions, max=max_model_len - 1)
|
||||
block_table_indices = (
|
||||
req_indices * n_blocks_per_req + clamped_positions // block_size
|
||||
)
|
||||
block_nums = cad.block_table_tensor.view(-1)[block_table_indices]
|
||||
block_offsets = clamped_positions % block_size
|
||||
new_slot_mapping = block_nums * block_size + block_offsets
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
exceeds_max_model_len = new_positions >= max_model_len
|
||||
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
# Mask out rejected tokens to prevent saves to the KV cache.
|
||||
new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID)
|
||||
return new_slot_mapping
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_toks_kernel(
|
||||
target_toks_ptr,
|
||||
next_toks_ptr,
|
||||
query_start_locs_ptr,
|
||||
query_end_locs_ptr,
|
||||
out_ptr_merged_toks,
|
||||
out_ptr_is_rejected_tok,
|
||||
target_toks_size,
|
||||
rejected_tok_fill,
|
||||
):
|
||||
"""
|
||||
Merges the `target_toks_ptr` and the `next_toks_ptr` into a new tensor
|
||||
called `out_ptr_merged_toks`. Rejected tokens are those after the
|
||||
`query_end_locs_ptr` and before the next `query_start_locs_ptr`. Fills the
|
||||
rejected tokens positions with the value `rejected_tok_fill`. Also fills a mask
|
||||
of the rejected tokens in `out_ptr_is_rejected_tok`.
|
||||
"""
|
||||
pid = tl.program_id(0)
|
||||
start_loc = tl.load(query_start_locs_ptr + pid)
|
||||
is_last_program = pid == tl.num_programs(0) - 1
|
||||
if is_last_program:
|
||||
next_start_loc = target_toks_size.to(tl.int32)
|
||||
else:
|
||||
next_start_loc = tl.load(query_start_locs_ptr + pid + 1).to(tl.int32)
|
||||
|
||||
end_loc = tl.load(query_end_locs_ptr + pid)
|
||||
new_val = tl.load(next_toks_ptr + pid)
|
||||
for i in range(start_loc, next_start_loc + 1):
|
||||
if i <= end_loc: # copy existing tokens
|
||||
old_val = tl.load(target_toks_ptr + i)
|
||||
tl.store(out_ptr_merged_toks + pid + i, old_val)
|
||||
tl.store(out_ptr_is_rejected_tok + pid + i, False)
|
||||
elif i == end_loc + 1: # copy bonus token
|
||||
tl.store(out_ptr_merged_toks + pid + i, new_val)
|
||||
tl.store(out_ptr_is_rejected_tok + pid + i, False)
|
||||
else: # fill rejected tokens
|
||||
tl.store(out_ptr_merged_toks + pid + i, rejected_tok_fill)
|
||||
tl.store(out_ptr_is_rejected_tok + pid + i, True)
|
||||
@@ -53,11 +53,12 @@ logger = init_logger(__name__)
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
class EagleProposer:
|
||||
class SpecDecodeBaseProposer:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
pass_hidden_states_to_model: bool,
|
||||
runner=None,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
@@ -65,6 +66,7 @@ class EagleProposer:
|
||||
assert self.speculative_config is not None
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
self.method = self.speculative_config.method
|
||||
self.pass_hidden_states_to_model = pass_hidden_states_to_model
|
||||
|
||||
self.runner = runner
|
||||
self.device = device
|
||||
@@ -72,7 +74,11 @@ class EagleProposer:
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
# The drafter can get longer sequences than the target model.
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
|
||||
)
|
||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
||||
# We need to get the hidden size from the draft model config because
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
@@ -143,7 +149,6 @@ class EagleProposer:
|
||||
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
|
||||
self.arange = torch.arange(
|
||||
max_num_slots_for_arange, device=device, dtype=torch.int32
|
||||
@@ -245,11 +250,7 @@ class EagleProposer:
|
||||
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
|
||||
if last_token_indices is None:
|
||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
batch_size = common_attn_metadata.batch_size()
|
||||
|
||||
if self.method == "eagle3":
|
||||
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||
@@ -257,12 +258,17 @@ class EagleProposer:
|
||||
target_hidden_states
|
||||
)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
num_tokens, last_token_indices, common_attn_metadata = (
|
||||
self.set_inputs_first_pass(
|
||||
target_token_ids=target_token_ids,
|
||||
next_token_ids=next_token_ids,
|
||||
target_positions=target_positions,
|
||||
last_token_indices=last_token_indices,
|
||||
cad=common_attn_metadata,
|
||||
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
|
||||
)
|
||||
)
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
@@ -311,8 +317,9 @@ class EagleProposer:
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self._set_positions(num_tokens, target_positions)
|
||||
if self.pass_hidden_states_to_model:
|
||||
# target_hidden_states and self.hidden_states can have different
|
||||
# hidden dims. E.g. large target model and small draft model.
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
@@ -330,6 +337,14 @@ class EagleProposer:
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(num_input_tokens),
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.pass_hidden_states_to_model:
|
||||
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
@@ -337,17 +352,13 @@ class EagleProposer:
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self._get_positions(num_input_tokens),
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method == "mtp":
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
@@ -357,9 +368,9 @@ class EagleProposer:
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = target_positions[:, last_token_indices]
|
||||
positions = self.positions[:, last_token_indices]
|
||||
else:
|
||||
positions = target_positions[last_token_indices]
|
||||
positions = self.positions[last_token_indices]
|
||||
if self.method in (
|
||||
"deepseek_mtp",
|
||||
"ernie_mtp",
|
||||
@@ -527,6 +538,14 @@ class EagleProposer:
|
||||
inputs_embeds = None
|
||||
|
||||
# Run the model.
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(input_batch_size),
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.pass_hidden_states_to_model:
|
||||
model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
@@ -534,17 +553,13 @@ class EagleProposer:
|
||||
num_tokens_across_dp=batch_size_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self._get_positions(input_batch_size),
|
||||
hidden_states=self.hidden_states[:input_batch_size],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method == "mtp":
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size])
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
@@ -554,6 +569,34 @@ class EagleProposer:
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
|
||||
def set_inputs_first_pass(
|
||||
self,
|
||||
target_token_ids: torch.Tensor,
|
||||
next_token_ids: torch.Tensor,
|
||||
target_positions: torch.Tensor,
|
||||
last_token_indices: torch.Tensor | None,
|
||||
cad: CommonAttentionMetadata,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None,
|
||||
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
|
||||
if last_token_indices is None:
|
||||
last_token_indices = cad.query_start_loc[1:] - 1
|
||||
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self._set_positions(num_tokens, target_positions)
|
||||
|
||||
return num_tokens, last_token_indices, cad
|
||||
|
||||
def model_returns_tuple(self) -> bool:
|
||||
return self.method not in ("mtp", "draft_model")
|
||||
|
||||
def prepare_next_token_ids_cpu(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
@@ -1214,12 +1257,14 @@ class EagleProposer:
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
self.model(
|
||||
kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
positions=self._get_positions(num_input_tokens),
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.pass_hidden_states_to_model:
|
||||
kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
self.model(**kwargs)
|
||||
|
||||
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
|
||||
"""Find and return the attention metadata builders for EAGLE layers.
|
||||
@@ -1264,8 +1309,8 @@ class EagleProposer:
|
||||
|
||||
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Validate that all eagle layers belong to the same KVCacheGroup.
|
||||
Need this assumption to ensure all eagle layers can use the
|
||||
Validate that all drafting layers belong to the same KVCacheGroup.
|
||||
Need this assumption to ensure all drafting layers can use the
|
||||
same AttentionMetadata.
|
||||
May extend to multiple AttentionMetadata in the future.
|
||||
"""
|
||||
@@ -1283,7 +1328,7 @@ class EagleProposer:
|
||||
)
|
||||
)
|
||||
== 1
|
||||
), "All eagle layers should belong to the same kv cache group"
|
||||
), "All drafting layers should belong to the same kv cache group"
|
||||
|
||||
def _pad_batch_across_dp(
|
||||
self,
|
||||
@@ -1308,6 +1353,21 @@ class EagleProposer:
|
||||
return num_tokens_dp_padded, num_toks_across_dp
|
||||
|
||||
|
||||
class EagleProposer(SpecDecodeBaseProposer):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config,
|
||||
device,
|
||||
pass_hidden_states_to_model=True,
|
||||
runner=runner,
|
||||
)
|
||||
|
||||
|
||||
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
||||
# to sample the draft tokens. We will use this after we find a way to manage
|
||||
# the draft prob tensor.
|
||||
|
||||
@@ -145,6 +145,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.draft_model import DraftModelProposer
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
@@ -432,10 +433,20 @@ class GPUModelRunner(
|
||||
# layers in the draft model.
|
||||
if self.speculative_config and get_pp_group().is_last_rank:
|
||||
self.drafter: (
|
||||
NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer
|
||||
NgramProposer
|
||||
| SuffixDecodingProposer
|
||||
| EagleProposer
|
||||
| DraftModelProposer
|
||||
| MedusaProposer
|
||||
)
|
||||
if self.speculative_config.method == "ngram":
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
elif self.speculative_config.uses_draft_model():
|
||||
self.drafter = DraftModelProposer(
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self,
|
||||
)
|
||||
elif self.speculative_config.method == "suffix":
|
||||
self.drafter = SuffixDecodingProposer(self.vllm_config)
|
||||
elif self.speculative_config.use_eagle():
|
||||
@@ -3443,10 +3454,13 @@ class GPUModelRunner(
|
||||
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
|
||||
<= self.effective_drafter_max_model_len
|
||||
)
|
||||
if spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch:
|
||||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||
use_gpu_toks = (
|
||||
spec_config.use_eagle() or spec_config.uses_draft_model()
|
||||
) and not spec_config.disable_padded_drafter_batch
|
||||
if use_gpu_toks:
|
||||
# EAGLE/DraftModel speculative decoding can use the GPU sampled tokens
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
if input_fits_in_drafter:
|
||||
propose_draft_token_ids(sampled_token_ids)
|
||||
@@ -3679,8 +3693,8 @@ class GPUModelRunner(
|
||||
target_hidden_states=hidden_states,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
elif spec_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
elif spec_config.use_eagle() or spec_config.uses_draft_model():
|
||||
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
|
||||
|
||||
if spec_config.disable_padded_drafter_batch:
|
||||
# When padded-batch is disabled, the sampled_token_ids should be
|
||||
@@ -4475,8 +4489,12 @@ class GPUModelRunner(
|
||||
else:
|
||||
hidden_states = outputs
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
if self.speculative_config and (
|
||||
self.speculative_config.use_eagle()
|
||||
or self.speculative_config.uses_draft_model()
|
||||
):
|
||||
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
|
||||
assert self.speculative_config is not None
|
||||
# Eagle currently only supports PIECEWISE cudagraphs.
|
||||
# Therefore only use cudagraphs if the main model uses PIECEWISE
|
||||
# NOTE(lucas): this is a hack, need to clean up.
|
||||
@@ -5652,8 +5670,11 @@ class GPUModelRunner(
|
||||
kv_cache_config, kernel_block_sizes
|
||||
)
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
if self.speculative_config and (
|
||||
self.speculative_config.use_eagle()
|
||||
or self.speculative_config.uses_draft_model()
|
||||
):
|
||||
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
|
||||
# validate all draft model layers belong to the same kv cache
|
||||
# group
|
||||
self.drafter.validate_same_kv_cache_group(kv_cache_config)
|
||||
|
||||
@@ -352,7 +352,7 @@ def bind_kv_cache(
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
for layer_name in layer_names:
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
|
||||
Reference in New Issue
Block a user