[Feat][Spec Decode] DFlash (#36847)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
ab1a6a43fa
commit
494636b29d
@@ -1163,6 +1163,14 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
# "JackFram/llama-160m",
|
||||
# speculative_model="ibm-ai-platform/llama-160m-accelerator"
|
||||
# ),
|
||||
# [DFlash]
|
||||
"DFlashDraftModel": _HfExamplesInfo(
|
||||
"Qwen/Qwen3.5-4B",
|
||||
speculative_model="z-lab/Qwen3.5-4B-DFlash",
|
||||
use_original_num_layers=True, # Need all layers since DFlash has >1 layer,
|
||||
max_model_len=8192, # Reduce max len to ensure test runs in low-VRAM CI env
|
||||
max_num_seqs=32,
|
||||
),
|
||||
# [Eagle]
|
||||
"EagleDeepSeekMTPModel": _HfExamplesInfo(
|
||||
"eagle618/deepseek-v3-random",
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
|
||||
from tests.utils import (
|
||||
@@ -1105,19 +1106,178 @@ def some_high_acceptance_metrics() -> dict:
|
||||
}
|
||||
|
||||
|
||||
def compute_acceptance_rate(metrics: list[Metric]) -> float:
|
||||
def compute_acceptance_rate(
|
||||
metrics: list[Metric], prev_metrics: list[Metric] | None = None
|
||||
) -> float:
|
||||
name2metric = {metric.name: metric for metric in metrics}
|
||||
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
|
||||
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value
|
||||
if n_draft_toks == 0:
|
||||
return float("nan")
|
||||
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
|
||||
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
|
||||
if prev_metrics is not None:
|
||||
prev_name2metric = {metric.name: metric for metric in prev_metrics}
|
||||
n_draft_toks -= prev_name2metric["vllm:spec_decode_num_draft_tokens"].value
|
||||
n_accepted_toks -= prev_name2metric[
|
||||
"vllm:spec_decode_num_accepted_tokens"
|
||||
].value
|
||||
if n_draft_toks <= 0:
|
||||
return float("nan")
|
||||
return n_accepted_toks / n_draft_toks
|
||||
|
||||
|
||||
def compute_acceptance_len(metrics: list[Metric]) -> float:
|
||||
def compute_acceptance_len(
|
||||
metrics: list[Metric], prev_metrics: list[Metric] | None = None
|
||||
) -> float:
|
||||
name2metric = {metric.name: metric for metric in metrics}
|
||||
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore
|
||||
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
|
||||
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value
|
||||
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
|
||||
if n_drafts == 0:
|
||||
return 1
|
||||
if prev_metrics is not None:
|
||||
prev_name2metric = {metric.name: metric for metric in prev_metrics}
|
||||
n_drafts -= prev_name2metric["vllm:spec_decode_num_drafts"].value
|
||||
n_accepted_toks -= prev_name2metric[
|
||||
"vllm:spec_decode_num_accepted_tokens"
|
||||
].value
|
||||
if n_drafts <= 0:
|
||||
return 1
|
||||
return 1 + (n_accepted_toks / n_drafts)
|
||||
|
||||
|
||||
# Datasets in the format used in DFlash validations
|
||||
def load_and_process_dataset(data_name: str):
|
||||
from datasets import load_dataset
|
||||
|
||||
if data_name == "gsm8k":
|
||||
dataset = load_dataset("openai/gsm8k", "main", split="test")
|
||||
prompt_fmt = (
|
||||
"{question}\nPlease reason step by step,"
|
||||
" and put your final answer within \\boxed{{}}."
|
||||
)
|
||||
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
|
||||
elif data_name == "mt-bench":
|
||||
dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
|
||||
dataset = dataset.map(lambda x: {"turns": x["prompt"]})
|
||||
elif data_name == "humaneval":
|
||||
dataset = load_dataset("openai/openai_humaneval", split="test")
|
||||
prompt_fmt = (
|
||||
"Write a solution to the following problem and make sure"
|
||||
" that it passes the tests:\n```python\n{prompt}\n```"
|
||||
)
|
||||
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dflash_config():
|
||||
target_model = "Qwen/Qwen3-8B"
|
||||
draft_model = "z-lab/Qwen3-8B-DFlash-b16"
|
||||
|
||||
return dict(
|
||||
model=target_model,
|
||||
trust_remote_code=True,
|
||||
speculative_config={
|
||||
"method": "dflash",
|
||||
"model": draft_model,
|
||||
"num_speculative_tokens": 16,
|
||||
"max_model_len": 32768,
|
||||
},
|
||||
max_model_len=32768,
|
||||
max_num_seqs=128,
|
||||
gpu_memory_utilization=0.85,
|
||||
enforce_eager=False,
|
||||
disable_log_stats=False,
|
||||
)
|
||||
|
||||
|
||||
def test_dflash_acceptance_rates(dflash_config):
|
||||
"""
|
||||
E2E test for DFlash (block diffusion) speculative decoding.
|
||||
Runs acceptance rate validation on GSM8k, MT-Bench, and HumanEval
|
||||
comparing against baseline results from the paper (Table 1).
|
||||
See https://github.com/z-lab/dflash/blob/main/benchmark_sglang.py for methodology.
|
||||
"""
|
||||
spec_llm = LLM(**dflash_config)
|
||||
|
||||
max_prompts_per_dataset = 200 # mt-bench has 80, humaneval has 164, truncates gsm8k
|
||||
|
||||
# All scores from Table 1 in https://arxiv.org/pdf/2602.06036
|
||||
expected_acceptance_lengths = {
|
||||
"mt-bench": 4.24,
|
||||
"humaneval": 6.50,
|
||||
"gsm8k": 6.54 * 0.95, # runs with a subset of prompts so extra wide tol here
|
||||
}
|
||||
|
||||
tokenizer = spec_llm.get_tokenizer()
|
||||
for dataset_name, expected_len in expected_acceptance_lengths.items():
|
||||
dataset = load_and_process_dataset(dataset_name)
|
||||
prev_metrics = None
|
||||
acceptance_lengths = []
|
||||
for i in tqdm(
|
||||
range(min(max_prompts_per_dataset, len(dataset))),
|
||||
desc=f"Processing {dataset_name}",
|
||||
):
|
||||
user_content = dataset[i]["turns"][0]
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": user_content}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
# Temp=0, MaxTokens=2048 from the paper
|
||||
spec_llm.generate(
|
||||
[prompt_text],
|
||||
SamplingParams(temperature=0, max_tokens=2048),
|
||||
use_tqdm=False,
|
||||
)
|
||||
current_metrics = spec_llm.get_metrics()
|
||||
acceptance_len = compute_acceptance_len(current_metrics, prev_metrics)
|
||||
prev_metrics = current_metrics
|
||||
acceptance_lengths.append(acceptance_len)
|
||||
|
||||
mean_acceptance_length = sum(acceptance_lengths) / len(acceptance_lengths)
|
||||
expected_len = expected_len * 0.9
|
||||
print(
|
||||
f"DFlash acceptance_len for {dataset_name}: {mean_acceptance_length:.2f}"
|
||||
f" (expected at least {expected_len:.2f})"
|
||||
)
|
||||
|
||||
assert mean_acceptance_length >= expected_len, (
|
||||
f"DFlash acceptance_len for {dataset_name} is below expected threshold:"
|
||||
f"{mean_acceptance_length:.2f} < {expected_len:.2f}"
|
||||
)
|
||||
|
||||
del spec_llm
|
||||
torch.accelerator.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def test_dflash_correctness(dflash_config):
|
||||
"""
|
||||
E2E test for DFlash (block diffusion) speculative decoding.
|
||||
Ensures output correctness on GSM8k, with cudagraphs and batching on.
|
||||
"""
|
||||
spec_llm = LLM(**dflash_config)
|
||||
|
||||
# Evaluate GSM8k accuracy (Qwen3-8B ref: ~87-92% on GSM8k)
|
||||
evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8)
|
||||
|
||||
current_metrics = spec_llm.get_metrics()
|
||||
acceptance_len = compute_acceptance_len(current_metrics)
|
||||
|
||||
# AR is thoroughly validated in test_dflash_acceptance_rates, in a manner consistent
|
||||
# with the DFlash paper. However, that test measures AL per-request and thus runs
|
||||
# with a batch size of 1. To ensure that AL does not collapse with large batch sizes
|
||||
# we enforce a baseline on the AL over the full lm-eval-style GSM8k test.
|
||||
expected_len = 3.5 # Measured is 3.9 to 4.0
|
||||
print(f"DFlash GSM8k correctness test got AL {acceptance_len}")
|
||||
assert acceptance_len >= expected_len, (
|
||||
"DFlash correctness check failed with"
|
||||
f" {acceptance_len=}, expected at least {expected_len}"
|
||||
)
|
||||
|
||||
del spec_llm
|
||||
torch.accelerator.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
@@ -27,6 +27,7 @@ from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.spec_decode.dflash import DFlashProposer
|
||||
from vllm.v1.spec_decode.draft_model import DraftModelProposer
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
@@ -36,6 +37,8 @@ model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
ar_draft_model_dir = "amd/PARD-Llama-3.2-1B" # Compatible with parallel and AR drafting
|
||||
dflash_target_dir = "Qwen/Qwen3-8B"
|
||||
dflash_dir = "z-lab/Qwen3-8B-DFlash-b16"
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
@@ -47,18 +50,29 @@ def _create_proposer(
|
||||
speculative_token_tree: list[tuple[int, ...]] | None = None,
|
||||
parallel_drafting: bool = False,
|
||||
) -> EagleProposer:
|
||||
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
|
||||
|
||||
# Method-dependent setup
|
||||
if method == "eagle":
|
||||
target_model_dir = model_dir
|
||||
draft_model_dir = eagle_dir
|
||||
elif method == "eagle3":
|
||||
target_model_dir = model_dir
|
||||
draft_model_dir = eagle3_dir
|
||||
elif method == "draft_model":
|
||||
target_model_dir = model_dir
|
||||
draft_model_dir = ar_draft_model_dir
|
||||
elif method == "dflash":
|
||||
target_model_dir = dflash_target_dir
|
||||
draft_model_dir = dflash_dir
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {method}")
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=target_model_dir,
|
||||
runner="generate",
|
||||
max_model_len=100,
|
||||
trust_remote_code=(method == "dflash"),
|
||||
)
|
||||
|
||||
spec_token_tree_str = None
|
||||
if speculative_token_tree is not None:
|
||||
assert num_speculative_tokens == len(speculative_token_tree)
|
||||
@@ -92,7 +106,9 @@ def _create_proposer(
|
||||
attention_config=AttentionConfig(backend=attention_backend),
|
||||
)
|
||||
|
||||
if "eagle" in method:
|
||||
if method == "dflash":
|
||||
proposer = DFlashProposer(vllm_config=vllm_config, device=device)
|
||||
elif "eagle" in method:
|
||||
proposer = EagleProposer(vllm_config=vllm_config, device=device)
|
||||
else:
|
||||
proposer = DraftModelProposer(vllm_config=vllm_config, device=device)
|
||||
@@ -1152,3 +1168,136 @@ def test_propose_tree(spec_token_tree):
|
||||
|
||||
# Verify that the draft tokens match our expectations.
|
||||
assert torch.equal(result, expected_tokens)
|
||||
|
||||
|
||||
def test_set_inputs_first_pass_dflash():
|
||||
"""
|
||||
Test for DFlash set_inputs_first_pass.
|
||||
|
||||
DFlash uses cross-attention: context tokens become K/V and only
|
||||
query tokens (bonus + mask) are Q. This tests the DFlash-specific
|
||||
input preparation where:
|
||||
- Context hidden states are stored by reference (no copy)
|
||||
- Query input_ids are [next_token, mask, mask, ...] per request
|
||||
- Context and query positions are written to separate buffers
|
||||
- token_indices_to_sample points to mask token positions only
|
||||
- A new CommonAttentionMetadata is returned with causal=False
|
||||
|
||||
Setup:
|
||||
- 3 requests with query_lens [3, 2, 4]
|
||||
- num_speculative_tokens = 3
|
||||
- num_query_per_req = 4 (1 bonus + 3 mask tokens)
|
||||
- next_token_ids: [100, 200, 300]
|
||||
|
||||
Expected output layout (query tokens only, 12 total):
|
||||
Request 0 (indices 0-3): [100, mask, mask, mask]
|
||||
Request 1 (indices 4-7): [200, mask, mask, mask]
|
||||
Request 2 (indices 8-11): [300, mask, mask, mask]
|
||||
|
||||
Expected positions layout (separate buffers):
|
||||
Context (_context_positions_buffer, 9 tokens): copied from target_positions
|
||||
Query (positions, 12 tokens):
|
||||
Request 0: last_pos=9, query=[10, 11, 12, 13]
|
||||
Request 1: last_pos=7, query=[8, 9, 10, 11]
|
||||
Request 2: last_pos=11, query=[12, 13, 14, 15]
|
||||
"""
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
num_speculative_tokens = 3
|
||||
proposer = _create_proposer("dflash", num_speculative_tokens)
|
||||
mask_token_id = proposer.parallel_drafting_token_id
|
||||
|
||||
# Setup batch with 3 requests
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[10, 8, 12],
|
||||
query_lens=[3, 2, 4],
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=BLOCK_SIZE,
|
||||
device=device,
|
||||
arange_block_indices=True,
|
||||
)
|
||||
|
||||
# Input tensors
|
||||
# Request 0: tokens [10, 11, 12] at positions [7, 8, 9]
|
||||
# Request 1: tokens [20, 21] at positions [6, 7]
|
||||
# Request 2: tokens [30, 31, 32, 33] at positions [8, 9, 10, 11]
|
||||
target_token_ids = torch.tensor(
|
||||
[10, 11, 12, 20, 21, 30, 31, 32, 33], dtype=torch.int32, device=device
|
||||
)
|
||||
target_positions = torch.tensor(
|
||||
[7, 8, 9, 6, 7, 8, 9, 10, 11], dtype=torch.int64, device=device
|
||||
)
|
||||
target_hidden_states = torch.randn(
|
||||
9, proposer.hidden_size, dtype=proposer.dtype, device=device
|
||||
)
|
||||
next_token_ids = torch.tensor([100, 200, 300], dtype=torch.int32, device=device)
|
||||
|
||||
num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass(
|
||||
target_token_ids=target_token_ids,
|
||||
next_token_ids=next_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
token_indices_to_sample=None,
|
||||
cad=common_attn_metadata,
|
||||
num_rejected_tokens_gpu=None,
|
||||
)
|
||||
|
||||
num_query_per_req = 1 + num_speculative_tokens # 4
|
||||
num_context = 9
|
||||
|
||||
# num_tokens is the query-only count
|
||||
assert num_tokens == 3 * num_query_per_req # 12
|
||||
|
||||
# Verify input_ids (query tokens only)
|
||||
# Each request: [next_token, mask, mask, mask]
|
||||
M = mask_token_id
|
||||
expected_input_ids = torch.tensor(
|
||||
[100, M, M, M, 200, M, M, M, 300, M, M, M],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids)
|
||||
|
||||
# Verify context positions (separate buffer): copied from target_positions
|
||||
assert torch.equal(
|
||||
proposer._context_positions_buffer[:num_context], target_positions
|
||||
)
|
||||
|
||||
# Verify query positions (separate buffer, starts at index 0):
|
||||
# req0: last_pos=9, query=[10, 11, 12, 13]
|
||||
# req1: last_pos=7, query=[8, 9, 10, 11]
|
||||
# req2: last_pos=11, query=[12, 13, 14, 15]
|
||||
expected_query_positions = torch.tensor(
|
||||
[10, 11, 12, 13, 8, 9, 10, 11, 12, 13, 14, 15],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
assert torch.equal(
|
||||
proposer.positions[:num_tokens],
|
||||
expected_query_positions,
|
||||
)
|
||||
|
||||
# Verify token_indices_to_sample (mask tokens only, skip bonus at offset 0)
|
||||
# req0: query indices 0-3, mask at 1,2,3
|
||||
# req1: query indices 4-7, mask at 5,6,7
|
||||
# req2: query indices 8-11, mask at 9,10,11
|
||||
expected_token_indices_to_sample = torch.tensor(
|
||||
[1, 2, 3, 5, 6, 7, 9, 10, 11], dtype=torch.int32, device=device
|
||||
)
|
||||
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
|
||||
|
||||
# Verify the new CAD has DFlash-specific properties
|
||||
assert output_cad.causal is False # DFlash requires non-causal attention
|
||||
assert output_cad.num_actual_tokens == num_tokens # query-only count
|
||||
assert output_cad.max_query_len == num_query_per_req
|
||||
|
||||
expected_query_start_loc = torch.tensor(
|
||||
[0, 4, 8, 12], dtype=torch.int32, device=device
|
||||
)
|
||||
assert torch.equal(output_cad.query_start_loc, expected_query_start_loc)
|
||||
|
||||
# Verify hidden states (stored by reference, not copied)
|
||||
assert proposer._dflash_hidden_states is target_hidden_states
|
||||
|
||||
@@ -47,8 +47,11 @@ MTPModelTypes = Literal[
|
||||
"pangu_ultra_moe_mtp",
|
||||
"step3p5_mtp",
|
||||
]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
|
||||
NgramGPUTypes = Literal["ngram_gpu"]
|
||||
DFlashModelTypes = Literal["dflash"]
|
||||
EagleModelTypes = Literal[
|
||||
"eagle", "eagle3", "extract_hidden_states", MTPModelTypes, DFlashModelTypes
|
||||
]
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"medusa",
|
||||
@@ -206,7 +209,11 @@ class SpeculativeConfig:
|
||||
factors: list[Any] = []
|
||||
# Eagle3 and extract_hidden_states affect the computation graph because
|
||||
# they return intermediate hidden states in addition to the final hidden state.
|
||||
uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states")
|
||||
uses_aux_hidden_states = self.method in (
|
||||
"eagle3",
|
||||
"extract_hidden_states",
|
||||
"dflash",
|
||||
)
|
||||
factors.append(uses_aux_hidden_states)
|
||||
|
||||
# The specific layers used also affect the computation graph
|
||||
@@ -490,7 +497,7 @@ class SpeculativeConfig:
|
||||
)
|
||||
|
||||
# Automatically detect the method
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
if self.method in ("eagle", "eagle3", "dflash"):
|
||||
pass
|
||||
# examples:
|
||||
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||
@@ -500,6 +507,8 @@ class SpeculativeConfig:
|
||||
self.method = "eagle"
|
||||
elif "eagle3" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle3"
|
||||
elif "dflash" in self.draft_model_config.model.lower():
|
||||
self.method = "dflash"
|
||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||
self.method = "medusa"
|
||||
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
|
||||
@@ -532,7 +541,7 @@ class SpeculativeConfig:
|
||||
)
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
if self.method in ("eagle", "eagle3", "dflash"):
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
from vllm.transformers_utils.configs.speculators import (
|
||||
SpeculatorsConfig,
|
||||
@@ -552,6 +561,9 @@ class SpeculativeConfig:
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
self.update_arch_()
|
||||
|
||||
if self.method == "dflash":
|
||||
self.parallel_drafting = True
|
||||
|
||||
if self.num_speculative_tokens is not None and hasattr(
|
||||
self.draft_model_config.hf_config, "num_lookahead_tokens"
|
||||
):
|
||||
@@ -807,7 +819,7 @@ class SpeculativeConfig:
|
||||
"kimi_k25",
|
||||
]
|
||||
if (
|
||||
self.method in ("eagle3", "extract_hidden_states")
|
||||
self.method in ("eagle3", "extract_hidden_states", "dflash")
|
||||
and self.target_model_config
|
||||
and not any(
|
||||
supported_model in self.target_model_config.hf_text_config.model_type
|
||||
@@ -855,7 +867,10 @@ class SpeculativeConfig:
|
||||
return slots_per_req
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3", "mtp")
|
||||
return self.method in ("eagle", "eagle3", "mtp", "dflash")
|
||||
|
||||
def use_dflash(self) -> bool:
|
||||
return self.method == "dflash"
|
||||
|
||||
def uses_draft_model(self) -> bool:
|
||||
return self.method == "draft_model"
|
||||
|
||||
@@ -1327,6 +1327,26 @@ class VllmConfig:
|
||||
max_num_batched_tokens - scheduled_token_delta
|
||||
)
|
||||
|
||||
if self.scheduler_config.max_num_scheduled_tokens <= 0:
|
||||
raise ValueError(
|
||||
"max_num_scheduled_tokens is set to"
|
||||
f" {self.scheduler_config.max_num_scheduled_tokens} based on"
|
||||
" the speculative decoding settings, which does not allow"
|
||||
" any tokens to be scheduled. Increase max_num_batched_tokens"
|
||||
" to accommodate the additional draft token slots, or decrease"
|
||||
" num_speculative_tokens or max_num_seqs."
|
||||
)
|
||||
if self.scheduler_config.max_num_scheduled_tokens < 8192:
|
||||
logger.warning_once(
|
||||
"max_num_scheduled_tokens is set to"
|
||||
f" {self.scheduler_config.max_num_scheduled_tokens} based on"
|
||||
" the speculative decoding settings. This may lead to suboptimal"
|
||||
" performance. Consider increasing max_num_batched_tokens to"
|
||||
" accommodate the additional draft token slots, or decrease"
|
||||
" num_speculative_tokens or max_num_seqs.",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
max_num_scheduled_tokens = self.scheduler_config.max_num_scheduled_tokens
|
||||
if max_num_batched_tokens < max_num_scheduled_tokens + (
|
||||
self.speculative_config.max_num_new_slots_for_drafting
|
||||
|
||||
@@ -285,6 +285,7 @@ class Qwen3ForCausalLM(
|
||||
|
||||
self.config = config
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen3Model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
|
||||
619
vllm/model_executor/models/qwen3_dflash.py
Normal file
619
vllm/model_executor/models/qwen3_dflash.py
Normal file
@@ -0,0 +1,619 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import Qwen3Config
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.transformers_utils.config import set_default_rope_theta
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from .qwen3 import Qwen3ForCausalLM
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
get_draft_quant_config,
|
||||
maybe_prefix,
|
||||
process_eagle_weight,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DFlashQwen3Attention(nn.Module):
|
||||
"""Attention for DFlash speculative decoding.
|
||||
|
||||
Context KVs are pre-inserted into the KV cache before the forward pass.
|
||||
This layer handles only query tokens via standard attention.
|
||||
Adapted from Qwen3Attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_parameters: dict,
|
||||
max_position: int = 4096 * 32,
|
||||
head_dim: int | None = None,
|
||||
rms_norm_eps: float = 1e-06,
|
||||
attention_bias: bool = False,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_name = prefix
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=attention_bias, # DFlash has o_proj bias when using attention bias
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters=rope_parameters,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=attn_type,
|
||||
)
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""DFlash attention assumes that the KV cache is already populated
|
||||
with the context K/V from the target model's hidden states. This forward op
|
||||
computes attention for the query tokens only.
|
||||
See also: precompute_and_store_context_kv"""
|
||||
qkv = F.linear(hidden_states, self.qkv_proj.weight, self.qkv_proj.bias)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Per-head RMSNorm
|
||||
q_shape, k_shape = q.shape, k.shape
|
||||
q = self.q_norm(
|
||||
q.view(*q_shape[:-1], q_shape[-1] // self.head_dim, self.head_dim)
|
||||
).view(q_shape)
|
||||
k = self.k_norm(
|
||||
k.view(*k_shape[:-1], k_shape[-1] // self.head_dim, self.head_dim)
|
||||
).view(k_shape)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class DFlashQwen3DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
config: Qwen3Config,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
set_default_rope_theta(config, default_theta=1000000)
|
||||
attn_type = AttentionType.DECODER
|
||||
|
||||
self.self_attn = DFlashQwen3Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
attention_bias=getattr(config, "attention_bias", False),
|
||||
head_dim=getattr(config, "head_dim", None),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_parameters=config.rope_parameters,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
)
|
||||
self.mlp = Qwen3MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is not None:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
else:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class DFlashQwen3Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
start_layer_id: int = 0,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
self.quant_config = get_draft_quant_config(vllm_config)
|
||||
|
||||
drafter_config = getattr(self.config, "eagle_config", {})
|
||||
drafter_config.update(getattr(self.config, "dflash_config", {}))
|
||||
|
||||
if drafter_config is not None and "use_aux_hidden_state" in drafter_config:
|
||||
self.use_aux_hidden_state = drafter_config["use_aux_hidden_state"]
|
||||
else:
|
||||
self.use_aux_hidden_state = True
|
||||
|
||||
current_vllm_config = get_current_vllm_config()
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DFlashQwen3DecoderLayer(
|
||||
current_vllm_config,
|
||||
prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
|
||||
config=self.config,
|
||||
)
|
||||
for layer_idx in range(self.config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
if self.use_aux_hidden_state:
|
||||
num_features_to_use = self.config.num_hidden_layers
|
||||
if "target_layer_ids" in drafter_config:
|
||||
num_features_to_use = len(drafter_config["target_layer_ids"])
|
||||
elif "layer_ids" in drafter_config:
|
||||
num_features_to_use = len(drafter_config["layer_ids"])
|
||||
if hasattr(self.config, "target_hidden_size"):
|
||||
fc_input_size = self.config.target_hidden_size * num_features_to_use
|
||||
else:
|
||||
fc_input_size = self.config.hidden_size * num_features_to_use
|
||||
self.fc = ReplicatedLinear(
|
||||
input_size=fc_input_size,
|
||||
output_size=self.config.hidden_size,
|
||||
bias=False,
|
||||
params_dtype=vllm_config.model_config.dtype,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "fc"),
|
||||
return_bias=False,
|
||||
)
|
||||
self.hidden_norm = RMSNorm(
|
||||
self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps,
|
||||
)
|
||||
self.norm = RMSNorm(
|
||||
self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def _build_fused_kv_buffers(self) -> None:
|
||||
"""Build fused weight buffers for precompute_and_store_context_kv.
|
||||
|
||||
Must be called after weights are loaded. Stacks the KV-projection
|
||||
weights, K-norm weights, and RoPE parameters from every attention
|
||||
layer so that precompute_and_store_context_kv can run one fused
|
||||
GEMM for all layers at once. Also aliases the weight of the hidden_norm.
|
||||
"""
|
||||
layers_attn = [layer.self_attn for layer in self.layers]
|
||||
attn0 = layers_attn[0]
|
||||
has_bias = attn0.qkv_proj.bias is not None
|
||||
|
||||
self._hidden_norm_weight = self.hidden_norm.weight.data
|
||||
|
||||
# KV projection weights: [num_layers * 2 * kv_size, hidden_size]
|
||||
kv_weights = [a.qkv_proj.weight[a.q_size :] for a in layers_attn]
|
||||
self._fused_kv_weight = torch.cat(kv_weights, dim=0)
|
||||
if has_bias:
|
||||
kv_biases = [a.qkv_proj.bias[a.q_size :] for a in layers_attn]
|
||||
self._fused_kv_bias: torch.Tensor | None = torch.cat(kv_biases, dim=0)
|
||||
else:
|
||||
self._fused_kv_bias = None
|
||||
|
||||
# K-norm weights: list of [head_dim] tensors, one per layer.
|
||||
self._k_norm_weights = [a.k_norm.weight.data for a in layers_attn]
|
||||
|
||||
# RoPE parameters
|
||||
self._rope_head_size = attn0.rotary_emb.head_size
|
||||
self._rope_cos_sin_cache = attn0.rotary_emb.cos_sin_cache
|
||||
self._rope_is_neox = attn0.rotary_emb.is_neox_style
|
||||
# Validation that RoPE params are the same across all layers
|
||||
for attn in layers_attn[1:]:
|
||||
assert (
|
||||
attn.rotary_emb.head_size == self._rope_head_size
|
||||
and attn.rotary_emb.is_neox_style == self._rope_is_neox
|
||||
), "All layers must have the same RoPE parameters for DFlash precomputation"
|
||||
|
||||
# Layer metadata
|
||||
self._num_attn_layers = len(layers_attn)
|
||||
self._kv_size = attn0.kv_size
|
||||
self._head_dim = attn0.head_dim
|
||||
self._num_kv_heads = attn0.num_kv_heads
|
||||
self._rms_norm_eps = attn0.q_norm.variance_epsilon
|
||||
# Validation that all layers have the same attention config
|
||||
for attn in layers_attn[1:]:
|
||||
assert (
|
||||
attn.kv_size == self._kv_size
|
||||
and attn.head_dim == self._head_dim
|
||||
and attn.num_kv_heads == self._num_kv_heads
|
||||
and attn.q_norm.variance_epsilon == self._rms_norm_eps
|
||||
), "All layers must have the same attn config for DFlash precomputation"
|
||||
|
||||
# References to inner Attention layers for direct cache writes
|
||||
self._attn_layers = [layer.self_attn.attn for layer in self.layers]
|
||||
|
||||
def precompute_and_store_context_kv(
|
||||
self,
|
||||
context_states: torch.Tensor,
|
||||
context_positions: torch.Tensor,
|
||||
context_slot_mapping: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
"""Precompute K/V for context states write them into each layer's KV cache.
|
||||
|
||||
Input context states are projected to K/V, normed, and have RoPE applied.
|
||||
Since the context shape is different than the query shape, we can't rely on the
|
||||
regular forward pass to apply torch.compile and CUDA graphs to this section.
|
||||
As such, this function is optimized to minimize the number of torch ops present:
|
||||
we use fused vLLM kernels for RMSNorm and RoPE, fuse the GEMM into one
|
||||
large projection, and avoid cloning buffers (with .contiguous()) where possible.
|
||||
|
||||
When context_slot_mapping is None (e.g. during dummy_run) only
|
||||
the computation runs, and no K/V is written to cache.
|
||||
"""
|
||||
if not hasattr(self, "_num_attn_layers"):
|
||||
logger.warning_once(
|
||||
"DFlash buffer initialization was skipped. If dummy weights are not "
|
||||
"in use, this may indicate an error in weight loading."
|
||||
)
|
||||
self._build_fused_kv_buffers()
|
||||
|
||||
num_ctx = context_states.shape[0]
|
||||
L = self._num_attn_layers
|
||||
kv = self._kv_size
|
||||
hd = self._head_dim
|
||||
nkv = self._num_kv_heads
|
||||
|
||||
# --- Fused KV projection (one GEMM for all layers) ---
|
||||
normed_context_states = torch.empty_like(context_states)
|
||||
ops.rms_norm(
|
||||
normed_context_states,
|
||||
context_states,
|
||||
self._hidden_norm_weight,
|
||||
self._rms_norm_eps,
|
||||
)
|
||||
all_kv_flat = F.linear(
|
||||
normed_context_states, self._fused_kv_weight, self._fused_kv_bias
|
||||
)
|
||||
# Single contiguous copy that separates K/V and transposes to
|
||||
# layer-major layout. Result: [2, L, num_ctx, nkv, hd] contiguous.
|
||||
# Indexing dim-0 gives contiguous [L, num_ctx, nkv, hd] for K and V.
|
||||
all_kv = (
|
||||
all_kv_flat.view(num_ctx, L, 2, nkv, hd).permute(2, 1, 0, 3, 4).contiguous()
|
||||
)
|
||||
all_k = all_kv[0] # [L, num_ctx, nkv, hd], contiguous
|
||||
all_v = all_kv[1] # [L, num_ctx, nkv, hd], contiguous
|
||||
|
||||
# --- Per-layer RMSNorm K (3D: [num_ctx, nkv, hd] per layer) ---
|
||||
all_k_normed = torch.empty_like(all_k)
|
||||
for i in range(L):
|
||||
ops.rms_norm(
|
||||
all_k_normed[i],
|
||||
all_k[i],
|
||||
self._k_norm_weights[i],
|
||||
self._rms_norm_eps,
|
||||
)
|
||||
|
||||
# --- Fused RoPE across all layers ---
|
||||
# View as [L * num_ctx, kv] so RoPE sees one big batch (no copy).
|
||||
# In-place RoPE: pass K as the "query" arg with key=None.
|
||||
all_k_flat = all_k_normed.view(L * num_ctx, kv)
|
||||
positions_repeated = context_positions.repeat(L)
|
||||
cos_sin_cache = self._rope_cos_sin_cache
|
||||
if cos_sin_cache.dtype != all_k_flat.dtype:
|
||||
cos_sin_cache = cos_sin_cache.to(dtype=all_k_flat.dtype)
|
||||
ops.rotary_embedding(
|
||||
positions_repeated,
|
||||
all_k_flat,
|
||||
None,
|
||||
self._rope_head_size,
|
||||
cos_sin_cache,
|
||||
self._rope_is_neox,
|
||||
)
|
||||
|
||||
if context_slot_mapping is None:
|
||||
return
|
||||
|
||||
# --- Per-layer cache insert ---
|
||||
all_k_final = all_k_flat.view(L, num_ctx, nkv, hd)
|
||||
for i in range(L):
|
||||
attn = self._attn_layers[i]
|
||||
kv_cache = attn.kv_cache
|
||||
attn.impl.do_kv_cache_update(
|
||||
attn,
|
||||
all_k_final[i],
|
||||
all_v[i],
|
||||
kv_cache,
|
||||
context_slot_mapping,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
input_embeds = self.embed_input_ids(input_ids)
|
||||
|
||||
hidden_states = input_embeds
|
||||
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "midlayer." in name:
|
||||
name = name.replace("midlayer.", "layers.0.")
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class DFlashQwen3ForCausalLM(Qwen3ForCausalLM):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
if getattr(self.config, "draft_vocab_size", None) is None:
|
||||
self.config.draft_vocab_size = getattr(self.config, "vocab_size", None)
|
||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
self.config.target_layer_count = target_layer_num
|
||||
self.model = DFlashQwen3Model(
|
||||
vllm_config=vllm_config,
|
||||
prefix="model",
|
||||
start_layer_id=target_layer_num,
|
||||
)
|
||||
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.config.draft_vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.config.draft_vocab_size, scale=logit_scale
|
||||
)
|
||||
self.draft_id_to_target_id = None
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: NestedTensors | None = None,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids, positions, inputs_embeds)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
if self.draft_id_to_target_id is None:
|
||||
return logits
|
||||
|
||||
base = torch.arange(self.config.draft_vocab_size, device=logits.device)
|
||||
targets = base + self.draft_id_to_target_id
|
||||
logits_new = logits.new_full(
|
||||
(logits.shape[0], self.config.vocab_size),
|
||||
float("-inf"),
|
||||
)
|
||||
logits_new[:, targets] = logits
|
||||
return logits_new
|
||||
|
||||
def precompute_and_store_context_kv(
|
||||
self,
|
||||
context_states: torch.Tensor,
|
||||
context_positions: torch.Tensor,
|
||||
context_slot_mapping: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
"""Precompute projected + RoPE'd K/V and write to cache."""
|
||||
self.model.precompute_and_store_context_kv(
|
||||
context_states, context_positions, context_slot_mapping
|
||||
)
|
||||
|
||||
def combine_hidden_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if not self.model.use_aux_hidden_state:
|
||||
return hidden_states
|
||||
needs_squeeze = hidden_states.dim() == 1
|
||||
if needs_squeeze:
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
result = self.model.fc(hidden_states)
|
||||
if needs_squeeze:
|
||||
result = result.squeeze(0)
|
||||
return result
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
model_weights = {}
|
||||
includes_draft_id_mapping = False
|
||||
includes_embed_tokens = False
|
||||
for name, loaded_weight in weights:
|
||||
assert "mask_hidden" not in name, (
|
||||
"DFlash should use mask_token_id to embed the padding hidden state"
|
||||
)
|
||||
if "t2d" in name:
|
||||
continue
|
||||
if "d2t" in name:
|
||||
name = name.replace("d2t", "draft_id_to_target_id")
|
||||
includes_draft_id_mapping = True
|
||||
elif "lm_head" not in name:
|
||||
name = "model." + name
|
||||
if "embed_tokens" in name:
|
||||
includes_embed_tokens = True
|
||||
model_weights[name] = loaded_weight
|
||||
process_eagle_weight(self, name)
|
||||
|
||||
skip_substrs = []
|
||||
if not includes_draft_id_mapping:
|
||||
skip_substrs.append("draft_id_to_target_id")
|
||||
if not includes_embed_tokens:
|
||||
skip_substrs.append("embed_tokens")
|
||||
if not self.model.use_aux_hidden_state:
|
||||
skip_substrs.append("fc.")
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=None,
|
||||
skip_substrs=skip_substrs,
|
||||
)
|
||||
loader.load_weights(model_weights.items())
|
||||
self.model._build_fused_kv_buffers()
|
||||
@@ -56,6 +56,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
|
||||
|
||||
from .interfaces import (
|
||||
EagleModelMixin,
|
||||
HasInnerState,
|
||||
IsHybrid,
|
||||
MixtureOfExperts,
|
||||
@@ -454,7 +455,7 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Qwen3NextModel(nn.Module):
|
||||
class Qwen3NextModel(nn.Module, EagleModelMixin):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@@ -492,8 +493,6 @@ class Qwen3NextModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.aux_hidden_state_layers: tuple[int, ...] = ()
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
@@ -515,20 +514,19 @@ class Qwen3NextModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||
for layer_idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer),
|
||||
start=self.start_layer,
|
||||
):
|
||||
if layer_idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(
|
||||
hidden_states + residual if residual is not None else hidden_states
|
||||
)
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
)
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, layer_idx + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
|
||||
@@ -546,6 +546,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
||||
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
||||
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
||||
"DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"),
|
||||
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
|
||||
@@ -62,9 +62,20 @@ class EAGLEConfig(PretrainedConfig):
|
||||
else f"Eagle3{arch}"
|
||||
for arch in self.model.architectures
|
||||
]
|
||||
elif method == "dflash":
|
||||
assert self.model is not None, (
|
||||
"model should not be None when method is dflash"
|
||||
)
|
||||
kwargs["architectures"] = [
|
||||
arch
|
||||
if arch.startswith("DFlash") or arch.endswith("DFlash")
|
||||
else f"DFlash{arch}"
|
||||
for arch in self.model.architectures
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid method {method}. Supported methods are eagle and eagle3."
|
||||
f"Invalid method {method}. Supported methods are "
|
||||
"eagle, eagle3, and dflash."
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -220,6 +220,17 @@ class AttentionBackend(ABC):
|
||||
def supports_per_head_quant_scales(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_non_causal(cls) -> bool:
|
||||
"""Check if backend supports non-causal (bidirectional) attention
|
||||
for decoder models.
|
||||
|
||||
Unlike ENCODER_ONLY attention type which implies a different
|
||||
execution model, this refers to non-causal attention within the
|
||||
standard paged-KV-cache decoder path.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""Check if backend supports a given attention type.
|
||||
@@ -261,6 +272,7 @@ class AttentionBackend(ABC):
|
||||
use_per_head_quant_scales: bool,
|
||||
device_capability: "DeviceCapability",
|
||||
attn_type: str,
|
||||
use_non_causal: bool = False,
|
||||
) -> list[str]:
|
||||
invalid_reasons = []
|
||||
if not cls.supports_head_size(head_size):
|
||||
@@ -293,6 +305,8 @@ class AttentionBackend(ABC):
|
||||
invalid_reasons.append("compute capability not supported")
|
||||
if not cls.supports_attn_type(attn_type):
|
||||
invalid_reasons.append(f"attention type {attn_type} not supported")
|
||||
if use_non_causal and not cls.supports_non_causal():
|
||||
invalid_reasons.append("non-causal attention not supported")
|
||||
combination_reason = cls.supports_combination(
|
||||
head_size,
|
||||
dtype,
|
||||
|
||||
@@ -101,6 +101,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN"
|
||||
|
||||
@classmethod
|
||||
def supports_non_causal(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""FlashAttention supports all attention types."""
|
||||
|
||||
@@ -29,6 +29,7 @@ class AttentionSelectorConfig(NamedTuple):
|
||||
use_mm_prefix: bool = False
|
||||
use_per_head_quant_scales: bool = False
|
||||
attn_type: str = AttentionType.DECODER
|
||||
use_non_causal: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
@@ -41,7 +42,8 @@ class AttentionSelectorConfig(NamedTuple):
|
||||
f"use_sparse={self.use_sparse}, "
|
||||
f"use_mm_prefix={self.use_mm_prefix}, "
|
||||
f"use_per_head_quant_scales={self.use_per_head_quant_scales}, "
|
||||
f"attn_type={self.attn_type})"
|
||||
f"attn_type={self.attn_type}, "
|
||||
f"use_non_causal={self.use_non_causal})"
|
||||
)
|
||||
|
||||
|
||||
@@ -76,6 +78,11 @@ def get_attn_backend(
|
||||
else:
|
||||
block_size = None
|
||||
|
||||
speculative_config = vllm_config.speculative_config
|
||||
use_non_causal = (
|
||||
speculative_config is not None and speculative_config.method == "dflash"
|
||||
)
|
||||
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
@@ -87,6 +94,7 @@ def get_attn_backend(
|
||||
use_mm_prefix=use_mm_prefix,
|
||||
use_per_head_quant_scales=use_per_head_quant_scales,
|
||||
attn_type=attn_type or AttentionType.DECODER,
|
||||
use_non_causal=use_non_causal,
|
||||
)
|
||||
|
||||
return _cached_get_attn_backend(
|
||||
|
||||
282
vllm/v1/spec_decode/dflash.py
Normal file
282
vllm/v1/spec_decode/dflash.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer
|
||||
from vllm.v1.spec_decode.utils import copy_and_expand_dflash_inputs_kernel
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DFlashProposer(SpecDecodeBaseProposer):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
assert vllm_config.speculative_config is not None
|
||||
assert vllm_config.speculative_config.method == "dflash"
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
pass_hidden_states_to_model=True,
|
||||
runner=runner,
|
||||
)
|
||||
|
||||
# Only next_token_ids and mask tokens are query tokens, all other context is K/V
|
||||
self.max_query_tokens = self.max_batch_size * (1 + self.num_speculative_tokens)
|
||||
# Positions covers both context states + query states
|
||||
self.max_positions = self.max_num_tokens + self.max_query_tokens
|
||||
|
||||
# Separate context buffers to keep query buffer addresses stable for CUDA graphs
|
||||
self._context_slot_mapping_buffer = torch.zeros(
|
||||
self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self._slot_mapping_buffer = torch.zeros(
|
||||
self.max_query_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self._context_positions_buffer = torch.zeros(
|
||||
self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self.positions = torch.zeros(
|
||||
self.max_query_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.arange = torch.arange(
|
||||
self.max_positions + 1, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
# For DFlash we use the input embeddings to embed the mask token
|
||||
self.parallel_drafting_hidden_state_tensor = None
|
||||
|
||||
@override
|
||||
def _raise_if_multimodal(self):
|
||||
# Override to allow multimodal inputs since DFlash supports Qwen3.5 models
|
||||
# Support for multimodal inputs has not been tested.
|
||||
pass
|
||||
|
||||
@override
|
||||
def set_inputs_first_pass(
|
||||
self,
|
||||
target_token_ids: torch.Tensor,
|
||||
next_token_ids: torch.Tensor,
|
||||
target_positions: torch.Tensor,
|
||||
target_hidden_states: torch.Tensor,
|
||||
token_indices_to_sample: torch.Tensor | None,
|
||||
cad: CommonAttentionMetadata,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None,
|
||||
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
|
||||
# DFlash cross-attention: context K/V from target hidden states,
|
||||
# Q from query embeddings (bonus + mask tokens).
|
||||
batch_size = cad.batch_size()
|
||||
num_context = target_token_ids.shape[0]
|
||||
num_query_per_req = 1 + self.num_speculative_tokens
|
||||
num_query_total = batch_size * num_query_per_req
|
||||
|
||||
# Store for build_model_inputs_first_pass to use
|
||||
self._dflash_num_context = num_context
|
||||
|
||||
# We don't need to copy into a buffer here since the context preprocessing
|
||||
# does not run in a CUDA graph
|
||||
self._dflash_hidden_states = target_hidden_states
|
||||
|
||||
token_indices_to_sample = torch.empty(
|
||||
batch_size * self.num_speculative_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Launch fused triton kernel for input_ids, positions, slot_mapping,
|
||||
# and token_indices_to_sample
|
||||
max_ctx_per_req = cad.max_query_len
|
||||
max_tokens_per_req = max_ctx_per_req + num_query_per_req
|
||||
BLOCK_SIZE = min(256, triton.next_power_of_2(max_tokens_per_req))
|
||||
num_blocks = triton.cdiv(max_tokens_per_req, BLOCK_SIZE)
|
||||
grid = (batch_size, num_blocks)
|
||||
|
||||
has_num_rejected = num_rejected_tokens_gpu is not None
|
||||
copy_and_expand_dflash_inputs_kernel[grid](
|
||||
# Inputs
|
||||
next_token_ids_ptr=next_token_ids,
|
||||
target_positions_ptr=target_positions,
|
||||
# Outputs
|
||||
out_input_ids_ptr=self.input_ids,
|
||||
out_context_positions_ptr=self._context_positions_buffer,
|
||||
out_query_positions_ptr=self.positions,
|
||||
out_context_slot_mapping_ptr=self._context_slot_mapping_buffer,
|
||||
out_query_slot_mapping_ptr=self._slot_mapping_buffer,
|
||||
out_token_indices_ptr=token_indices_to_sample,
|
||||
# Block table
|
||||
block_table_ptr=cad.block_table_tensor,
|
||||
block_table_stride=cad.block_table_tensor.stride(0),
|
||||
# Metadata
|
||||
query_start_loc_ptr=cad.query_start_loc,
|
||||
num_rejected_tokens_ptr=(
|
||||
num_rejected_tokens_gpu if has_num_rejected else 0
|
||||
),
|
||||
# Scalars
|
||||
parallel_drafting_token_id=self.parallel_drafting_token_id,
|
||||
block_size=self.block_size,
|
||||
num_query_per_req=num_query_per_req,
|
||||
num_speculative_tokens=self.num_speculative_tokens,
|
||||
total_input_tokens=num_context,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
HAS_NUM_REJECTED=has_num_rejected,
|
||||
)
|
||||
|
||||
query_slot_mapping = self._slot_mapping_buffer[:num_query_total]
|
||||
new_query_start_loc = self.arange[: batch_size + 1] * num_query_per_req
|
||||
|
||||
# In padded mode, cad.seq_lens includes rejected tokens. Subtract
|
||||
# them so attention only sees the valid prefix of context states.
|
||||
effective_seq_lens = cad.seq_lens
|
||||
if has_num_rejected:
|
||||
effective_seq_lens = effective_seq_lens - num_rejected_tokens_gpu
|
||||
|
||||
new_cad = CommonAttentionMetadata(
|
||||
query_start_loc=new_query_start_loc,
|
||||
seq_lens=effective_seq_lens + num_query_per_req,
|
||||
query_start_loc_cpu=(
|
||||
torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone()
|
||||
* num_query_per_req
|
||||
),
|
||||
_seq_lens_cpu=None,
|
||||
_num_computed_tokens_cpu=None,
|
||||
num_reqs=cad.num_reqs,
|
||||
num_actual_tokens=num_query_total,
|
||||
max_query_len=num_query_per_req,
|
||||
max_seq_len=cad.max_seq_len + num_query_per_req,
|
||||
block_table_tensor=cad.block_table_tensor,
|
||||
slot_mapping=query_slot_mapping,
|
||||
causal=False, # Non-causal attention is required for DFlash
|
||||
)
|
||||
|
||||
return num_query_total, token_indices_to_sample, new_cad
|
||||
|
||||
@override
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
is_graph_capturing: bool = False,
|
||||
slot_mappings: dict[str, torch.Tensor] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Key differences to default dummy_run:
|
||||
- Only one forward pass due to parallel drafting
|
||||
- DFlash uses context states as unpadded metadata, so hidden_states will
|
||||
use the unpadded num_tokens instead of num_input_tokens
|
||||
- max_query_tokens is quite small, DFlash only sees spec tokens as queries
|
||||
- Multimodal inputs are not currently supported
|
||||
"""
|
||||
num_query_tokens = min(num_tokens, self.max_query_tokens)
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_query_tokens, use_cudagraphs=use_cudagraphs
|
||||
)
|
||||
)
|
||||
|
||||
# Slot mapping sized to num_input_tokens (query only), matching
|
||||
# the K/V tensor size from the model forward. Context KVs are
|
||||
# pre-inserted separately and don't flow through the model.
|
||||
if (
|
||||
self._draft_attn_layer_names
|
||||
and slot_mappings is not None
|
||||
and next(iter(self._draft_attn_layer_names)) in slot_mappings
|
||||
):
|
||||
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
|
||||
else:
|
||||
slot_mapping_dict = slot_mappings or {}
|
||||
|
||||
# Context and query positions use separate buffers; no copy needed.
|
||||
context_positions = self._context_positions_buffer[:num_tokens]
|
||||
# Context states will be passed directly to the precomputation without
|
||||
# going through the buffer, since no CUDA graph is used for the precomputation.
|
||||
# For the dummy run, we use the dummy buffer.
|
||||
context_states = self.hidden_states[:num_tokens]
|
||||
|
||||
# Run the KV projection (GEMM + norms + RoPE) for memory profiling,
|
||||
self.model.precompute_and_store_context_kv(context_states, context_positions)
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=slot_mapping_dict,
|
||||
):
|
||||
self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self._get_positions(num_input_tokens),
|
||||
inputs_embeds=None,
|
||||
)
|
||||
|
||||
@override
|
||||
def build_model_inputs_first_pass(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_input_tokens: int,
|
||||
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
# Context and query positions/slots were written to separate
|
||||
# buffers by the kernel — no copy needed.
|
||||
num_context = self._dflash_num_context
|
||||
|
||||
# Pre-insert context KVs directly into cache
|
||||
self.model.precompute_and_store_context_kv(
|
||||
self._dflash_hidden_states, # Shape is already [num_context, hidden_size]
|
||||
self._context_positions_buffer[:num_context],
|
||||
self._context_slot_mapping_buffer[:num_context],
|
||||
)
|
||||
return (
|
||||
dict(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self._get_positions(num_input_tokens),
|
||||
inputs_embeds=None,
|
||||
),
|
||||
num_input_tokens,
|
||||
)
|
||||
|
||||
@override
|
||||
def build_per_layer_attn_metadata(
|
||||
self, cad: CommonAttentionMetadata, draft_index: int = 0
|
||||
) -> dict[str, object]:
|
||||
per_layer_attention_metadata = super().build_per_layer_attn_metadata(
|
||||
cad, draft_index
|
||||
)
|
||||
for layer_name, attn_metadata in per_layer_attention_metadata.items():
|
||||
assert getattr(attn_metadata, "causal", None) is False, (
|
||||
f"Attention metadata for layer {layer_name} does not have"
|
||||
" non-causal support, which is required for DFlash."
|
||||
" Consider using a different attention backend, such as FlashAttention."
|
||||
)
|
||||
return per_layer_attention_metadata
|
||||
|
||||
@override
|
||||
def _get_eagle3_use_aux_hidden_state_from_config(self):
|
||||
use_aux_hidden_state = True
|
||||
dflash_config = getattr(
|
||||
self.draft_model_config.hf_config, "dflash_config", None
|
||||
)
|
||||
if dflash_config is not None:
|
||||
use_aux_hidden_state = dflash_config.get("use_aux_hidden_state", True)
|
||||
return use_aux_hidden_state
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
from importlib.util import find_spec
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -23,6 +23,7 @@ from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
@@ -83,13 +84,15 @@ class SpecDecodeBaseProposer:
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
|
||||
|
||||
# Unifying eagle, draft model, and parallel drafting support
|
||||
# Unifying eagle, draft model, and parallel drafting support.
|
||||
# DFlash always uses parallel drafting (all tokens in one pass),
|
||||
# but has an additional slot for the next_token_id (does not shift like EAGLE)
|
||||
self.parallel_drafting: bool = self.speculative_config.parallel_drafting
|
||||
self.extra_slots_per_request = (
|
||||
1 if not self.parallel_drafting else self.num_speculative_tokens
|
||||
)
|
||||
self.net_num_new_slots_per_request = self.extra_slots_per_request - (
|
||||
1 if self.pass_hidden_states_to_model else 0
|
||||
1 if (self.pass_hidden_states_to_model and self.method != "dflash") else 0
|
||||
)
|
||||
self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0
|
||||
|
||||
@@ -101,10 +104,14 @@ class SpecDecodeBaseProposer:
|
||||
self.speculative_config.use_local_argmax_reduction
|
||||
)
|
||||
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
||||
|
||||
# Can be specialized by methods like DFlash to reduce the limit
|
||||
self.max_query_tokens = self.max_num_tokens
|
||||
self.max_positions = self.max_num_tokens
|
||||
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
@@ -146,18 +153,20 @@ class SpecDecodeBaseProposer:
|
||||
# 1D-RoPE.
|
||||
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||
self.mrope_positions = torch.zeros(
|
||||
(3, self.max_num_tokens + 1), dtype=torch.int64, device=device
|
||||
(3, self.max_positions + 1), dtype=torch.int64, device=device
|
||||
)
|
||||
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
|
||||
self.xdrope_positions = torch.zeros(
|
||||
(self.uses_xdrope_dim, self.max_num_tokens + 1),
|
||||
(self.uses_xdrope_dim, self.max_positions + 1),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
# RoPE need (max_num_tokens,)
|
||||
self.positions = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int64, device=device
|
||||
self.max_positions,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self.hidden_states = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
|
||||
@@ -168,7 +177,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
|
||||
max_num_slots_for_arange = max(self.max_batch_size + 1, self.max_num_tokens)
|
||||
self.arange = torch.arange(
|
||||
max_num_slots_for_arange, device=device, dtype=torch.int32
|
||||
)
|
||||
@@ -200,7 +209,7 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
|
||||
self.backup_next_token_ids = CpuGpuBuffer(
|
||||
max_batch_size,
|
||||
self.max_batch_size,
|
||||
dtype=torch.int32,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
device=device,
|
||||
@@ -208,7 +217,9 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
|
||||
self._slot_mapping_buffer = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int64, device=device
|
||||
self.max_positions,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Determine allowed attention backends once during initialization.
|
||||
@@ -275,7 +286,7 @@ class SpecDecodeBaseProposer:
|
||||
# Precompute draft position offsets in flattened tree.
|
||||
self.tree_draft_pos_offsets = torch.arange(
|
||||
1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
|
||||
).repeat(max_batch_size, 1)
|
||||
).repeat(self.max_batch_size, 1)
|
||||
|
||||
def _raise_if_padded_drafter_batch_disabled(self):
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
@@ -305,14 +316,19 @@ class SpecDecodeBaseProposer:
|
||||
# for those masked slots.
|
||||
|
||||
model_hf_config = self.draft_model_config.hf_config
|
||||
if hasattr(model_hf_config, "pard_token"):
|
||||
# DFlash stores mask_token_id in dflash_config
|
||||
dflash_config = getattr(model_hf_config, "dflash_config", None)
|
||||
if dflash_config and "mask_token_id" in dflash_config:
|
||||
self.parallel_drafting_token_id = dflash_config["mask_token_id"]
|
||||
elif hasattr(model_hf_config, "pard_token"):
|
||||
self.parallel_drafting_token_id = model_hf_config.pard_token
|
||||
elif hasattr(model_hf_config, "ptd_token_id"):
|
||||
self.parallel_drafting_token_id = model_hf_config.ptd_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
"For parallel drafting, the draft model config must have "
|
||||
"`pard_token` or `ptd_token_id` specified in its config.json."
|
||||
"`pard_token`, `ptd_token_id`, or "
|
||||
"`dflash_config.mask_token_id` specified in its config.json."
|
||||
)
|
||||
|
||||
if self.pass_hidden_states_to_model:
|
||||
@@ -402,9 +418,14 @@ class SpecDecodeBaseProposer:
|
||||
) -> torch.Tensor:
|
||||
batch_size = common_attn_metadata.batch_size()
|
||||
|
||||
if self.method == "eagle3":
|
||||
if self.method in ("eagle3", "dflash"):
|
||||
assert isinstance(
|
||||
self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM)
|
||||
self.model,
|
||||
(
|
||||
Eagle3LlamaForCausalLM,
|
||||
Eagle3DeepseekV2ForCausalLM,
|
||||
DFlashQwen3ForCausalLM,
|
||||
),
|
||||
)
|
||||
target_hidden_states = self.model.combine_hidden_states(
|
||||
target_hidden_states
|
||||
@@ -423,40 +444,17 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
)
|
||||
|
||||
per_layer_attn_metadata: dict[str, object] = {}
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
per_layer_attn_metadata = self.build_per_layer_attn_metadata(
|
||||
common_attn_metadata
|
||||
)
|
||||
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(num_tokens)
|
||||
)
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
|
||||
|
||||
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
|
||||
self.input_ids[:num_tokens],
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(num_input_tokens),
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.pass_hidden_states_to_model:
|
||||
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
model_kwargs, slot_mapping_size = self.build_model_inputs_first_pass(
|
||||
num_tokens, num_input_tokens, mm_embed_inputs
|
||||
)
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
@@ -465,7 +463,7 @@ class SpecDecodeBaseProposer:
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
slot_mapping_size, common_attn_metadata.slot_mapping
|
||||
),
|
||||
):
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
@@ -488,7 +486,10 @@ class SpecDecodeBaseProposer:
|
||||
positions = self.positions[token_indices_to_sample]
|
||||
hidden_states = hidden_states[token_indices_to_sample]
|
||||
|
||||
if isinstance(attn_metadata, TreeAttentionMetadata):
|
||||
if any(
|
||||
isinstance(attn_metadata, TreeAttentionMetadata)
|
||||
for attn_metadata in per_layer_attn_metadata.values()
|
||||
):
|
||||
# Draft using tree attention - requires full logits for top-k
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
draft_token_ids_list = self.propose_tree(
|
||||
@@ -504,15 +505,16 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
draft_token_ids = self._greedy_sample(sample_hidden_states)
|
||||
|
||||
if self.allowed_attn_types is not None and not isinstance(
|
||||
attn_metadata, self.allowed_attn_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unsupported attention metadata type for speculative "
|
||||
"decoding with num_speculative_tokens > 1: "
|
||||
f"{type(attn_metadata)}. Supported types are: "
|
||||
f"{self.allowed_attn_types}"
|
||||
)
|
||||
for attn_metadata in per_layer_attn_metadata.values():
|
||||
if self.allowed_attn_types is not None and not isinstance(
|
||||
attn_metadata, self.allowed_attn_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unsupported attention metadata type for speculative "
|
||||
"decoding with num_speculative_tokens > 1: "
|
||||
f"{type(attn_metadata)}. Supported types are: "
|
||||
f"{self.allowed_attn_types}"
|
||||
)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
@@ -593,13 +595,9 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata._num_computed_tokens_cpu += 1
|
||||
|
||||
# Rebuild attention metadata
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1,
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
per_layer_attn_metadata = self.build_per_layer_attn_metadata(
|
||||
common_attn_metadata, draft_index=token_index + 1
|
||||
)
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
@@ -780,8 +778,51 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
return total_num_output_tokens, token_indices_to_sample, new_cad
|
||||
|
||||
def build_model_inputs_first_pass(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_input_tokens: int,
|
||||
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
|
||||
|
||||
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
|
||||
self.input_ids[:num_tokens],
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(num_input_tokens),
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.pass_hidden_states_to_model:
|
||||
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
|
||||
return model_kwargs, num_input_tokens
|
||||
|
||||
def build_per_layer_attn_metadata(
|
||||
self, common_attn_metadata: CommonAttentionMetadata, draft_index: int = 0
|
||||
) -> dict[str, object]:
|
||||
per_layer_attn_metadata: dict[str, object] = {}
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=draft_index
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
return per_layer_attn_metadata
|
||||
|
||||
def model_returns_tuple(self) -> bool:
|
||||
return self.method not in ("mtp", "draft_model")
|
||||
return self.method not in ("mtp", "draft_model", "dflash")
|
||||
|
||||
def prepare_next_token_ids_cpu(
|
||||
self,
|
||||
@@ -1310,15 +1351,20 @@ class SpecDecodeBaseProposer:
|
||||
self._maybe_share_embeddings(target_language_model)
|
||||
self._maybe_share_lm_head(target_language_model)
|
||||
|
||||
if self.parallel_drafting and self.pass_hidden_states_to_model:
|
||||
assert self.parallel_drafting_hidden_state_tensor is not None
|
||||
self.parallel_drafting_hidden_state_tensor.copy_(
|
||||
self.model.combine_hidden_states(
|
||||
self.model.mask_hidden.view(3 * self.hidden_size)
|
||||
if (
|
||||
self.parallel_drafting
|
||||
and self.pass_hidden_states_to_model
|
||||
and self.parallel_drafting_hidden_state_tensor is not None
|
||||
):
|
||||
flat_mask = self.model.mask_hidden.view(-1)
|
||||
if self.eagle3_use_aux_hidden_state:
|
||||
# EAGLE3: mask_hidden stores all aux hidden states,
|
||||
# project through combine_hidden_states
|
||||
self.parallel_drafting_hidden_state_tensor.copy_(
|
||||
self.model.combine_hidden_states(flat_mask)
|
||||
)
|
||||
if self.eagle3_use_aux_hidden_state
|
||||
else self.model.mask_hidden.view(self.hidden_size)
|
||||
)
|
||||
else:
|
||||
self.parallel_drafting_hidden_state_tensor.copy_(flat_mask)
|
||||
|
||||
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
|
||||
"""
|
||||
@@ -1493,8 +1539,9 @@ class SpecDecodeBaseProposer:
|
||||
) -> None:
|
||||
# FIXME: when using tree-based specdec, adjust number of forward-passes
|
||||
# according to the depth of the tree.
|
||||
only_one_forward_pass = is_graph_capturing or self.parallel_drafting
|
||||
for fwd_idx in range(
|
||||
self.num_speculative_tokens if not is_graph_capturing else 1
|
||||
1 if only_one_forward_pass else self.num_speculative_tokens
|
||||
):
|
||||
if fwd_idx <= 1:
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
|
||||
@@ -441,6 +441,114 @@ def copy_and_expand_eagle_inputs_kernel(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def copy_and_expand_dflash_inputs_kernel(
|
||||
# Inputs
|
||||
next_token_ids_ptr, # [num_reqs]
|
||||
target_positions_ptr, # [num_context]
|
||||
# Outputs
|
||||
out_input_ids_ptr, # [num_query_total] (output)
|
||||
out_context_positions_ptr, # [num_context] (output)
|
||||
out_query_positions_ptr, # [num_query_total] (output)
|
||||
out_context_slot_mapping_ptr, # [num_context] (output)
|
||||
out_query_slot_mapping_ptr, # [num_query_total] (output)
|
||||
out_token_indices_ptr, # [num_reqs * num_speculative_tokens] (output)
|
||||
# Block table
|
||||
block_table_ptr, # [max_reqs, max_blocks]
|
||||
block_table_stride, # stride of block_table dim 0 (in elements)
|
||||
# Metadata
|
||||
query_start_loc_ptr, # [num_reqs + 1]
|
||||
num_rejected_tokens_ptr, # [num_reqs] or null (0) when not padded
|
||||
# Scalars
|
||||
parallel_drafting_token_id, # tl.int32
|
||||
block_size, # tl.int32
|
||||
num_query_per_req, # tl.int32
|
||||
num_speculative_tokens, # tl.int32
|
||||
total_input_tokens, # tl.int32
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
HAS_NUM_REJECTED: tl.constexpr = False,
|
||||
):
|
||||
"""
|
||||
Fused kernel for DFlash first-pass input setup.
|
||||
|
||||
Per request, this kernel:
|
||||
1. Copies context positions from target_positions to
|
||||
out_context_positions.
|
||||
2. Computes query positions (last_target_pos + 1 + offset) and writes
|
||||
them to out_query_positions.
|
||||
3. Writes input_ids for query tokens: [next_token, mask, mask, ...].
|
||||
4. Computes slot_mapping for context and query positions into separate
|
||||
buffers via block_table lookup.
|
||||
5. Writes token_indices_to_sample for the mask (speculative) tokens.
|
||||
"""
|
||||
req_idx = tl.program_id(axis=0)
|
||||
block_idx = tl.program_id(axis=1)
|
||||
|
||||
# Load context token range for this request
|
||||
ctx_start = tl.load(query_start_loc_ptr + req_idx)
|
||||
ctx_end = tl.load(query_start_loc_ptr + req_idx + 1)
|
||||
num_ctx = ctx_end - ctx_start
|
||||
total_tokens = num_ctx + num_query_per_req
|
||||
|
||||
j = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
in_bounds = j < total_tokens
|
||||
is_ctx = j < num_ctx
|
||||
is_query = (~is_ctx) & in_bounds
|
||||
query_off = j - num_ctx # offset within query portion (0-indexed)
|
||||
|
||||
# --- Positions ---
|
||||
# Context: load from target_positions
|
||||
ctx_pos_idx = tl.minimum(ctx_start + j, total_input_tokens - 1)
|
||||
ctx_pos = tl.load(target_positions_ptr + ctx_pos_idx, mask=is_ctx, other=0)
|
||||
|
||||
# Query: last_valid_pos + 1 + query_off
|
||||
# In padded mode, ctx_end includes rejected tokens; use valid_ctx_end
|
||||
# to find the last accepted context position.
|
||||
if HAS_NUM_REJECTED:
|
||||
num_rejected = tl.load(num_rejected_tokens_ptr + req_idx)
|
||||
valid_ctx_end = ctx_end - num_rejected
|
||||
else:
|
||||
valid_ctx_end = ctx_end
|
||||
last_pos = tl.load(target_positions_ptr + valid_ctx_end - 1)
|
||||
query_pos = last_pos + 1 + query_off
|
||||
|
||||
positions = tl.where(is_ctx, ctx_pos, query_pos)
|
||||
|
||||
# Context and query positions go to separate buffers.
|
||||
ctx_pos_out = ctx_start + j
|
||||
tl.store(out_context_positions_ptr + ctx_pos_out, ctx_pos, mask=is_ctx)
|
||||
query_out = req_idx * num_query_per_req + query_off
|
||||
tl.store(out_query_positions_ptr + query_out, query_pos, mask=is_query)
|
||||
|
||||
# --- Slot mapping (block_table lookup for all positions) ---
|
||||
block_num = positions // block_size
|
||||
# # Clamp block_number to avoid OOB when position is at max
|
||||
block_num = tl.minimum(block_num, block_table_stride - 1)
|
||||
block_id = tl.load(
|
||||
block_table_ptr + req_idx * block_table_stride + block_num,
|
||||
mask=in_bounds,
|
||||
other=0,
|
||||
).to(tl.int64)
|
||||
slot = block_id * block_size + (positions % block_size)
|
||||
tl.store(out_context_slot_mapping_ptr + ctx_pos_out, slot, mask=is_ctx)
|
||||
tl.store(out_query_slot_mapping_ptr + query_out, slot, mask=is_query)
|
||||
|
||||
# --- Input IDs (query tokens only) ---
|
||||
bonus_token = tl.load(next_token_ids_ptr + req_idx)
|
||||
is_bonus = is_query & (query_off == 0)
|
||||
input_id = tl.where(is_bonus, bonus_token, parallel_drafting_token_id)
|
||||
tl.store(out_input_ids_ptr + query_out, input_id, mask=is_query)
|
||||
|
||||
# --- Token indices to sample (mask tokens, skip the bonus token) ---
|
||||
is_sample = is_query & (query_off > 0)
|
||||
sample_out_idx = req_idx * num_speculative_tokens + (query_off - 1)
|
||||
tl.store(
|
||||
out_token_indices_ptr + sample_out_idx,
|
||||
query_out,
|
||||
mask=is_sample,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def update_num_computed_tokens_for_batch_change(
|
||||
num_computed_tokens: torch.Tensor,
|
||||
|
||||
@@ -160,6 +160,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.dflash import DFlashProposer
|
||||
from vllm.v1.spec_decode.draft_model import DraftModelProposer
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
|
||||
@@ -515,6 +516,7 @@ class GPUModelRunner(
|
||||
| NgramProposerGPU
|
||||
| SuffixDecodingProposer
|
||||
| EagleProposer
|
||||
| DFlashProposer
|
||||
| DraftModelProposer
|
||||
| MedusaProposer
|
||||
| ExtractHiddenStatesProposer
|
||||
@@ -546,6 +548,9 @@ class GPUModelRunner(
|
||||
self._ngram_pinned_val_buf = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.int32, pin_memory=True
|
||||
)
|
||||
elif self.speculative_config.use_dflash():
|
||||
self.drafter = DFlashProposer(self.vllm_config, self.device, self)
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
elif self.speculative_config.method == "suffix":
|
||||
self.drafter = SuffixDecodingProposer(self.vllm_config)
|
||||
elif self.speculative_config.use_eagle():
|
||||
@@ -2289,7 +2294,7 @@ class GPUModelRunner(
|
||||
cm.slot_mapping = slot_mappings[kv_cache_gid]
|
||||
|
||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||
if isinstance(self.drafter, EagleProposer):
|
||||
if isinstance(self.drafter, (EagleProposer, DFlashProposer)):
|
||||
if self.drafter.kv_cache_gid == kv_cache_gid:
|
||||
spec_decode_common_attn_metadata = cm
|
||||
else:
|
||||
@@ -4202,7 +4207,10 @@ class GPUModelRunner(
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
assert isinstance(
|
||||
self.drafter,
|
||||
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
|
||||
EagleProposer
|
||||
| DFlashProposer
|
||||
| DraftModelProposer
|
||||
| ExtractHiddenStatesProposer,
|
||||
)
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
if input_fits_in_drafter:
|
||||
@@ -4589,8 +4597,14 @@ class GPUModelRunner(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
|
||||
elif spec_config.use_eagle() or spec_config.uses_draft_model():
|
||||
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
|
||||
elif (
|
||||
spec_config.use_eagle()
|
||||
or spec_config.use_dflash()
|
||||
or spec_config.uses_draft_model()
|
||||
):
|
||||
assert isinstance(
|
||||
self.drafter, EagleProposer | DFlashProposer | DraftModelProposer
|
||||
)
|
||||
|
||||
if spec_config.disable_padded_drafter_batch:
|
||||
# When padded-batch is disabled, the sampled_token_ids should be
|
||||
@@ -4889,10 +4903,13 @@ class GPUModelRunner(
|
||||
return None
|
||||
|
||||
hf_config = self.speculative_config.draft_model_config.hf_config
|
||||
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
|
||||
return None
|
||||
|
||||
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
|
||||
layer_ids = getattr(hf_config, "eagle_aux_hidden_state_layer_ids", None)
|
||||
if not layer_ids:
|
||||
dflash_config = getattr(hf_config, "dflash_config", None)
|
||||
if dflash_config and isinstance(dflash_config, dict):
|
||||
layer_ids = dflash_config.get("target_layer_ids")
|
||||
|
||||
if layer_ids and isinstance(layer_ids, (list, tuple)):
|
||||
return tuple(layer_ids)
|
||||
|
||||
@@ -5479,7 +5496,10 @@ class GPUModelRunner(
|
||||
):
|
||||
assert isinstance(
|
||||
self.drafter,
|
||||
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
|
||||
EagleProposer
|
||||
| DFlashProposer
|
||||
| DraftModelProposer
|
||||
| ExtractHiddenStatesProposer,
|
||||
)
|
||||
assert self.speculative_config is not None
|
||||
# Eagle currently only supports PIECEWISE cudagraphs.
|
||||
@@ -6236,7 +6256,9 @@ class GPUModelRunner(
|
||||
self.speculative_config.use_eagle()
|
||||
or self.speculative_config.uses_draft_model()
|
||||
):
|
||||
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
|
||||
assert isinstance(
|
||||
self.drafter, EagleProposer | DFlashProposer | DraftModelProposer
|
||||
)
|
||||
self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes)
|
||||
|
||||
def _check_and_update_cudagraph_mode(
|
||||
@@ -6420,7 +6442,10 @@ class GPUModelRunner(
|
||||
self.speculative_config.use_eagle()
|
||||
or self.speculative_config.uses_extract_hidden_states()
|
||||
):
|
||||
assert isinstance(self.drafter, EagleProposer | ExtractHiddenStatesProposer)
|
||||
assert isinstance(
|
||||
self.drafter,
|
||||
EagleProposer | DFlashProposer | ExtractHiddenStatesProposer,
|
||||
)
|
||||
self.drafter.initialize_cudagraph_keys(cudagraph_mode)
|
||||
|
||||
def calculate_reorder_batch_threshold(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user