[Feat][Spec Decode] DFlash (#36847)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-03-30 15:03:15 -04:00
committed by GitHub
parent ab1a6a43fa
commit 494636b29d
17 changed files with 1577 additions and 107 deletions

View File

@@ -1163,6 +1163,14 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
# "JackFram/llama-160m",
# speculative_model="ibm-ai-platform/llama-160m-accelerator"
# ),
# [DFlash]
"DFlashDraftModel": _HfExamplesInfo(
"Qwen/Qwen3.5-4B",
speculative_model="z-lab/Qwen3.5-4B-DFlash",
use_original_num_layers=True, # Need all layers since DFlash has >1 layer,
max_model_len=8192, # Reduce max len to ensure test runs in low-VRAM CI env
max_num_seqs=32,
),
# [Eagle]
"EagleDeepSeekMTPModel": _HfExamplesInfo(
"eagle618/deepseek-v3-random",

View File

@@ -7,6 +7,7 @@ from typing import Any
import pytest
import torch
from tqdm import tqdm
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
from tests.utils import (
@@ -1105,19 +1106,178 @@ def some_high_acceptance_metrics() -> dict:
}
def compute_acceptance_rate(metrics: list[Metric]) -> float:
def compute_acceptance_rate(
metrics: list[Metric], prev_metrics: list[Metric] | None = None
) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value
if n_draft_toks == 0:
return float("nan")
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
if prev_metrics is not None:
prev_name2metric = {metric.name: metric for metric in prev_metrics}
n_draft_toks -= prev_name2metric["vllm:spec_decode_num_draft_tokens"].value
n_accepted_toks -= prev_name2metric[
"vllm:spec_decode_num_accepted_tokens"
].value
if n_draft_toks <= 0:
return float("nan")
return n_accepted_toks / n_draft_toks
def compute_acceptance_len(metrics: list[Metric]) -> float:
def compute_acceptance_len(
metrics: list[Metric], prev_metrics: list[Metric] | None = None
) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
if n_drafts == 0:
return 1
if prev_metrics is not None:
prev_name2metric = {metric.name: metric for metric in prev_metrics}
n_drafts -= prev_name2metric["vllm:spec_decode_num_drafts"].value
n_accepted_toks -= prev_name2metric[
"vllm:spec_decode_num_accepted_tokens"
].value
if n_drafts <= 0:
return 1
return 1 + (n_accepted_toks / n_drafts)
# Datasets in the format used in DFlash validations
def load_and_process_dataset(data_name: str):
from datasets import load_dataset
if data_name == "gsm8k":
dataset = load_dataset("openai/gsm8k", "main", split="test")
prompt_fmt = (
"{question}\nPlease reason step by step,"
" and put your final answer within \\boxed{{}}."
)
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
elif data_name == "mt-bench":
dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
dataset = dataset.map(lambda x: {"turns": x["prompt"]})
elif data_name == "humaneval":
dataset = load_dataset("openai/openai_humaneval", split="test")
prompt_fmt = (
"Write a solution to the following problem and make sure"
" that it passes the tests:\n```python\n{prompt}\n```"
)
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
return dataset
@pytest.fixture
def dflash_config():
target_model = "Qwen/Qwen3-8B"
draft_model = "z-lab/Qwen3-8B-DFlash-b16"
return dict(
model=target_model,
trust_remote_code=True,
speculative_config={
"method": "dflash",
"model": draft_model,
"num_speculative_tokens": 16,
"max_model_len": 32768,
},
max_model_len=32768,
max_num_seqs=128,
gpu_memory_utilization=0.85,
enforce_eager=False,
disable_log_stats=False,
)
def test_dflash_acceptance_rates(dflash_config):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Runs acceptance rate validation on GSM8k, MT-Bench, and HumanEval
comparing against baseline results from the paper (Table 1).
See https://github.com/z-lab/dflash/blob/main/benchmark_sglang.py for methodology.
"""
spec_llm = LLM(**dflash_config)
max_prompts_per_dataset = 200 # mt-bench has 80, humaneval has 164, truncates gsm8k
# All scores from Table 1 in https://arxiv.org/pdf/2602.06036
expected_acceptance_lengths = {
"mt-bench": 4.24,
"humaneval": 6.50,
"gsm8k": 6.54 * 0.95, # runs with a subset of prompts so extra wide tol here
}
tokenizer = spec_llm.get_tokenizer()
for dataset_name, expected_len in expected_acceptance_lengths.items():
dataset = load_and_process_dataset(dataset_name)
prev_metrics = None
acceptance_lengths = []
for i in tqdm(
range(min(max_prompts_per_dataset, len(dataset))),
desc=f"Processing {dataset_name}",
):
user_content = dataset[i]["turns"][0]
prompt_text = tokenizer.apply_chat_template(
[{"role": "user", "content": user_content}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
# Temp=0, MaxTokens=2048 from the paper
spec_llm.generate(
[prompt_text],
SamplingParams(temperature=0, max_tokens=2048),
use_tqdm=False,
)
current_metrics = spec_llm.get_metrics()
acceptance_len = compute_acceptance_len(current_metrics, prev_metrics)
prev_metrics = current_metrics
acceptance_lengths.append(acceptance_len)
mean_acceptance_length = sum(acceptance_lengths) / len(acceptance_lengths)
expected_len = expected_len * 0.9
print(
f"DFlash acceptance_len for {dataset_name}: {mean_acceptance_length:.2f}"
f" (expected at least {expected_len:.2f})"
)
assert mean_acceptance_length >= expected_len, (
f"DFlash acceptance_len for {dataset_name} is below expected threshold:"
f"{mean_acceptance_length:.2f} < {expected_len:.2f}"
)
del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
def test_dflash_correctness(dflash_config):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Ensures output correctness on GSM8k, with cudagraphs and batching on.
"""
spec_llm = LLM(**dflash_config)
# Evaluate GSM8k accuracy (Qwen3-8B ref: ~87-92% on GSM8k)
evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8)
current_metrics = spec_llm.get_metrics()
acceptance_len = compute_acceptance_len(current_metrics)
# AR is thoroughly validated in test_dflash_acceptance_rates, in a manner consistent
# with the DFlash paper. However, that test measures AL per-request and thus runs
# with a batch size of 1. To ensure that AL does not collapse with large batch sizes
# we enforce a baseline on the AL over the full lm-eval-style GSM8k test.
expected_len = 3.5 # Measured is 3.9 to 4.0
print(f"DFlash GSM8k correctness test got AL {acceptance_len}")
assert acceptance_len >= expected_len, (
"DFlash correctness check failed with"
f" {acceptance_len=}, expected at least {expected_len}"
)
del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()

View File

@@ -27,6 +27,7 @@ from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.dflash import DFlashProposer
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -36,6 +37,8 @@ model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
ar_draft_model_dir = "amd/PARD-Llama-3.2-1B" # Compatible with parallel and AR drafting
dflash_target_dir = "Qwen/Qwen3-8B"
dflash_dir = "z-lab/Qwen3-8B-DFlash-b16"
BLOCK_SIZE = 16
@@ -47,18 +50,29 @@ def _create_proposer(
speculative_token_tree: list[tuple[int, ...]] | None = None,
parallel_drafting: bool = False,
) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
# Method-dependent setup
if method == "eagle":
target_model_dir = model_dir
draft_model_dir = eagle_dir
elif method == "eagle3":
target_model_dir = model_dir
draft_model_dir = eagle3_dir
elif method == "draft_model":
target_model_dir = model_dir
draft_model_dir = ar_draft_model_dir
elif method == "dflash":
target_model_dir = dflash_target_dir
draft_model_dir = dflash_dir
else:
raise ValueError(f"Unknown method: {method}")
model_config = ModelConfig(
model=target_model_dir,
runner="generate",
max_model_len=100,
trust_remote_code=(method == "dflash"),
)
spec_token_tree_str = None
if speculative_token_tree is not None:
assert num_speculative_tokens == len(speculative_token_tree)
@@ -92,7 +106,9 @@ def _create_proposer(
attention_config=AttentionConfig(backend=attention_backend),
)
if "eagle" in method:
if method == "dflash":
proposer = DFlashProposer(vllm_config=vllm_config, device=device)
elif "eagle" in method:
proposer = EagleProposer(vllm_config=vllm_config, device=device)
else:
proposer = DraftModelProposer(vllm_config=vllm_config, device=device)
@@ -1152,3 +1168,136 @@ def test_propose_tree(spec_token_tree):
# Verify that the draft tokens match our expectations.
assert torch.equal(result, expected_tokens)
def test_set_inputs_first_pass_dflash():
"""
Test for DFlash set_inputs_first_pass.
DFlash uses cross-attention: context tokens become K/V and only
query tokens (bonus + mask) are Q. This tests the DFlash-specific
input preparation where:
- Context hidden states are stored by reference (no copy)
- Query input_ids are [next_token, mask, mask, ...] per request
- Context and query positions are written to separate buffers
- token_indices_to_sample points to mask token positions only
- A new CommonAttentionMetadata is returned with causal=False
Setup:
- 3 requests with query_lens [3, 2, 4]
- num_speculative_tokens = 3
- num_query_per_req = 4 (1 bonus + 3 mask tokens)
- next_token_ids: [100, 200, 300]
Expected output layout (query tokens only, 12 total):
Request 0 (indices 0-3): [100, mask, mask, mask]
Request 1 (indices 4-7): [200, mask, mask, mask]
Request 2 (indices 8-11): [300, mask, mask, mask]
Expected positions layout (separate buffers):
Context (_context_positions_buffer, 9 tokens): copied from target_positions
Query (positions, 12 tokens):
Request 0: last_pos=9, query=[10, 11, 12, 13]
Request 1: last_pos=7, query=[8, 9, 10, 11]
Request 2: last_pos=11, query=[12, 13, 14, 15]
"""
device = torch.device(current_platform.device_type)
num_speculative_tokens = 3
proposer = _create_proposer("dflash", num_speculative_tokens)
mask_token_id = proposer.parallel_drafting_token_id
# Setup batch with 3 requests
batch_spec = BatchSpec(
seq_lens=[10, 8, 12],
query_lens=[3, 2, 4],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=BLOCK_SIZE,
device=device,
arange_block_indices=True,
)
# Input tensors
# Request 0: tokens [10, 11, 12] at positions [7, 8, 9]
# Request 1: tokens [20, 21] at positions [6, 7]
# Request 2: tokens [30, 31, 32, 33] at positions [8, 9, 10, 11]
target_token_ids = torch.tensor(
[10, 11, 12, 20, 21, 30, 31, 32, 33], dtype=torch.int32, device=device
)
target_positions = torch.tensor(
[7, 8, 9, 6, 7, 8, 9, 10, 11], dtype=torch.int64, device=device
)
target_hidden_states = torch.randn(
9, proposer.hidden_size, dtype=proposer.dtype, device=device
)
next_token_ids = torch.tensor([100, 200, 300], dtype=torch.int32, device=device)
num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=None,
cad=common_attn_metadata,
num_rejected_tokens_gpu=None,
)
num_query_per_req = 1 + num_speculative_tokens # 4
num_context = 9
# num_tokens is the query-only count
assert num_tokens == 3 * num_query_per_req # 12
# Verify input_ids (query tokens only)
# Each request: [next_token, mask, mask, mask]
M = mask_token_id
expected_input_ids = torch.tensor(
[100, M, M, M, 200, M, M, M, 300, M, M, M],
dtype=torch.int32,
device=device,
)
assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids)
# Verify context positions (separate buffer): copied from target_positions
assert torch.equal(
proposer._context_positions_buffer[:num_context], target_positions
)
# Verify query positions (separate buffer, starts at index 0):
# req0: last_pos=9, query=[10, 11, 12, 13]
# req1: last_pos=7, query=[8, 9, 10, 11]
# req2: last_pos=11, query=[12, 13, 14, 15]
expected_query_positions = torch.tensor(
[10, 11, 12, 13, 8, 9, 10, 11, 12, 13, 14, 15],
dtype=torch.int64,
device=device,
)
assert torch.equal(
proposer.positions[:num_tokens],
expected_query_positions,
)
# Verify token_indices_to_sample (mask tokens only, skip bonus at offset 0)
# req0: query indices 0-3, mask at 1,2,3
# req1: query indices 4-7, mask at 5,6,7
# req2: query indices 8-11, mask at 9,10,11
expected_token_indices_to_sample = torch.tensor(
[1, 2, 3, 5, 6, 7, 9, 10, 11], dtype=torch.int32, device=device
)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
# Verify the new CAD has DFlash-specific properties
assert output_cad.causal is False # DFlash requires non-causal attention
assert output_cad.num_actual_tokens == num_tokens # query-only count
assert output_cad.max_query_len == num_query_per_req
expected_query_start_loc = torch.tensor(
[0, 4, 8, 12], dtype=torch.int32, device=device
)
assert torch.equal(output_cad.query_start_loc, expected_query_start_loc)
# Verify hidden states (stored by reference, not copied)
assert proposer._dflash_hidden_states is target_hidden_states

