diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index a84d5b116..1e3e310e7 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -54,7 +54,7 @@ def parse_args(): "--method", type=str, default="eagle", - choices=["ngram", "eagle", "eagle3", "mtp"], + choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"], ) parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--prompt-lookup-max", type=int, default=5) @@ -70,7 +70,11 @@ def parse_args(): parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) + parser.add_argument("--draft-model", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") + parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) + parser.add_argument("--disable-padded-drafter-batch", action="store_true") + parser.add_argument("--max-num-seqs", type=int, default=None) return parser.parse_args() @@ -111,6 +115,7 @@ def main(args): "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, + "disable_padded_drafter_batch": args.disable_padded_drafter_batch, } elif args.method == "ngram": speculative_config = { @@ -119,6 +124,15 @@ def main(args): "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, } + elif args.method == "draft_model": + assert args.draft_model is not None and args.draft_model != "" + speculative_config = { + "method": args.method, + "model": args.draft_model, + "num_speculative_tokens": args.num_spec_tokens, + "enforce_eager": args.enforce_eager, + "max_model_len": args.max_model_len, + } elif args.method == "mtp": speculative_config = { "method": "mtp", @@ -133,12 +147,13 @@ def main(args): tensor_parallel_size=args.tp, enable_chunked_prefill=args.enable_chunked_prefill, enforce_eager=args.enforce_eager, - gpu_memory_utilization=0.9, + gpu_memory_utilization=args.gpu_memory_utilization, speculative_config=speculative_config, disable_log_stats=False, max_model_len=args.max_model_len, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, + max_num_seqs=args.max_num_seqs, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) diff --git a/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py b/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py index 507320612..90ffbef31 100644 --- a/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py +++ b/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py @@ -4,13 +4,13 @@ import asyncio import copy import logging import os -import re import socket import threading import uuid import aiohttp import msgpack +import regex as re import zmq from quart import Quart, make_response, request diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index a25114a4d..bc635cee6 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random +from collections.abc import Iterable +from dataclasses import dataclass from typing import Any import pytest @@ -10,32 +12,45 @@ from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark from vllm import LLM, SamplingParams from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR +from vllm.benchmarks.datasets import InstructCoderDataset +from vllm.config.vllm import VllmConfig from vllm.distributed import cleanup_dist_env_and_memory +from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform +from vllm.v1.metrics.reader import Metric +from vllm.v1.spec_decode.draft_model import ( + create_vllm_config_for_draft_model, + merge_toks_kernel, +) MTP_SIMILARITY_RATE = 0.8 def _skip_if_insufficient_gpus_for_tp(tp_size: int): """Skip test if available GPUs < tp_size on ROCm.""" - if current_platform.is_rocm(): - available_gpus = torch.cuda.device_count() - if available_gpus < tp_size: - pytest.skip( - f"Test requires {tp_size} GPUs, but only {available_gpus} available" - ) + available_gpus = torch.cuda.device_count() + if available_gpus < tp_size: + pytest.skip( + f"Test requires {tp_size} GPUs, but only {available_gpus} available" + ) -def get_test_prompts(mm_enabled: bool): +Messages = list[dict[str, Any]] + + +def get_test_prompts( + mm_enabled: bool, quiet: bool = False, num_prompts: int = 100 +) -> list[Messages]: prompt_types = ["repeat", "sentence"] if mm_enabled: prompt_types.append("mm") - num_prompts = 100 prompts = [] random.seed(0) random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) - print(f"Prompt types: {random_prompt_type_choices}") + + if not quiet: + print(f"Prompt types: {random_prompt_type_choices}") # Generate a mixed batch of prompts, some of which can be easily # predicted by n-gram matching and some which likely cannot. @@ -75,11 +90,27 @@ def get_test_prompts(mm_enabled: bool): return prompts +def get_instruct_coder_messages(n: int) -> list[Messages]: + dataset = InstructCoderDataset( + dataset_path="likaixin/InstructCoder", dataset_split="train" + ) + prompts: Iterable[str] = dataset.sample_prompts(n=n) + return [[{"role": "user", "content": prompt}] for prompt in prompts] + + @pytest.fixture def sampling_config(): + return greedy_sampling() + + +def greedy_sampling() -> SamplingParams: return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) +def stochastic_sampling() -> SamplingParams: + return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False) + + @pytest.fixture def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" @@ -583,3 +614,269 @@ def test_mtp_correctness( del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() + + +@dataclass +class ArgsTest: + target_model: str + draft_model: str + sampling_config: SamplingParams + num_speculative_tokens: int + expected_acceptance_rate: float + expected_acceptance_len: float + # Defaults + target_tensor_parallel_size: int = 1 + draft_tensor_parallel_size: int = 1 + max_model_len: int = 1024 + gpu_memory_utilization: float = 0.5 + dataset: str = "test_prompts" + num_prompts: int = 100 + + +cases = [ + # Same model for draft and target, greedy sampling. + ArgsTest( + target_model="Qwen/Qwen3-0.6B", + draft_model="Qwen/Qwen3-0.6B", + sampling_config=greedy_sampling(), + num_speculative_tokens=3, # K + expected_acceptance_len=3 + 1, # K + 1 + expected_acceptance_rate=1.0, + ), + # Smaller draft model, stochastic sampling. + ArgsTest( + target_model="Qwen/Qwen3-1.7B", + draft_model="Qwen/Qwen3-0.6B", + sampling_config=stochastic_sampling(), + num_speculative_tokens=3, + expected_acceptance_len=2.8 + 1, + expected_acceptance_rate=0.9, + ), +] + + +@pytest.mark.parametrize("args", cases) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool): + assert_draft_model_correctness(args, enforce_eager) + + +def test_draft_model_realistic_example(): + args = ArgsTest( + target_model="Qwen/Qwen3-1.7B", + draft_model="Qwen/Qwen3-0.6B", + dataset="likaixin/InstructCoder", + num_speculative_tokens=3, + sampling_config=greedy_sampling(), + # values below are not derived, but just prevent a regression + expected_acceptance_len=2.8, + expected_acceptance_rate=0.55, + ) + assert_draft_model_correctness(args, enforce_eager=False) + + +@pytest.mark.parametrize( + "models", + [ + # target_model, draft_model + ("Qwen/Qwen3-1.7B-FP8", "Qwen/Qwen3-0.6B"), # target quantized + ("Qwen/Qwen3-1.7B", "Qwen/Qwen3-0.6B-FP8"), # draft quantized + ], + ids=["target_quantized", "draft_quantized"], +) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool): + tgt_model, draft_model = models + sd_case = ArgsTest( + target_model=tgt_model, + draft_model=draft_model, + **some_high_acceptance_metrics(), + ) + assert_draft_model_correctness(sd_case, enforce_eager) + + +def test_draft_model_tensor_parallelism(): + """Ensure spec decode works when running with TP > 1.""" + _skip_if_insufficient_gpus_for_tp(2) + sd_case = ArgsTest( + target_model="Qwen/Qwen3-1.7B", + target_tensor_parallel_size=2, + draft_model="Qwen/Qwen3-0.6B", + draft_tensor_parallel_size=2, + **some_high_acceptance_metrics(), + ) + assert_draft_model_correctness(sd_case, enforce_eager=False) + + +def test_draft_model_engine_args_tensor_parallelism(): + """Ensure the vllm_config for the draft model is created correctly, + and independently of the target model (quantization, TP, etc.)""" + _skip_if_insufficient_gpus_for_tp(2) + + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B-FP8", # <<< tgt quantized + tensor_parallel_size=2, + speculative_config={ + "model": "Qwen/Qwen3-0.6B", # <<< draft not quantized + "method": "draft_model", + "num_speculative_tokens": 3, + "draft_tensor_parallel_size": 1, # <<< valid arg name + }, + ) + tgt_vllm_config: VllmConfig = engine_args.create_engine_config() + assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2 + assert tgt_vllm_config.quant_config.get_name() == "fp8" + + draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config) + assert draft_vllm_config.parallel_config.tensor_parallel_size == 1 + assert draft_vllm_config.quant_config is None + + +def test_draft_model_engine_args_rejects_invalid_tp_argname(): + """The user should pass "draft_tensor_parallel_size" rather than + "tensor_parallel_size". We enforce this with validation.""" + + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B", + tensor_parallel_size=1, + speculative_config={ + "model": "Qwen/Qwen3-0.6B", + "method": "draft_model", + "num_speculative_tokens": 3, + "tensor_parallel_size": 1, # <<< invalid arg name + }, + ) + with pytest.raises(ValueError): + engine_args.create_engine_config() + + +def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): + """Compare the outputs using and not using speculative decoding. + In the greedy decoding case, the outputs must match EXACTLY.""" + test_prompts: list[Messages] = get_messages( + dataset=args.dataset, n=args.num_prompts + ) + + spec_llm = LLM( + model=args.target_model, + speculative_config={ + "model": args.draft_model, + "method": "draft_model", + "num_speculative_tokens": args.num_speculative_tokens, + "max_model_len": args.max_model_len, + "enforce_eager": enforce_eager, + "draft_tensor_parallel_size": args.draft_tensor_parallel_size, + "max_num_seqs": 100, # limit cudagraph capture runtime + }, + max_model_len=args.max_model_len, + gpu_memory_utilization=args.gpu_memory_utilization, + tensor_parallel_size=args.target_tensor_parallel_size, + enforce_eager=enforce_eager, + disable_log_stats=False, # enables get_metrics() + ) + # we don't check the outputs, only check the metrics + spec_llm.chat(test_prompts, args.sampling_config) + metrics = spec_llm.get_metrics() + + acceptance_rate: float = compute_acceptance_rate(metrics) + acceptance_len: float = compute_acceptance_len(metrics) + del spec_llm # CLEANUP + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + print( + f"spec-decode: target={args.target_model}, draft={args.draft_model}, " + f"temperature={args.sampling_config.temperature:.2f}, " + f"acceptance_rate={acceptance_rate:.2f}, " + f"acceptance_len={acceptance_len:.2f}, " + ) + + assert acceptance_rate >= args.expected_acceptance_rate + assert acceptance_len >= args.expected_acceptance_len + + +def get_messages(dataset: str, n: int) -> list[Messages]: + if dataset == "test_prompts": + return get_test_prompts(mm_enabled=False, quiet=True, num_prompts=n) + elif dataset == "likaixin/InstructCoder": + return get_instruct_coder_messages(n=n) + else: + raise NotImplementedError(f"Dataset '{dataset}' not implemented") + + +def some_high_acceptance_metrics() -> dict: + return { + "sampling_config": greedy_sampling(), + "num_speculative_tokens": 3, + "expected_acceptance_len": 2.90 + 1, + "expected_acceptance_rate": 0.90, + } + + +def test_merge_toks_kernel(): + device = "cuda" + merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2 + merged = torch.full((merged_len,), -100, device=device) # -100 is arbitrary + is_rejected_tok = torch.full((merged_len,), True, device=device) + grid = (2,) + merge_toks_kernel[grid]( + target_toks_ptr=torch.tensor([0, 1, 2, 0, 1], device=device), + next_toks_ptr=torch.tensor([3, 2], device=device), + query_start_locs_ptr=torch.tensor([0, 3], device=device), + query_end_locs_ptr=torch.tensor([2, 4], device=device), + out_ptr_merged_toks=merged, + out_ptr_is_rejected_tok=is_rejected_tok, + target_toks_size=5, + rejected_tok_fill=-1, + ) + expected_merged = torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device) + assert torch.allclose(merged, expected_merged) + + expected_rejected_toks = torch.tensor([False] * merged_len, device=device) + assert torch.allclose(is_rejected_tok, expected_rejected_toks) + + +def test_merge_toks_kernel_with_rejected_tokens(): + device = "cuda" + merged_size = 9 + 2 # len(target_toks) = 9, batch_size = 2 + merged = torch.full((merged_size,), -100, device=device) + is_rejected_tok = torch.full((merged_size,), True, device=device) + grid = (2,) + merge_toks_kernel[grid]( + # rejected tokens + # ↓ ↓ ↓ ↓ + target_toks_ptr=torch.tensor([0, 1, 2, 13, 14, 15, 0, 1, 22], device=device), + next_toks_ptr=torch.tensor([3, 2], device=device), + query_start_locs_ptr=torch.tensor([0, 6], device=device), + query_end_locs_ptr=torch.tensor([2, 7], device=device), + out_ptr_merged_toks=merged, + out_ptr_is_rejected_tok=is_rejected_tok, + target_toks_size=9, + rejected_tok_fill=-1, + ) + expected_merged = torch.tensor([0, 1, 2, 3, -1, -1, -1, 0, 1, 2, -1], device=device) + assert torch.allclose(merged, expected_merged) + + expected_rejected_toks = torch.tensor( + [False, False, False, False, True, True, True, False, False, False, True], + device=device, + ) + assert torch.allclose(is_rejected_tok, expected_rejected_toks) + + +def compute_acceptance_rate(metrics: list[Metric]) -> float: + name2metric = {metric.name: metric for metric in metrics} + n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore + if n_draft_toks == 0: + return float("nan") + n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore + return n_accepted_toks / n_draft_toks + + +def compute_acceptance_len(metrics: list[Metric]) -> float: + name2metric = {metric.name: metric for metric in metrics} + n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore + n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore + if n_drafts == 0: + return 1 + return 1 + (n_accepted_toks / n_drafts) diff --git a/tests/v1/worker/test_utils.py b/tests/v1/worker/test_utils.py index a13e11d71..d223ad6e0 100644 --- a/tests/v1/worker/test_utils.py +++ b/tests/v1/worker/test_utils.py @@ -55,3 +55,38 @@ def test_bind_kv_cache_non_attention(default_vllm_config): assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"] assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"] + + +def test_bind_kv_cache_draft_model(default_vllm_config): + from vllm.attention.layer import Attention + + layer_names = [ + "model.layers.0.attn", + "model.layers.1.attn", + "draft_model.layers.0.attn", + "draft_model.layers.1.attn", + ] + ctx = { + layer_name: Attention(32, 128, 0.1, prefix=layer_name) + for layer_name in layer_names + } + kv_cache = {layer_name: torch.zeros((1,)) for layer_name in layer_names} + runner_kv_caches: list[torch.Tensor] = [] + bind_kv_cache(kv_cache, ctx, runner_kv_caches) + + assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"] + assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"] + assert ( + ctx["draft_model.layers.0.attn"].kv_cache[0] + is kv_cache["draft_model.layers.0.attn"] + ) + assert ( + ctx["draft_model.layers.1.attn"].kv_cache[0] + is kv_cache["draft_model.layers.1.attn"] + ) + + # caches are ordered by layer_index, interleaving target and draft model + assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"] + assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"] + assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"] + assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"] diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 20d13b167..60062aa5d 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -2593,17 +2593,10 @@ class InstructCoderDataset(HuggingFaceDataset): request_id_prefix: str = "", no_oversample: bool = False, **kwargs, - ) -> list: + ) -> list[SampleRequest]: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for i, item in enumerate(self.data): - if len(sampled_requests) >= num_requests: - break - prompt = ( - f"{item['input']}\n\n{item['instruction']} Just output " - "the code, do not include any explanation." - ) - + for i, prompt in enumerate(self.sample_prompts(n=num_requests)): # apply template if not skip_chat_template: prompt = tokenizer.apply_chat_template( @@ -2626,6 +2619,14 @@ class InstructCoderDataset(HuggingFaceDataset): ) return sampled_requests + def sample_prompts(self, n: int) -> Iterator[str]: + for item in self.data.take(n): + prompt = ( + f"{item['input']}\n\n{item['instruction']} Just output " + "the code, do not include any explanation." + ) + yield prompt + # ----------------------------------------------------------------------------- # MT-Bench Dataset Implementation diff --git a/vllm/benchmarks/lib/ready_checker.py b/vllm/benchmarks/lib/ready_checker.py index 5649faf05..0cfd053f5 100644 --- a/vllm/benchmarks/lib/ready_checker.py +++ b/vllm/benchmarks/lib/ready_checker.py @@ -8,8 +8,12 @@ import time import aiohttp from tqdm.asyncio import tqdm +from vllm.logger import init_logger + from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput +logger = init_logger(__name__) + async def wait_for_endpoint( request_func: RequestFunc, @@ -61,6 +65,8 @@ async def wait_for_endpoint( if output.success: pbar.close() return output + else: + logger.warning("Endpoint is not ready. Error='%s'", output.error) except aiohttp.ClientConnectorError: pass diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 26b913b47..4cdf84897 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -3,6 +3,7 @@ import os from collections.abc import Callable +from dataclasses import replace from typing import TYPE_CHECKING, Any, Literal import torch @@ -709,3 +710,6 @@ class ParallelConfig: ) return self + + def replace(self, **kwargs) -> Self: + return replace(self, **kwargs) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index c89a8f0c5..8f34dadae 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -77,6 +77,9 @@ class SpeculativeConfig: draft_tensor_parallel_size: int | None = Field(default=None, ge=1) """The degree of the tensor parallelism for the draft model. Can only be 1 or the same as the target model's tensor parallel size.""" + tensor_parallel_size: int | None = None + """Users should pass "draft_tensor_parallel_size". This parameter's purpose is to + warn users when they mistakenly provide the wrong argument.""" # Draft model configuration quantization: me_quant.QuantizationMethods | None = None @@ -397,13 +400,11 @@ class SpeculativeConfig: "one layer. Might need some code changes " "to support multiple layers." ) + elif self.method == "draft_model": + pass else: - self.method = "draft_model" raise NotImplementedError( - "Speculative decoding with draft model is not " - "supported yet. Please consider using other " - "speculative decoding methods such as ngram, medusa, " - "eagle, or mtp." + f"Unsupported speculative method: '{self.method}'" ) # Replace hf_config for EAGLE draft_model @@ -631,6 +632,12 @@ class SpeculativeConfig: @model_validator(mode="after") def _verify_args(self) -> Self: + if self.tensor_parallel_size is not None: + raise ValueError( + "'tensor_parallel_size' is not a valid argument in the " + "speculative_config. Please pass 'draft_tensor_parallel_size' instead." + ) + if self.num_speculative_tokens is None: raise ValueError( "num_speculative_tokens must be provided with " @@ -669,12 +676,32 @@ class SpeculativeConfig: f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 f"Got {self.target_model_config.hf_text_config.model_type=}" ) - + self.verify_equal_vocab_size_if_draft_model() return self + def verify_equal_vocab_size_if_draft_model(self): + if ( + self.method == "draft_model" + and self.target_model_config is not None + and self.draft_model_config is not None + ): + target_vocab_size = self.target_model_config.get_vocab_size() + draft_vocab_size = self.draft_model_config.get_vocab_size() + if target_vocab_size != draft_vocab_size: + raise ValueError( + f"Target and draft model should have the same vocabulary size. " + f"Target model vocab_size={target_vocab_size}. " + f"Draft model vocab_size={draft_vocab_size}. " + f"Using models with different tokenizers can cause out-of-bounds " + f"errors during speculative decoding." + ) + def use_eagle(self) -> bool: return self.method in ("eagle", "eagle3", "mtp") + def uses_draft_model(self) -> bool: + return self.method == "draft_model" + def __repr__(self) -> str: method = self.method model = None if method in ("ngram", "suffix") else self.draft_model_config.model diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0181cb1f0..3d42205ca 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1214,10 +1214,19 @@ class VllmConfig: compilation_config = self.compilation_config computed_compile_ranges_split_points = [] - # The upper bound of the compile ranges is the max_num_batched_tokens - max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - if max_num_batched_tokens is not None: - computed_compile_ranges_split_points.append(max_num_batched_tokens) + # The upper bound of the compile ranges is the max_num_batched_tokens. + # For speculative decoding with draft model, the compile range must be extended + # by 1 for each sequence. + compile_range_end = self.scheduler_config.max_num_batched_tokens + if compile_range_end is not None: + do_extend: bool = ( + self.speculative_config is not None + and self.speculative_config.uses_draft_model() + ) + if do_extend: + compile_range_end += self.scheduler_config.max_num_seqs + + computed_compile_ranges_split_points.append(compile_range_end) # Add the compile ranges for flashinfer if compilation_config.pass_config.fuse_allreduce_rms: @@ -1228,10 +1237,7 @@ class VllmConfig: self.model_config.get_hidden_size() * self.model_config.dtype.itemsize ) - if ( - max_num_batched_tokens is not None - and max_token_num < max_num_batched_tokens - ): + if compile_range_end is not None and max_token_num < compile_range_end: computed_compile_ranges_split_points.append(max_token_num) else: logger.debug( @@ -1243,11 +1249,7 @@ class VllmConfig: for x in compilation_config.compile_ranges_split_points: assert isinstance(x, int) assert x > 0, f"Invalid compile range split point: {x}" - if ( - max_num_batched_tokens is not None - and x < max_num_batched_tokens - and x > 1 - ): + if compile_range_end is not None and x < compile_range_end and x > 1: computed_compile_ranges_split_points.append(x) compilation_config.compile_ranges_split_points = sorted( computed_compile_ranges_split_points @@ -1316,6 +1318,14 @@ class VllmConfig: path = self.compilation_config.debug_dump_path / append_path return path + def replace(self, **kwargs): + """ + Replace attributes of the config, and 'recompute' the config. + dataclass.replace() calls __init__() and __post_init__(), source: + https://docs.python.org/3/library/dataclasses.html#dataclasses.replace + """ + return replace(self, **kwargs) + def __str__(self): return ( f"model={self.model_config.model!r}, " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cb82be6b6..798b5fb8e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1776,21 +1776,6 @@ class EngineArgs: ): _raise_unsupported_error(feature_name="Concurrent Partial Prefill") - # N-gram, Medusa, and Eagle are supported for speculative decoding. - if self.speculative_config is not None: - # speculative_config could still be a dict at this point - if isinstance(self.speculative_config, dict): - method = self.speculative_config.get("method", None) - else: - method = self.speculative_config.method - - if method == "draft_model": - raise NotImplementedError( - "Draft model speculative decoding is not supported yet. " - "Please consider using other speculative decoding methods " - "such as ngram, medusa, eagle, or mtp." - ) - if self.pipeline_parallel_size > 1: supports_pp = getattr( self.distributed_executor_backend, "supports_pp", False diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 052d2cfc1..e1d8d2ead 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -124,12 +124,17 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: def get_model( - *, vllm_config: VllmConfig, model_config: ModelConfig | None = None + *, + vllm_config: VllmConfig, + model_config: ModelConfig | None = None, + prefix: str = "", ) -> nn.Module: loader = get_model_loader(vllm_config.load_config) if model_config is None: model_config = vllm_config.model_config - return loader.load_model(vllm_config=vllm_config, model_config=model_config) + return loader.load_model( + vllm_config=vllm_config, model_config=model_config, prefix=prefix + ) __all__ = [ diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 2238b0cfe..4b89b3481 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -36,7 +36,7 @@ class BaseModelLoader(ABC): raise NotImplementedError def load_model( - self, vllm_config: VllmConfig, model_config: ModelConfig + self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = "" ) -> nn.Module: """Load a model with the given configurations.""" device_config = vllm_config.device_config @@ -48,7 +48,7 @@ class BaseModelLoader(ABC): with set_default_torch_dtype(model_config.dtype): with target_device: model = initialize_model( - vllm_config=vllm_config, model_config=model_config + vllm_config=vllm_config, model_config=model_config, prefix=prefix ) log_model_inspection(model) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 7f94bd234..e1fb99a5a 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -335,7 +335,7 @@ class GGUFModelLoader(BaseModelLoader): ) def load_model( - self, vllm_config: VllmConfig, model_config: ModelConfig + self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = "" ) -> nn.Module: device_config = vllm_config.device_config local_model_path = self._prepare_weights(model_config) @@ -364,7 +364,7 @@ class GGUFModelLoader(BaseModelLoader): target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config) + model = initialize_model(vllm_config=vllm_config, prefix=prefix) self.load_weights(model, model_config) process_weights_after_loading(model, model_config, target_device) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 2b3704cfe..a3e3c9fd0 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -68,6 +68,7 @@ class TensorizerLoader(BaseModelLoader): def _load_model_serialized_cpu( self, vllm_config: VllmConfig, + prefix: str = "", ) -> nn.Module: """Load a serialized model with tensorizer to the CPU. @@ -80,7 +81,7 @@ class TensorizerLoader(BaseModelLoader): model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = initialize_model(vllm_config=vllm_config) + model = initialize_model(vllm_config=vllm_config, prefix=prefix) model.load_weights(self._get_weights_iterator()) return model.eval() @@ -112,7 +113,7 @@ class TensorizerLoader(BaseModelLoader): model.load_weights(self._get_weights_iterator()) def load_model( - self, vllm_config: VllmConfig, model_config: ModelConfig + self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = "" ) -> nn.Module: parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) @@ -134,7 +135,7 @@ class TensorizerLoader(BaseModelLoader): ) self.load_weights(model, model_config) return model - return self._load_model_serialized_cpu(vllm_config=vllm_config) + return self._load_model_serialized_cpu(vllm_config=vllm_config, prefix=prefix) @staticmethod def save_model( diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 5ea8f0e62..b4dcea105 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, replace from enum import Enum from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args @@ -329,6 +329,16 @@ class CommonAttentionMetadata: _num_computed_tokens_cache: torch.Tensor | None = None + def batch_size(self) -> int: + return self.seq_lens.shape[0] + + def naive_query_lens(self) -> torch.Tensor: + """Naive because it assumes that query ends where the next query starts.""" + return self.query_start_loc[1:] - self.query_start_loc[:-1] + + def replace(self, **kwargs) -> "CommonAttentionMetadata": + return replace(self, **kwargs) + @property @deprecated( """ diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 1c254b836..82321c000 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -818,3 +818,35 @@ def get_dcp_local_seq_lens( ) dcp_local_seq_lens = base + remainder return dcp_local_seq_lens.squeeze(1) + + +def extend_all_queries_by_1( + common_attn_metadata: CommonAttentionMetadata, + arange: torch.Tensor, + new_slot_mapping: torch.Tensor, +) -> CommonAttentionMetadata: + """ + Creates a new CommonAttentionMetadata with all query lengths increased by 1. + Also all seq lens are increased by 1. + This is useful e.g. in speculative decoding with draft models, where we + extend each sequence by 1 token. + The slot mapping is computed externally, as it requires more information. + """ + cad = common_attn_metadata + # query start loc must be increased by [+0, +1, +2, ..., +batch_size] + new_query_start_loc = cad.query_start_loc + arange[: len(cad.query_start_loc)] + new_query_start_loc_cpu = cad.query_start_loc_cpu + torch.arange( + len(cad.query_start_loc_cpu), dtype=torch.int32 + ) + new_cad = cad.replace( + query_start_loc=new_query_start_loc, + query_start_loc_cpu=new_query_start_loc_cpu, + seq_lens=cad.seq_lens + 1, + # each request is extended by 1 token -> batch_size tokens are added + num_actual_tokens=cad.num_actual_tokens + cad.batch_size(), + # All query lens increase by 1, so max query len increases by 1 + max_query_len=cad.max_query_len + 1, + max_seq_len=cad.max_seq_len + 1, + slot_mapping=new_slot_mapping, + ) + return new_cad diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a6d6ae93e..0cb65bd0f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -208,6 +208,8 @@ class Scheduler(SchedulerInterface): if speculative_config.use_eagle(): self.use_eagle = True self.num_lookahead_tokens = self.num_spec_tokens + if speculative_config.uses_draft_model(): + self.num_lookahead_tokens = self.num_spec_tokens # Create the KV cache manager. self.kv_cache_manager = KVCacheManager( diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py new file mode 100644 index 000000000..5a54074dd --- /dev/null +++ b/vllm/v1/spec_decode/draft_model.py @@ -0,0 +1,271 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch + +from vllm.attention.layer import Attention +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.speculative import SpeculativeConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.triton_utils import tl, triton +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + extend_all_queries_by_1, +) +from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer + +logger = init_logger(__name__) + + +class DraftModelProposer(SpecDecodeBaseProposer): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__( + vllm_config=vllm_config, + device=device, + pass_hidden_states_to_model=False, + runner=runner, + ) + self._raise_if_multimodal() + self._raise_if_mrope() + self._raise_if_padded_drafter_batch_disabled() + self._raise_if_vocab_size_mismatch() + self._raise_if_draft_tp_mismatch() + + def _block_size(self) -> int: + builder = self._get_attention_metadata_builder() + return builder.kv_cache_spec.block_size + + def _raise_if_multimodal(self): + if self.supports_mm_inputs: + raise NotImplementedError( + "Speculative Decoding with draft models " + "does not support multimodal models yet" + ) + + def _raise_if_mrope(self): + if self.draft_model_config.uses_mrope: + raise NotImplementedError( + "Speculative Decoding with draft models does not support M-RoPE yet" + ) + + def _raise_if_padded_drafter_batch_disabled(self): + if self.vllm_config.speculative_config.disable_padded_drafter_batch: + raise NotImplementedError( + "Speculative Decoding with draft models only supports " + "padded drafter batch. Please don't pass --disable-padded-drafter-batch" + " in the speculative_config." + ) + + def _raise_if_vocab_size_mismatch(self): + self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model() + + def _raise_if_draft_tp_mismatch(self): + # Note(Tomas Ruiz) If we run the target model with TP > 1 and + # the draft model with TP = 1, then the different TP ranks collide. + # Specifically when all ranks compile the draft model on rank 0 + # (because TP=1), then the torch compile cache is overwritten and corrupted. + # We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414 + # To prevent this error, we assert that both TP sizes must be the same. + spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config + tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size + draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size + if draft_tp != tgt_tp: + raise ValueError( + f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' " + f"must be the same. Got {draft_tp} and {tgt_tp}. " + "Please pass 'draft_tensor_parallel_size' in the speculative_config." + ) + + def set_inputs_first_pass( + self, + target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, + target_positions: torch.Tensor, + last_token_indices: torch.Tensor | None, + cad: CommonAttentionMetadata, + num_rejected_tokens_gpu: torch.Tensor | None, + ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: + batch_size = cad.batch_size() + grid = (batch_size,) + start_locs = cad.query_start_loc[:-1] + end_locs = cad.query_start_loc[1:] - 1 + if num_rejected_tokens_gpu is not None: + end_locs -= num_rejected_tokens_gpu + + num_tokens = target_token_ids.shape[0] + batch_size + is_rejected_tok = torch.empty( + (num_tokens,), device=self.input_ids.device, dtype=torch.bool + ) + merge_toks_kernel[grid]( + target_toks_ptr=target_token_ids, + next_toks_ptr=next_token_ids, + query_start_locs_ptr=start_locs, + query_end_locs_ptr=end_locs, + out_ptr_merged_toks=self.input_ids, + out_ptr_is_rejected_tok=is_rejected_tok, + target_toks_size=target_token_ids.shape[0], + # passing a negative rejected_tok_fill value will raise an error + # when the value is used to index into embeddings. + # Therefore, we pass a valid integer, e.g. 0. + rejected_tok_fill=0, + ) + merge_toks_kernel[grid]( + target_toks_ptr=target_positions, + next_toks_ptr=target_positions[end_locs] + 1, + query_start_locs_ptr=start_locs, + query_end_locs_ptr=end_locs, + out_ptr_merged_toks=self.positions, + out_ptr_is_rejected_tok=is_rejected_tok, + target_toks_size=target_positions.shape[0], + rejected_tok_fill=0, + ) + + # recompute slot mapping + new_slot_mapping = compute_new_slot_mapping( + cad=cad, + new_positions=self.positions[:num_tokens], + is_rejected_token_mask=is_rejected_tok, + block_size=self._block_size(), + max_model_len=self.max_model_len, + ) + # update common_attn_metadata + new_cad: CommonAttentionMetadata = extend_all_queries_by_1( + cad, + arange=self.arange, + new_slot_mapping=new_slot_mapping, + ) + + new_last_token_indices = new_cad.query_start_loc[1:] - 1 + if num_rejected_tokens_gpu is not None: + new_last_token_indices -= num_rejected_tokens_gpu + + return num_tokens, new_last_token_indices, new_cad + + def load_model(self, target_model: Any) -> None: + """Takes target_model to satisfy the type checker.""" + + # This must be computed before loading the draft model + # because that mutates the forward_context of the vllm_config + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + ) + + from vllm.compilation.backends import set_model_tag + + draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model( + target_model_vllm_config=self.vllm_config + ) + logger.info( + "Starting to load draft model %s. TP=%d, rank=%d", + draft_vllm_config.model_config.model, + draft_vllm_config.parallel_config.tensor_parallel_size, + draft_vllm_config.parallel_config.rank, + ) + with set_model_tag("draft_model"): + self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model") + + # This must be computed after loading the draft model + # because that mutates the forward_context of the vllm_config + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + - target_attn_layer_names + ) + self.attn_layer_names = list(draft_attn_layer_names) + + +def create_vllm_config_for_draft_model( + target_model_vllm_config: VllmConfig, +) -> VllmConfig: + """The vllm_config is configured for the target model, e.g. + its quant_config and parallel_config. But the draft model is potentially + quantized differently, and has potentially different tensor_parallel_size. + This function creates a new vllm_config configured for the draft model. + The vllm_config is useful when loading the draft model with get_model(). + """ + old = target_model_vllm_config + new_parallel_config = old.speculative_config.draft_parallel_config.replace( + rank=old.parallel_config.rank + ) + new: VllmConfig = old.replace( + quant_config=None, # quant_config is recomputed in __init__() + model_config=old.speculative_config.draft_model_config, + parallel_config=new_parallel_config, + ) + return new + + +def compute_new_slot_mapping( + cad: CommonAttentionMetadata, + new_positions: torch.Tensor, + is_rejected_token_mask: torch.Tensor, + block_size: int, + max_model_len: int, +): + batch_size, n_blocks_per_req = cad.block_table_tensor.shape + req_indices = torch.arange(batch_size, device=cad.query_start_loc.device) + req_indices = torch.repeat_interleave( + req_indices, cad.naive_query_lens() + 1, output_size=len(new_positions) + ) + # Clamp the positions to prevent an out-of-bounds error when indexing + # into block_table_tensor. + clamped_positions = torch.clamp(new_positions, max=max_model_len - 1) + block_table_indices = ( + req_indices * n_blocks_per_req + clamped_positions // block_size + ) + block_nums = cad.block_table_tensor.view(-1)[block_table_indices] + block_offsets = clamped_positions % block_size + new_slot_mapping = block_nums * block_size + block_offsets + # Mask out the position ids that exceed the max model length. + exceeds_max_model_len = new_positions >= max_model_len + new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + # Mask out rejected tokens to prevent saves to the KV cache. + new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID) + return new_slot_mapping + + +@triton.jit +def merge_toks_kernel( + target_toks_ptr, + next_toks_ptr, + query_start_locs_ptr, + query_end_locs_ptr, + out_ptr_merged_toks, + out_ptr_is_rejected_tok, + target_toks_size, + rejected_tok_fill, +): + """ + Merges the `target_toks_ptr` and the `next_toks_ptr` into a new tensor + called `out_ptr_merged_toks`. Rejected tokens are those after the + `query_end_locs_ptr` and before the next `query_start_locs_ptr`. Fills the + rejected tokens positions with the value `rejected_tok_fill`. Also fills a mask + of the rejected tokens in `out_ptr_is_rejected_tok`. + """ + pid = tl.program_id(0) + start_loc = tl.load(query_start_locs_ptr + pid) + is_last_program = pid == tl.num_programs(0) - 1 + if is_last_program: + next_start_loc = target_toks_size.to(tl.int32) + else: + next_start_loc = tl.load(query_start_locs_ptr + pid + 1).to(tl.int32) + + end_loc = tl.load(query_end_locs_ptr + pid) + new_val = tl.load(next_toks_ptr + pid) + for i in range(start_loc, next_start_loc + 1): + if i <= end_loc: # copy existing tokens + old_val = tl.load(target_toks_ptr + i) + tl.store(out_ptr_merged_toks + pid + i, old_val) + tl.store(out_ptr_is_rejected_tok + pid + i, False) + elif i == end_loc + 1: # copy bonus token + tl.store(out_ptr_merged_toks + pid + i, new_val) + tl.store(out_ptr_is_rejected_tok + pid + i, False) + else: # fill rejected tokens + tl.store(out_ptr_merged_toks + pid + i, rejected_tok_fill) + tl.store(out_ptr_is_rejected_tok + pid + i, True) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b7693f4f7..ff34afb16 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -53,11 +53,12 @@ logger = init_logger(__name__) PADDING_SLOT_ID = -1 -class EagleProposer: +class SpecDecodeBaseProposer: def __init__( self, vllm_config: VllmConfig, device: torch.device, + pass_hidden_states_to_model: bool, runner=None, ): self.vllm_config = vllm_config @@ -65,6 +66,7 @@ class EagleProposer: assert self.speculative_config is not None self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method + self.pass_hidden_states_to_model = pass_hidden_states_to_model self.runner = runner self.device = device @@ -72,7 +74,11 @@ class EagleProposer: self.max_model_len = vllm_config.model_config.max_model_len self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.num_speculative_tokens = self.speculative_config.num_speculative_tokens - self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + # The drafter can get longer sequences than the target model. + max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.max_num_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size + ) self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's @@ -143,7 +149,6 @@ class EagleProposer: # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. - max_batch_size = vllm_config.scheduler_config.max_num_seqs max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) self.arange = torch.arange( max_num_slots_for_arange, device=device, dtype=torch.int32 @@ -245,11 +250,7 @@ class EagleProposer: mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, num_rejected_tokens_gpu: torch.Tensor | None = None, ) -> torch.Tensor: - num_tokens = target_token_ids.shape[0] - batch_size = next_token_ids.shape[0] - - if last_token_indices is None: - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + batch_size = common_attn_metadata.batch_size() if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) @@ -257,12 +258,17 @@ class EagleProposer: target_hidden_states ) assert target_hidden_states.shape[-1] == self.hidden_size - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[: num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids + + num_tokens, last_token_indices, common_attn_metadata = ( + self.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + last_token_indices=last_token_indices, + cad=common_attn_metadata, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + ) + ) assert self.runner is not None @@ -311,9 +317,10 @@ class EagleProposer: if num_tokens_across_dp is not None: num_tokens_across_dp[self.dp_rank] = num_input_tokens - # copy inputs to buffer for cudagraph - self._set_positions(num_tokens, target_positions) - self.hidden_states[:num_tokens] = target_hidden_states + if self.pass_hidden_states_to_model: + # target_hidden_states and self.hidden_states can have different + # hidden dims. E.g. large target model and small draft model. + self.hidden_states[:num_tokens] = target_hidden_states if self.supports_mm_inputs: mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) @@ -330,6 +337,14 @@ class EagleProposer: input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None + model_kwargs = { + "input_ids": input_ids, + "positions": self._get_positions(num_input_tokens), + "inputs_embeds": inputs_embeds, + } + if self.pass_hidden_states_to_model: + model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] + with set_forward_context( per_layer_attn_metadata, self.vllm_config, @@ -337,17 +352,13 @@ class EagleProposer: num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, ): - ret_hidden_states = self.model( - input_ids=input_ids, - positions=self._get_positions(num_input_tokens), - hidden_states=self.hidden_states[:num_input_tokens], - inputs_embeds=inputs_embeds, - ) - if self.method == "mtp": + ret_hidden_states = self.model(**model_kwargs) + if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states + sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -357,9 +368,9 @@ class EagleProposer: return draft_token_ids.view(-1, 1) if self.uses_mrope: - positions = target_positions[:, last_token_indices] + positions = self.positions[:, last_token_indices] else: - positions = target_positions[last_token_indices] + positions = self.positions[last_token_indices] if self.method in ( "deepseek_mtp", "ernie_mtp", @@ -527,6 +538,14 @@ class EagleProposer: inputs_embeds = None # Run the model. + model_kwargs = { + "input_ids": input_ids, + "positions": self._get_positions(input_batch_size), + "inputs_embeds": inputs_embeds, + } + if self.pass_hidden_states_to_model: + model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size] + with set_forward_context( per_layer_attn_metadata, self.vllm_config, @@ -534,17 +553,13 @@ class EagleProposer: num_tokens_across_dp=batch_size_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, ): - ret_hidden_states = self.model( - input_ids=input_ids, - positions=self._get_positions(input_batch_size), - hidden_states=self.hidden_states[:input_batch_size], - inputs_embeds=inputs_embeds, - ) - if self.method == "mtp": + ret_hidden_states = self.model(**model_kwargs) + if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = ret_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states + hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size]) draft_token_ids = logits.argmax(dim=-1) @@ -554,6 +569,34 @@ class EagleProposer: draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids + def set_inputs_first_pass( + self, + target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, + target_positions: torch.Tensor, + last_token_indices: torch.Tensor | None, + cad: CommonAttentionMetadata, + num_rejected_tokens_gpu: torch.Tensor | None, + ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: + if last_token_indices is None: + last_token_indices = cad.query_start_loc[1:] - 1 + + num_tokens = target_token_ids.shape[0] + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[: num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids + + # copy inputs to buffer for cudagraph + self._set_positions(num_tokens, target_positions) + + return num_tokens, last_token_indices, cad + + def model_returns_tuple(self) -> bool: + return self.method not in ("mtp", "draft_model") + def prepare_next_token_ids_cpu( self, sampled_token_ids: list[list[int]], @@ -1214,12 +1257,14 @@ class EagleProposer: input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - self.model( + kwargs = dict( input_ids=input_ids, positions=self._get_positions(num_input_tokens), - hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) + if self.pass_hidden_states_to_model: + kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] + self.model(**kwargs) def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: """Find and return the attention metadata builders for EAGLE layers. @@ -1264,8 +1309,8 @@ class EagleProposer: def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: """ - Validate that all eagle layers belong to the same KVCacheGroup. - Need this assumption to ensure all eagle layers can use the + Validate that all drafting layers belong to the same KVCacheGroup. + Need this assumption to ensure all drafting layers can use the same AttentionMetadata. May extend to multiple AttentionMetadata in the future. """ @@ -1283,7 +1328,7 @@ class EagleProposer: ) ) == 1 - ), "All eagle layers should belong to the same kv cache group" + ), "All drafting layers should belong to the same kv cache group" def _pad_batch_across_dp( self, @@ -1308,6 +1353,21 @@ class EagleProposer: return num_tokens_dp_padded, num_toks_across_dp +class EagleProposer(SpecDecodeBaseProposer): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__( + vllm_config, + device, + pass_hidden_states_to_model=True, + runner=runner, + ) + + # NOTE(woosuk): Currently, the below code is not used and we always use argmax # to sample the draft tokens. We will use this after we find a way to manage # the draft prob tensor. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 32a07d64a..12f74dbca 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -145,6 +145,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler +from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -432,10 +433,20 @@ class GPUModelRunner( # layers in the draft model. if self.speculative_config and get_pp_group().is_last_rank: self.drafter: ( - NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer + NgramProposer + | SuffixDecodingProposer + | EagleProposer + | DraftModelProposer + | MedusaProposer ) if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.uses_draft_model(): + self.drafter = DraftModelProposer( + vllm_config=self.vllm_config, + device=self.device, + runner=self, + ) elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): @@ -3443,10 +3454,13 @@ class GPUModelRunner( spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens <= self.effective_drafter_max_model_len ) - if spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch: - # EAGLE speculative decoding can use the GPU sampled tokens + use_gpu_toks = ( + spec_config.use_eagle() or spec_config.uses_draft_model() + ) and not spec_config.disable_padded_drafter_batch + if use_gpu_toks: + # EAGLE/DraftModel speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. - assert isinstance(self.drafter, EagleProposer) + assert isinstance(self.drafter, EagleProposer | DraftModelProposer) sampled_token_ids = sampler_output.sampled_token_ids if input_fits_in_drafter: propose_draft_token_ids(sampled_token_ids) @@ -3679,8 +3693,8 @@ class GPUModelRunner( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) - elif spec_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + elif spec_config.use_eagle() or spec_config.uses_draft_model(): + assert isinstance(self.drafter, EagleProposer | DraftModelProposer) if spec_config.disable_padded_drafter_batch: # When padded-batch is disabled, the sampled_token_ids should be @@ -4475,8 +4489,12 @@ class GPUModelRunner( else: hidden_states = outputs - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ): + assert isinstance(self.drafter, EagleProposer | DraftModelProposer) + assert self.speculative_config is not None # Eagle currently only supports PIECEWISE cudagraphs. # Therefore only use cudagraphs if the main model uses PIECEWISE # NOTE(lucas): this is a hack, need to clean up. @@ -5652,8 +5670,11 @@ class GPUModelRunner( kv_cache_config, kernel_block_sizes ) - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ): + assert isinstance(self.drafter, EagleProposer | DraftModelProposer) # validate all draft model layers belong to the same kv cache # group self.drafter.validate_same_kv_cache_group(kv_cache_config) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 810160046..ccfbc3c6b 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -352,8 +352,8 @@ def bind_kv_cache( pass else: raise NotImplementedError - layer_name = layer_names[0] - runner_kv_caches.append(kv_caches[layer_name]) + for layer_name in layer_names: + runner_kv_caches.append(kv_caches[layer_name]) # Bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items():