View File

@@ -47,8 +47,11 @@ MTPModelTypes = Literal[
"pangu_ultra_moe_mtp",
"step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
NgramGPUTypes = Literal["ngram_gpu"]
DFlashModelTypes = Literal["dflash"]
EagleModelTypes = Literal[
"eagle", "eagle3", "extract_hidden_states", MTPModelTypes, DFlashModelTypes
]
SpeculativeMethod = Literal[
"ngram",
"medusa",
@@ -206,7 +209,11 @@ class SpeculativeConfig:
factors: list[Any] = []
# Eagle3 and extract_hidden_states affect the computation graph because
# they return intermediate hidden states in addition to the final hidden state.
uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states")
uses_aux_hidden_states = self.method in (
"eagle3",
"extract_hidden_states",
"dflash",
)
factors.append(uses_aux_hidden_states)
# The specific layers used also affect the computation graph
@@ -490,7 +497,7 @@ class SpeculativeConfig:
)
# Automatically detect the method
if self.method in ("eagle", "eagle3"):
if self.method in ("eagle", "eagle3", "dflash"):
pass
# examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B
@@ -500,6 +507,8 @@ class SpeculativeConfig:
self.method = "eagle"
elif "eagle3" in self.draft_model_config.model.lower():
self.method = "eagle3"
elif "dflash" in self.draft_model_config.model.lower():
self.method = "dflash"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
@@ -532,7 +541,7 @@ class SpeculativeConfig:
)
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
if self.method in ("eagle", "eagle3", "dflash"):
from vllm.transformers_utils.configs.eagle import EAGLEConfig
from vllm.transformers_utils.configs.speculators import (
SpeculatorsConfig,
@@ -552,6 +561,9 @@ class SpeculativeConfig:
self.draft_model_config.hf_config = eagle_config
self.update_arch_()
if self.method == "dflash":
self.parallel_drafting = True
if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens"
):
@@ -807,7 +819,7 @@ class SpeculativeConfig:
"kimi_k25",
]
if (
self.method in ("eagle3", "extract_hidden_states")
self.method in ("eagle3", "extract_hidden_states", "dflash")
and self.target_model_config
and not any(
supported_model in self.target_model_config.hf_text_config.model_type
@@ -855,7 +867,10 @@ class SpeculativeConfig:
return slots_per_req
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp")
return self.method in ("eagle", "eagle3", "mtp", "dflash")
def use_dflash(self) -> bool:
return self.method == "dflash"
def uses_draft_model(self) -> bool:
return self.method == "draft_model"

View File

@@ -1327,6 +1327,26 @@ class VllmConfig:
max_num_batched_tokens - scheduled_token_delta
)
if self.scheduler_config.max_num_scheduled_tokens <= 0:
raise ValueError(
"max_num_scheduled_tokens is set to"
f" {self.scheduler_config.max_num_scheduled_tokens} based on"
" the speculative decoding settings, which does not allow"
" any tokens to be scheduled. Increase max_num_batched_tokens"
" to accommodate the additional draft token slots, or decrease"
" num_speculative_tokens or max_num_seqs."
)
if self.scheduler_config.max_num_scheduled_tokens < 8192:
logger.warning_once(
"max_num_scheduled_tokens is set to"
f" {self.scheduler_config.max_num_scheduled_tokens} based on"
" the speculative decoding settings. This may lead to suboptimal"
" performance. Consider increasing max_num_batched_tokens to"
" accommodate the additional draft token slots, or decrease"
" num_speculative_tokens or max_num_seqs.",
scope="local",
)
max_num_scheduled_tokens = self.scheduler_config.max_num_scheduled_tokens
if max_num_batched_tokens < max_num_scheduled_tokens + (
self.speculative_config.max_num_new_slots_for_drafting

View File

@@ -285,6 +285,7 @@ class Qwen3ForCausalLM(
self.config = config
self.vllm_config = vllm_config
self.quant_config = quant_config
self.model = Qwen3Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")

View File

@@ -0,0 +1,619 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
import torch.nn.functional as F
from torch import nn
from transformers import Qwen3Config
from vllm import _custom_ops as ops
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.multimodal.inputs import NestedTensors
from vllm.transformers_utils.config import set_default_rope_theta
from vllm.v1.attention.backend import AttentionType
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen3 import Qwen3ForCausalLM
from .utils import (
AutoWeightsLoader,
get_draft_quant_config,
maybe_prefix,
process_eagle_weight,
)
logger = init_logger(__name__)
class DFlashQwen3Attention(nn.Module):
"""Attention for DFlash speculative decoding.
Context KVs are pre-inserted into the KV cache before the forward pass.
This layer handles only query tokens via standard attention.
Adapted from Qwen3Attention."""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_parameters: dict,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
attention_bias: bool = False,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__()
self.layer_name = prefix
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=attention_bias, # DFlash has o_proj bias when using attention bias
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
max_position=max_position,
rope_parameters=rope_parameters,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type,
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
"""DFlash attention assumes that the KV cache is already populated
with the context K/V from the target model's hidden states. This forward op
computes attention for the query tokens only.
See also: precompute_and_store_context_kv"""
qkv = F.linear(hidden_states, self.qkv_proj.weight, self.qkv_proj.bias)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Per-head RMSNorm
q_shape, k_shape = q.shape, k.shape
q = self.q_norm(
q.view(*q_shape[:-1], q_shape[-1] // self.head_dim, self.head_dim)
).view(q_shape)
k = self.k_norm(
k.view(*k_shape[:-1], k_shape[-1] // self.head_dim, self.head_dim)
).view(k_shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class DFlashQwen3DecoderLayer(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
*,
config: Qwen3Config,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
set_default_rope_theta(config, default_theta=1000000)
attn_type = AttentionType.DECODER
self.self_attn = DFlashQwen3Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rms_norm_eps=config.rms_norm_eps,
attention_bias=getattr(config, "attention_bias", False),
head_dim=getattr(config, "head_dim", None),
cache_config=cache_config,
quant_config=quant_config,
rope_parameters=config.rope_parameters,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = Qwen3MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is not None:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
else:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class DFlashQwen3Model(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
start_layer_id: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = vllm_config.speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
self.quant_config = get_draft_quant_config(vllm_config)
drafter_config = getattr(self.config, "eagle_config", {})
drafter_config.update(getattr(self.config, "dflash_config", {}))
if drafter_config is not None and "use_aux_hidden_state" in drafter_config:
self.use_aux_hidden_state = drafter_config["use_aux_hidden_state"]
else:
self.use_aux_hidden_state = True
current_vllm_config = get_current_vllm_config()
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList(
[
DFlashQwen3DecoderLayer(
current_vllm_config,
prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
config=self.config,
)
for layer_idx in range(self.config.num_hidden_layers)
]
)
if self.use_aux_hidden_state:
num_features_to_use = self.config.num_hidden_layers
if "target_layer_ids" in drafter_config:
num_features_to_use = len(drafter_config["target_layer_ids"])
elif "layer_ids" in drafter_config:
num_features_to_use = len(drafter_config["layer_ids"])
if hasattr(self.config, "target_hidden_size"):
fc_input_size = self.config.target_hidden_size * num_features_to_use
else:
fc_input_size = self.config.hidden_size * num_features_to_use
self.fc = ReplicatedLinear(
input_size=fc_input_size,
output_size=self.config.hidden_size,
bias=False,
params_dtype=vllm_config.model_config.dtype,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "fc"),
return_bias=False,
)
self.hidden_norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
)
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def _build_fused_kv_buffers(self) -> None:
"""Build fused weight buffers for precompute_and_store_context_kv.
Must be called after weights are loaded. Stacks the KV-projection
weights, K-norm weights, and RoPE parameters from every attention
layer so that precompute_and_store_context_kv can run one fused
GEMM for all layers at once. Also aliases the weight of the hidden_norm.
"""
layers_attn = [layer.self_attn for layer in self.layers]
attn0 = layers_attn[0]
has_bias = attn0.qkv_proj.bias is not None
self._hidden_norm_weight = self.hidden_norm.weight.data
# KV projection weights: [num_layers * 2 * kv_size, hidden_size]
kv_weights = [a.qkv_proj.weight[a.q_size :] for a in layers_attn]
self._fused_kv_weight = torch.cat(kv_weights, dim=0)
if has_bias:
kv_biases = [a.qkv_proj.bias[a.q_size :] for a in layers_attn]
self._fused_kv_bias: torch.Tensor | None = torch.cat(kv_biases, dim=0)
else:
self._fused_kv_bias = None
# K-norm weights: list of [head_dim] tensors, one per layer.
self._k_norm_weights = [a.k_norm.weight.data for a in layers_attn]
# RoPE parameters
self._rope_head_size = attn0.rotary_emb.head_size
self._rope_cos_sin_cache = attn0.rotary_emb.cos_sin_cache
self._rope_is_neox = attn0.rotary_emb.is_neox_style
# Validation that RoPE params are the same across all layers
for attn in layers_attn[1:]:
assert (
attn.rotary_emb.head_size == self._rope_head_size
and attn.rotary_emb.is_neox_style == self._rope_is_neox
), "All layers must have the same RoPE parameters for DFlash precomputation"
# Layer metadata
self._num_attn_layers = len(layers_attn)
self._kv_size = attn0.kv_size
self._head_dim = attn0.head_dim
self._num_kv_heads = attn0.num_kv_heads
self._rms_norm_eps = attn0.q_norm.variance_epsilon
# Validation that all layers have the same attention config
for attn in layers_attn[1:]:
assert (
attn.kv_size == self._kv_size
and attn.head_dim == self._head_dim
and attn.num_kv_heads == self._num_kv_heads
and attn.q_norm.variance_epsilon == self._rms_norm_eps
), "All layers must have the same attn config for DFlash precomputation"
# References to inner Attention layers for direct cache writes
self._attn_layers = [layer.self_attn.attn for layer in self.layers]
def precompute_and_store_context_kv(
self,
context_states: torch.Tensor,
context_positions: torch.Tensor,
context_slot_mapping: torch.Tensor | None = None,
) -> None:
"""Precompute K/V for context states write them into each layer's KV cache.
Input context states are projected to K/V, normed, and have RoPE applied.
Since the context shape is different than the query shape, we can't rely on the
regular forward pass to apply torch.compile and CUDA graphs to this section.
As such, this function is optimized to minimize the number of torch ops present:
we use fused vLLM kernels for RMSNorm and RoPE, fuse the GEMM into one
large projection, and avoid cloning buffers (with .contiguous()) where possible.
When context_slot_mapping is None (e.g. during dummy_run) only
the computation runs, and no K/V is written to cache.
"""
if not hasattr(self, "_num_attn_layers"):
logger.warning_once(
"DFlash buffer initialization was skipped. If dummy weights are not "
"in use, this may indicate an error in weight loading."
)
self._build_fused_kv_buffers()
num_ctx = context_states.shape[0]
L = self._num_attn_layers
kv = self._kv_size
hd = self._head_dim
nkv = self._num_kv_heads
# --- Fused KV projection (one GEMM for all layers) ---
normed_context_states = torch.empty_like(context_states)
ops.rms_norm(
normed_context_states,
context_states,
self._hidden_norm_weight,
self._rms_norm_eps,
)
all_kv_flat = F.linear(
normed_context_states, self._fused_kv_weight, self._fused_kv_bias
)
# Single contiguous copy that separates K/V and transposes to
# layer-major layout. Result: [2, L, num_ctx, nkv, hd] contiguous.
# Indexing dim-0 gives contiguous [L, num_ctx, nkv, hd] for K and V.
all_kv = (
all_kv_flat.view(num_ctx, L, 2, nkv, hd).permute(2, 1, 0, 3, 4).contiguous()
)
all_k = all_kv[0] # [L, num_ctx, nkv, hd], contiguous
all_v = all_kv[1] # [L, num_ctx, nkv, hd], contiguous
# --- Per-layer RMSNorm K (3D: [num_ctx, nkv, hd] per layer) ---
all_k_normed = torch.empty_like(all_k)
for i in range(L):
ops.rms_norm(
all_k_normed[i],
all_k[i],
self._k_norm_weights[i],
self._rms_norm_eps,
)
# --- Fused RoPE across all layers ---
# View as [L * num_ctx, kv] so RoPE sees one big batch (no copy).
# In-place RoPE: pass K as the "query" arg with key=None.
all_k_flat = all_k_normed.view(L * num_ctx, kv)
positions_repeated = context_positions.repeat(L)
cos_sin_cache = self._rope_cos_sin_cache
if cos_sin_cache.dtype != all_k_flat.dtype:
cos_sin_cache = cos_sin_cache.to(dtype=all_k_flat.dtype)
ops.rotary_embedding(
positions_repeated,
all_k_flat,
None,
self._rope_head_size,
cos_sin_cache,
self._rope_is_neox,
)
if context_slot_mapping is None:
return
# --- Per-layer cache insert ---
all_k_final = all_k_flat.view(L, num_ctx, nkv, hd)
for i in range(L):
attn = self._attn_layers[i]
kv_cache = attn.kv_cache
attn.impl.do_kv_cache_update(
attn,
all_k_final[i],
all_v[i],
kv_cache,
context_slot_mapping,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
if input_embeds is None:
input_embeds = self.embed_input_ids(input_ids)
hidden_states = input_embeds
residual = None
for layer in self.layers:
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class DFlashQwen3ForCausalLM(Qwen3ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config.speculative_config.draft_model_config.hf_config
if getattr(self.config, "draft_vocab_size", None) is None:
self.config.draft_vocab_size = getattr(self.config, "vocab_size", None)
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config
)
self.config.target_layer_count = target_layer_num
self.model = DFlashQwen3Model(
vllm_config=vllm_config,
prefix="model",
start_layer_id=target_layer_num,
)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.lm_head = ParallelLMHead(
self.config.draft_vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(
self.config.draft_vocab_size, scale=logit_scale
)
self.draft_id_to_target_id = None
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: NestedTensors | None = None,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model(input_ids, positions, inputs_embeds)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
if self.draft_id_to_target_id is None:
return logits
base = torch.arange(self.config.draft_vocab_size, device=logits.device)
targets = base + self.draft_id_to_target_id
logits_new = logits.new_full(
(logits.shape[0], self.config.vocab_size),
float("-inf"),
)
logits_new[:, targets] = logits
return logits_new
def precompute_and_store_context_kv(
self,
context_states: torch.Tensor,
context_positions: torch.Tensor,
context_slot_mapping: torch.Tensor | None = None,
) -> None:
"""Precompute projected + RoPE'd K/V and write to cache."""
self.model.precompute_and_store_context_kv(
context_states, context_positions, context_slot_mapping
)
def combine_hidden_states(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
if not self.model.use_aux_hidden_state:
return hidden_states
needs_squeeze = hidden_states.dim() == 1
if needs_squeeze:
hidden_states = hidden_states.unsqueeze(0)
result = self.model.fc(hidden_states)
if needs_squeeze:
result = result.squeeze(0)
return result
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
model_weights = {}
includes_draft_id_mapping = False
includes_embed_tokens = False
for name, loaded_weight in weights:
assert "mask_hidden" not in name, (
"DFlash should use mask_token_id to embed the padding hidden state"
)
if "t2d" in name:
continue
if "d2t" in name:
name = name.replace("d2t", "draft_id_to_target_id")
includes_draft_id_mapping = True
elif "lm_head" not in name:
name = "model." + name
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight
process_eagle_weight(self, name)
skip_substrs = []
if not includes_draft_id_mapping:
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
if not self.model.use_aux_hidden_state:
skip_substrs.append("fc.")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
skip_substrs=skip_substrs,
)
loader.load_weights(model_weights.items())
self.model._build_fused_kv_buffers()

View File

@@ -56,6 +56,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
from .interfaces import (
EagleModelMixin,
HasInnerState,
IsHybrid,
MixtureOfExperts,
@@ -454,7 +455,7 @@ class Qwen3NextDecoderLayer(nn.Module):
@support_torch_compile
class Qwen3NextModel(nn.Module):
class Qwen3NextModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -492,8 +493,6 @@ class Qwen3NextModel(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int, ...] = ()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -515,20 +514,19 @@ class Qwen3NextModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer,
):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(

View File

@@ -546,6 +546,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),

View File

@@ -62,9 +62,20 @@ class EAGLEConfig(PretrainedConfig):
else f"Eagle3{arch}"
for arch in self.model.architectures
]
elif method == "dflash":
assert self.model is not None, (
"model should not be None when method is dflash"
)
kwargs["architectures"] = [
arch
if arch.startswith("DFlash") or arch.endswith("DFlash")
else f"DFlash{arch}"
for arch in self.model.architectures
]
else:
raise ValueError(
f"Invalid method {method}. Supported methods are eagle and eagle3."
f"Invalid method {method}. Supported methods are "
"eagle, eagle3, and dflash."
)
super().__init__(**kwargs)

View File

@@ -220,6 +220,17 @@ class AttentionBackend(ABC):
def supports_per_head_quant_scales(cls) -> bool:
return False
@classmethod
def supports_non_causal(cls) -> bool:
"""Check if backend supports non-causal (bidirectional) attention
for decoder models.
Unlike ENCODER_ONLY attention type which implies a different
execution model, this refers to non-causal attention within the
standard paged-KV-cache decoder path.
"""
return False
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
@@ -261,6 +272,7 @@ class AttentionBackend(ABC):
use_per_head_quant_scales: bool,
device_capability: "DeviceCapability",
attn_type: str,
use_non_causal: bool = False,
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
@@ -293,6 +305,8 @@ class AttentionBackend(ABC):
invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
invalid_reasons.append(f"attention type {attn_type} not supported")
if use_non_causal and not cls.supports_non_causal():
invalid_reasons.append("non-causal attention not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,

View File

@@ -101,6 +101,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_name() -> str:
return "FLASH_ATTN"
@classmethod
def supports_non_causal(cls) -> bool:
return True
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types."""

View File

@@ -29,6 +29,7 @@ class AttentionSelectorConfig(NamedTuple):
use_mm_prefix: bool = False
use_per_head_quant_scales: bool = False
attn_type: str = AttentionType.DECODER
use_non_causal: bool = False
def __repr__(self):
return (
@@ -41,7 +42,8 @@ class AttentionSelectorConfig(NamedTuple):
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"use_per_head_quant_scales={self.use_per_head_quant_scales}, "
f"attn_type={self.attn_type})"
f"attn_type={self.attn_type}, "
f"use_non_causal={self.use_non_causal})"
)
@@ -76,6 +78,11 @@ def get_attn_backend(
else:
block_size = None
speculative_config = vllm_config.speculative_config
use_non_causal = (
speculative_config is not None and speculative_config.method == "dflash"
)
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
@@ -87,6 +94,7 @@ def get_attn_backend(
use_mm_prefix=use_mm_prefix,
use_per_head_quant_scales=use_per_head_quant_scales,
attn_type=attn_type or AttentionType.DECODER,
use_non_causal=use_non_causal,
)
return _cached_get_attn_backend(

View File

@@ -0,0 +1,282 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from typing_extensions import override
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.triton_utils import triton
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer
from vllm.v1.spec_decode.utils import copy_and_expand_dflash_inputs_kernel
logger = init_logger(__name__)
class DFlashProposer(SpecDecodeBaseProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.method == "dflash"
super().__init__(
vllm_config=vllm_config,
device=device,
pass_hidden_states_to_model=True,
runner=runner,
)
# Only next_token_ids and mask tokens are query tokens, all other context is K/V
self.max_query_tokens = self.max_batch_size * (1 + self.num_speculative_tokens)
# Positions covers both context states + query states
self.max_positions = self.max_num_tokens + self.max_query_tokens
# Separate context buffers to keep query buffer addresses stable for CUDA graphs
self._context_slot_mapping_buffer = torch.zeros(
self.max_num_tokens,
dtype=torch.int64,
device=device,
)
self._slot_mapping_buffer = torch.zeros(
self.max_query_tokens,
dtype=torch.int64,
device=device,
)
self._context_positions_buffer = torch.zeros(
self.max_num_tokens,
dtype=torch.int64,
device=device,
)
self.positions = torch.zeros(
self.max_query_tokens,
dtype=torch.int64,
device=device,
)
self.arange = torch.arange(
self.max_positions + 1, device=device, dtype=torch.int32
)
# For DFlash we use the input embeddings to embed the mask token
self.parallel_drafting_hidden_state_tensor = None
@override
def _raise_if_multimodal(self):
# Override to allow multimodal inputs since DFlash supports Qwen3.5 models
# Support for multimodal inputs has not been tested.
pass
@override
def set_inputs_first_pass(
self,
target_token_ids: torch.Tensor,
next_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor | None,
cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
# DFlash cross-attention: context K/V from target hidden states,
# Q from query embeddings (bonus + mask tokens).
batch_size = cad.batch_size()
num_context = target_token_ids.shape[0]
num_query_per_req = 1 + self.num_speculative_tokens
num_query_total = batch_size * num_query_per_req
# Store for build_model_inputs_first_pass to use
self._dflash_num_context = num_context
# We don't need to copy into a buffer here since the context preprocessing
# does not run in a CUDA graph
self._dflash_hidden_states = target_hidden_states
token_indices_to_sample = torch.empty(
batch_size * self.num_speculative_tokens,
dtype=torch.int32,
device=self.device,
)
# Launch fused triton kernel for input_ids, positions, slot_mapping,
# and token_indices_to_sample
max_ctx_per_req = cad.max_query_len
max_tokens_per_req = max_ctx_per_req + num_query_per_req
BLOCK_SIZE = min(256, triton.next_power_of_2(max_tokens_per_req))
num_blocks = triton.cdiv(max_tokens_per_req, BLOCK_SIZE)
grid = (batch_size, num_blocks)
has_num_rejected = num_rejected_tokens_gpu is not None
copy_and_expand_dflash_inputs_kernel[grid](
# Inputs
next_token_ids_ptr=next_token_ids,
target_positions_ptr=target_positions,
# Outputs
out_input_ids_ptr=self.input_ids,
out_context_positions_ptr=self._context_positions_buffer,
out_query_positions_ptr=self.positions,
out_context_slot_mapping_ptr=self._context_slot_mapping_buffer,
out_query_slot_mapping_ptr=self._slot_mapping_buffer,
out_token_indices_ptr=token_indices_to_sample,
# Block table
block_table_ptr=cad.block_table_tensor,
block_table_stride=cad.block_table_tensor.stride(0),
# Metadata
query_start_loc_ptr=cad.query_start_loc,
num_rejected_tokens_ptr=(
num_rejected_tokens_gpu if has_num_rejected else 0
),
# Scalars
parallel_drafting_token_id=self.parallel_drafting_token_id,
block_size=self.block_size,
num_query_per_req=num_query_per_req,
num_speculative_tokens=self.num_speculative_tokens,
total_input_tokens=num_context,
BLOCK_SIZE=BLOCK_SIZE,
HAS_NUM_REJECTED=has_num_rejected,
)
query_slot_mapping = self._slot_mapping_buffer[:num_query_total]
new_query_start_loc = self.arange[: batch_size + 1] * num_query_per_req
# In padded mode, cad.seq_lens includes rejected tokens. Subtract
# them so attention only sees the valid prefix of context states.
effective_seq_lens = cad.seq_lens
if has_num_rejected:
effective_seq_lens = effective_seq_lens - num_rejected_tokens_gpu
new_cad = CommonAttentionMetadata(
query_start_loc=new_query_start_loc,
seq_lens=effective_seq_lens + num_query_per_req,
query_start_loc_cpu=(
torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone()
* num_query_per_req
),
_seq_lens_cpu=None,
_num_computed_tokens_cpu=None,
num_reqs=cad.num_reqs,
num_actual_tokens=num_query_total,
max_query_len=num_query_per_req,
max_seq_len=cad.max_seq_len + num_query_per_req,
block_table_tensor=cad.block_table_tensor,
slot_mapping=query_slot_mapping,
causal=False, # Non-causal attention is required for DFlash
)
return num_query_total, token_indices_to_sample, new_cad
@override
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
"""
Key differences to default dummy_run:
- Only one forward pass due to parallel drafting
- DFlash uses context states as unpadded metadata, so hidden_states will
use the unpadded num_tokens instead of num_input_tokens
- max_query_tokens is quite small, DFlash only sees spec tokens as queries
- Multimodal inputs are not currently supported
"""
num_query_tokens = min(num_tokens, self.max_query_tokens)
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_query_tokens, use_cudagraphs=use_cudagraphs
)
)
# Slot mapping sized to num_input_tokens (query only), matching
# the K/V tensor size from the model forward. Context KVs are
# pre-inserted separately and don't flow through the model.
if (
self._draft_attn_layer_names
and slot_mappings is not None
and next(iter(self._draft_attn_layer_names)) in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
# Context and query positions use separate buffers; no copy needed.
context_positions = self._context_positions_buffer[:num_tokens]
# Context states will be passed directly to the precomputation without
# going through the buffer, since no CUDA graph is used for the precomputation.
# For the dummy run, we use the dummy buffer.
context_states = self.hidden_states[:num_tokens]
# Run the KV projection (GEMM + norms + RoPE) for memory profiling,
self.model.precompute_and_store_context_kv(context_states, context_positions)
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self._get_positions(num_input_tokens),
inputs_embeds=None,
)
@override
def build_model_inputs_first_pass(
self,
num_tokens: int,
num_input_tokens: int,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
) -> tuple[dict[str, Any], int]:
# Context and query positions/slots were written to separate
# buffers by the kernel — no copy needed.
num_context = self._dflash_num_context
# Pre-insert context KVs directly into cache
self.model.precompute_and_store_context_kv(
self._dflash_hidden_states, # Shape is already [num_context, hidden_size]
self._context_positions_buffer[:num_context],
self._context_slot_mapping_buffer[:num_context],
)
return (
dict(
input_ids=self.input_ids[:num_input_tokens],
positions=self._get_positions(num_input_tokens),
inputs_embeds=None,
),
num_input_tokens,
)
@override
def build_per_layer_attn_metadata(
self, cad: CommonAttentionMetadata, draft_index: int = 0
) -> dict[str, object]:
per_layer_attention_metadata = super().build_per_layer_attn_metadata(
cad, draft_index
)
for layer_name, attn_metadata in per_layer_attention_metadata.items():
assert getattr(attn_metadata, "causal", None) is False, (
f"Attention metadata for layer {layer_name} does not have"
" non-causal support, which is required for DFlash."
" Consider using a different attention backend, such as FlashAttention."
)
return per_layer_attention_metadata
@override
def _get_eagle3_use_aux_hidden_state_from_config(self):
use_aux_hidden_state = True
dflash_config = getattr(
self.draft_model_config.hf_config, "dflash_config", None
)
if dflash_config is not None:
use_aux_hidden_state = dflash_config.get("use_aux_hidden_state", True)
return use_aux_hidden_state

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from importlib.util import find_spec
from typing import cast
from typing import Any, cast
import numpy as np
import torch
@@ -23,6 +23,7 @@ from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.triton_utils import triton
@@ -83,13 +84,15 @@ class SpecDecodeBaseProposer:
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
# Unifying eagle, draft model, and parallel drafting support
# Unifying eagle, draft model, and parallel drafting support.
# DFlash always uses parallel drafting (all tokens in one pass),
# but has an additional slot for the next_token_id (does not shift like EAGLE)
self.parallel_drafting: bool = self.speculative_config.parallel_drafting
self.extra_slots_per_request = (
1 if not self.parallel_drafting else self.num_speculative_tokens
)
self.net_num_new_slots_per_request = self.extra_slots_per_request - (
1 if self.pass_hidden_states_to_model else 0
1 if (self.pass_hidden_states_to_model and self.method != "dflash") else 0
)
self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0
@@ -101,10 +104,14 @@ class SpecDecodeBaseProposer:
self.speculative_config.use_local_argmax_reduction
)
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.token_arange_np = np.arange(self.max_num_tokens)
# Can be specialized by methods like DFlash to reduce the limit
self.max_query_tokens = self.max_num_tokens
self.max_positions = self.max_num_tokens
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
@@ -146,18 +153,20 @@ class SpecDecodeBaseProposer:
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros(
(3, self.max_num_tokens + 1), dtype=torch.int64, device=device
(3, self.max_positions + 1), dtype=torch.int64, device=device
)
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
self.xdrope_positions = torch.zeros(
(self.uses_xdrope_dim, self.max_num_tokens + 1),
(self.uses_xdrope_dim, self.max_positions + 1),
dtype=torch.int64,
device=device,
)
else:
# RoPE need (max_num_tokens,)
self.positions = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
self.max_positions,
dtype=torch.int64,
device=device,
)
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
@@ -168,7 +177,7 @@ class SpecDecodeBaseProposer:
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
max_num_slots_for_arange = max(self.max_batch_size + 1, self.max_num_tokens)
self.arange = torch.arange(
max_num_slots_for_arange, device=device, dtype=torch.int32
)
@@ -200,7 +209,7 @@ class SpecDecodeBaseProposer:
)
self.backup_next_token_ids = CpuGpuBuffer(
max_batch_size,
self.max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
@@ -208,7 +217,9 @@ class SpecDecodeBaseProposer:
)
self._slot_mapping_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
self.max_positions,
dtype=torch.int64,
device=device,
)
# Determine allowed attention backends once during initialization.
@@ -275,7 +286,7 @@ class SpecDecodeBaseProposer:
# Precompute draft position offsets in flattened tree.
self.tree_draft_pos_offsets = torch.arange(
1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
).repeat(max_batch_size, 1)
).repeat(self.max_batch_size, 1)
def _raise_if_padded_drafter_batch_disabled(self):
if self.speculative_config.disable_padded_drafter_batch:
@@ -305,14 +316,19 @@ class SpecDecodeBaseProposer:
# for those masked slots.
model_hf_config = self.draft_model_config.hf_config
if hasattr(model_hf_config, "pard_token"):
# DFlash stores mask_token_id in dflash_config
dflash_config = getattr(model_hf_config, "dflash_config", None)
if dflash_config and "mask_token_id" in dflash_config:
self.parallel_drafting_token_id = dflash_config["mask_token_id"]
elif hasattr(model_hf_config, "pard_token"):
self.parallel_drafting_token_id = model_hf_config.pard_token
elif hasattr(model_hf_config, "ptd_token_id"):
self.parallel_drafting_token_id = model_hf_config.ptd_token_id
else:
raise ValueError(
"For parallel drafting, the draft model config must have "
"`pard_token` or `ptd_token_id` specified in its config.json."
"`pard_token`, `ptd_token_id`, or "
"`dflash_config.mask_token_id` specified in its config.json."
)
if self.pass_hidden_states_to_model:
@@ -402,9 +418,14 @@ class SpecDecodeBaseProposer:
) -> torch.Tensor:
batch_size = common_attn_metadata.batch_size()
if self.method == "eagle3":
if self.method in ("eagle3", "dflash"):
assert isinstance(
self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM)
self.model,
(
Eagle3LlamaForCausalLM,
Eagle3DeepseekV2ForCausalLM,
DFlashQwen3ForCausalLM,
),
)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states
@@ -423,41 +444,18 @@ class SpecDecodeBaseProposer:
)
)
per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
per_layer_attn_metadata = self.build_per_layer_attn_metadata(
common_attn_metadata
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
model_kwargs, slot_mapping_size = self.build_model_inputs_first_pass(
num_tokens, num_input_tokens, mm_embed_inputs
)
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
@@ -465,7 +463,7 @@ class SpecDecodeBaseProposer:
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
slot_mapping_size, common_attn_metadata.slot_mapping
),
):
ret_hidden_states = self.model(**model_kwargs)
@@ -488,7 +486,10 @@ class SpecDecodeBaseProposer:
positions = self.positions[token_indices_to_sample]
hidden_states = hidden_states[token_indices_to_sample]
if isinstance(attn_metadata, TreeAttentionMetadata):
if any(
isinstance(attn_metadata, TreeAttentionMetadata)
for attn_metadata in per_layer_attn_metadata.values()
):
# Draft using tree attention - requires full logits for top-k
logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids_list = self.propose_tree(
@@ -504,6 +505,7 @@ class SpecDecodeBaseProposer:
draft_token_ids = self._greedy_sample(sample_hidden_states)
for attn_metadata in per_layer_attn_metadata.values():
if self.allowed_attn_types is not None and not isinstance(
attn_metadata, self.allowed_attn_types
):
@@ -593,13 +595,9 @@ class SpecDecodeBaseProposer:
common_attn_metadata._num_computed_tokens_cpu += 1
# Rebuild attention metadata
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1,
per_layer_attn_metadata = self.build_per_layer_attn_metadata(
common_attn_metadata, draft_index=token_index + 1
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
@@ -780,8 +778,51 @@ class SpecDecodeBaseProposer:
return total_num_output_tokens, token_indices_to_sample, new_cad
def build_model_inputs_first_pass(
self,
num_tokens: int,
num_input_tokens: int,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
) -> tuple[dict[str, Any], int]:
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
return model_kwargs, num_input_tokens
def build_per_layer_attn_metadata(
self, common_attn_metadata: CommonAttentionMetadata, draft_index: int = 0
) -> dict[str, object]:
per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=draft_index
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
return per_layer_attn_metadata
def model_returns_tuple(self) -> bool:
return self.method not in ("mtp", "draft_model")
return self.method not in ("mtp", "draft_model", "dflash")
def prepare_next_token_ids_cpu(
self,
@@ -1310,15 +1351,20 @@ class SpecDecodeBaseProposer:
self._maybe_share_embeddings(target_language_model)
self._maybe_share_lm_head(target_language_model)
if self.parallel_drafting and self.pass_hidden_states_to_model:
assert self.parallel_drafting_hidden_state_tensor is not None
if (
self.parallel_drafting
and self.pass_hidden_states_to_model
and self.parallel_drafting_hidden_state_tensor is not None
):
flat_mask = self.model.mask_hidden.view(-1)
if self.eagle3_use_aux_hidden_state:
# EAGLE3: mask_hidden stores all aux hidden states,
# project through combine_hidden_states
self.parallel_drafting_hidden_state_tensor.copy_(
self.model.combine_hidden_states(
self.model.mask_hidden.view(3 * self.hidden_size)
)
if self.eagle3_use_aux_hidden_state
else self.model.mask_hidden.view(self.hidden_size)
self.model.combine_hidden_states(flat_mask)
)
else:
self.parallel_drafting_hidden_state_tensor.copy_(flat_mask)
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
"""
@@ -1493,8 +1539,9 @@ class SpecDecodeBaseProposer:
) -> None:
# FIXME: when using tree-based specdec, adjust number of forward-passes
# according to the depth of the tree.
only_one_forward_pass = is_graph_capturing or self.parallel_drafting
for fwd_idx in range(
self.num_speculative_tokens if not is_graph_capturing else 1
1 if only_one_forward_pass else self.num_speculative_tokens
):
if fwd_idx <= 1:
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (

View File

@@ -441,6 +441,114 @@ def copy_and_expand_eagle_inputs_kernel(
)
@triton.jit
def copy_and_expand_dflash_inputs_kernel(
# Inputs
next_token_ids_ptr, # [num_reqs]
target_positions_ptr, # [num_context]
# Outputs
out_input_ids_ptr, # [num_query_total] (output)
out_context_positions_ptr, # [num_context] (output)
out_query_positions_ptr, # [num_query_total] (output)
out_context_slot_mapping_ptr, # [num_context] (output)
out_query_slot_mapping_ptr, # [num_query_total] (output)
out_token_indices_ptr, # [num_reqs * num_speculative_tokens] (output)
# Block table
block_table_ptr, # [max_reqs, max_blocks]
block_table_stride, # stride of block_table dim 0 (in elements)
# Metadata
query_start_loc_ptr, # [num_reqs + 1]
num_rejected_tokens_ptr, # [num_reqs] or null (0) when not padded
# Scalars
parallel_drafting_token_id, # tl.int32
block_size, # tl.int32
num_query_per_req, # tl.int32
num_speculative_tokens, # tl.int32
total_input_tokens, # tl.int32
BLOCK_SIZE: tl.constexpr,
HAS_NUM_REJECTED: tl.constexpr = False,
):
"""
Fused kernel for DFlash first-pass input setup.
Per request, this kernel:
1. Copies context positions from target_positions to
out_context_positions.
2. Computes query positions (last_target_pos + 1 + offset) and writes
them to out_query_positions.
3. Writes input_ids for query tokens: [next_token, mask, mask, ...].
4. Computes slot_mapping for context and query positions into separate
buffers via block_table lookup.
5. Writes token_indices_to_sample for the mask (speculative) tokens.
"""
req_idx = tl.program_id(axis=0)
block_idx = tl.program_id(axis=1)
# Load context token range for this request
ctx_start = tl.load(query_start_loc_ptr + req_idx)
ctx_end = tl.load(query_start_loc_ptr + req_idx + 1)
num_ctx = ctx_end - ctx_start
total_tokens = num_ctx + num_query_per_req
j = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
in_bounds = j < total_tokens
is_ctx = j < num_ctx
is_query = (~is_ctx) & in_bounds
query_off = j - num_ctx # offset within query portion (0-indexed)
# --- Positions ---
# Context: load from target_positions
ctx_pos_idx = tl.minimum(ctx_start + j, total_input_tokens - 1)
ctx_pos = tl.load(target_positions_ptr + ctx_pos_idx, mask=is_ctx, other=0)
# Query: last_valid_pos + 1 + query_off
# In padded mode, ctx_end includes rejected tokens; use valid_ctx_end
# to find the last accepted context position.
if HAS_NUM_REJECTED:
num_rejected = tl.load(num_rejected_tokens_ptr + req_idx)
valid_ctx_end = ctx_end - num_rejected
else:
valid_ctx_end = ctx_end
last_pos = tl.load(target_positions_ptr + valid_ctx_end - 1)
query_pos = last_pos + 1 + query_off
positions = tl.where(is_ctx, ctx_pos, query_pos)
# Context and query positions go to separate buffers.
ctx_pos_out = ctx_start + j
tl.store(out_context_positions_ptr + ctx_pos_out, ctx_pos, mask=is_ctx)
query_out = req_idx * num_query_per_req + query_off
tl.store(out_query_positions_ptr + query_out, query_pos, mask=is_query)
# --- Slot mapping (block_table lookup for all positions) ---
block_num = positions // block_size
# # Clamp block_number to avoid OOB when position is at max
block_num = tl.minimum(block_num, block_table_stride - 1)
block_id = tl.load(
block_table_ptr + req_idx * block_table_stride + block_num,
mask=in_bounds,
other=0,
).to(tl.int64)
slot = block_id * block_size + (positions % block_size)
tl.store(out_context_slot_mapping_ptr + ctx_pos_out, slot, mask=is_ctx)
tl.store(out_query_slot_mapping_ptr + query_out, slot, mask=is_query)
# --- Input IDs (query tokens only) ---
bonus_token = tl.load(next_token_ids_ptr + req_idx)
is_bonus = is_query & (query_off == 0)
input_id = tl.where(is_bonus, bonus_token, parallel_drafting_token_id)
tl.store(out_input_ids_ptr + query_out, input_id, mask=is_query)
# --- Token indices to sample (mask tokens, skip the bonus token) ---
is_sample = is_query & (query_off > 0)
sample_out_idx = req_idx * num_speculative_tokens + (query_off - 1)
tl.store(
out_token_indices_ptr + sample_out_idx,
query_out,
mask=is_sample,
)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def update_num_computed_tokens_for_batch_change(
num_computed_tokens: torch.Tensor,

View File

@@ -160,6 +160,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.dflash import DFlashProposer
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
@@ -515,6 +516,7 @@ class GPUModelRunner(
| NgramProposerGPU
| SuffixDecodingProposer
| EagleProposer
| DFlashProposer
| DraftModelProposer
| MedusaProposer
| ExtractHiddenStatesProposer
@@ -546,6 +548,9 @@ class GPUModelRunner(
self._ngram_pinned_val_buf = torch.zeros(
self.max_num_reqs, dtype=torch.int32, pin_memory=True
)
elif self.speculative_config.use_dflash():
self.drafter = DFlashProposer(self.vllm_config, self.device, self)
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
@@ -2289,7 +2294,7 @@ class GPUModelRunner(
cm.slot_mapping = slot_mappings[kv_cache_gid]
if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer):
if isinstance(self.drafter, (EagleProposer, DFlashProposer)):
if self.drafter.kv_cache_gid == kv_cache_gid:
spec_decode_common_attn_metadata = cm
else:
@@ -4202,7 +4207,10 @@ class GPUModelRunner(
# as inputs, and does not need to wait for bookkeeping to finish.
assert isinstance(
self.drafter,
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
EagleProposer
| DFlashProposer
| DraftModelProposer
| ExtractHiddenStatesProposer,
)
sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter:
@@ -4589,8 +4597,14 @@ class GPUModelRunner(
next_token_ids, valid_sampled_tokens_count
)
elif spec_config.use_eagle() or spec_config.uses_draft_model():
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
elif (
spec_config.use_eagle()
or spec_config.use_dflash()
or spec_config.uses_draft_model()
):
assert isinstance(
self.drafter, EagleProposer | DFlashProposer | DraftModelProposer
)
if spec_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
@@ -4889,10 +4903,13 @@ class GPUModelRunner(
return None
hf_config = self.speculative_config.draft_model_config.hf_config
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
return None
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
layer_ids = getattr(hf_config, "eagle_aux_hidden_state_layer_ids", None)
if not layer_ids:
dflash_config = getattr(hf_config, "dflash_config", None)
if dflash_config and isinstance(dflash_config, dict):
layer_ids = dflash_config.get("target_layer_ids")
if layer_ids and isinstance(layer_ids, (list, tuple)):
return tuple(layer_ids)
@@ -5479,7 +5496,10 @@ class GPUModelRunner(
):
assert isinstance(
self.drafter,
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
EagleProposer
| DFlashProposer
| DraftModelProposer
| ExtractHiddenStatesProposer,
)
assert self.speculative_config is not None
# Eagle currently only supports PIECEWISE cudagraphs.
@@ -6236,7 +6256,9 @@ class GPUModelRunner(
self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model()
):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
assert isinstance(
self.drafter, EagleProposer | DFlashProposer | DraftModelProposer
)
self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes)
def _check_and_update_cudagraph_mode(
@@ -6420,7 +6442,10 @@ class GPUModelRunner(
self.speculative_config.use_eagle()
or self.speculative_config.uses_extract_hidden_states()
):
assert isinstance(self.drafter, EagleProposer | ExtractHiddenStatesProposer)
assert isinstance(
self.drafter,
EagleProposer | DFlashProposer | ExtractHiddenStatesProposer,
)
self.drafter.initialize_cudagraph_keys(cudagraph_mode)
def calculate_reorder_batch_threshold(self) -> None: