[Feat][Spec Decode] DFlash (#36847)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-03-30 15:03:15 -04:00
committed by GitHub
parent ab1a6a43fa
commit 494636b29d
17 changed files with 1577 additions and 107 deletions

View File

@@ -7,6 +7,7 @@ from typing import Any
import pytest
import torch
from tqdm import tqdm
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
from tests.utils import (
@@ -1105,19 +1106,178 @@ def some_high_acceptance_metrics() -> dict:
}
def compute_acceptance_rate(metrics: list[Metric]) -> float:
def compute_acceptance_rate(
metrics: list[Metric], prev_metrics: list[Metric] | None = None
) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value
if n_draft_toks == 0:
return float("nan")
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
if prev_metrics is not None:
prev_name2metric = {metric.name: metric for metric in prev_metrics}
n_draft_toks -= prev_name2metric["vllm:spec_decode_num_draft_tokens"].value
n_accepted_toks -= prev_name2metric[
"vllm:spec_decode_num_accepted_tokens"
].value
if n_draft_toks <= 0:
return float("nan")
return n_accepted_toks / n_draft_toks
def compute_acceptance_len(metrics: list[Metric]) -> float:
def compute_acceptance_len(
metrics: list[Metric], prev_metrics: list[Metric] | None = None
) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
if n_drafts == 0:
return 1
if prev_metrics is not None:
prev_name2metric = {metric.name: metric for metric in prev_metrics}
n_drafts -= prev_name2metric["vllm:spec_decode_num_drafts"].value
n_accepted_toks -= prev_name2metric[
"vllm:spec_decode_num_accepted_tokens"
].value
if n_drafts <= 0:
return 1
return 1 + (n_accepted_toks / n_drafts)
# Datasets in the format used in DFlash validations
def load_and_process_dataset(data_name: str):
from datasets import load_dataset
if data_name == "gsm8k":
dataset = load_dataset("openai/gsm8k", "main", split="test")
prompt_fmt = (
"{question}\nPlease reason step by step,"
" and put your final answer within \\boxed{{}}."
)
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
elif data_name == "mt-bench":
dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
dataset = dataset.map(lambda x: {"turns": x["prompt"]})
elif data_name == "humaneval":
dataset = load_dataset("openai/openai_humaneval", split="test")
prompt_fmt = (
"Write a solution to the following problem and make sure"
" that it passes the tests:\n```python\n{prompt}\n```"
)
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
return dataset
@pytest.fixture
def dflash_config():
target_model = "Qwen/Qwen3-8B"
draft_model = "z-lab/Qwen3-8B-DFlash-b16"
return dict(
model=target_model,
trust_remote_code=True,
speculative_config={
"method": "dflash",
"model": draft_model,
"num_speculative_tokens": 16,
"max_model_len": 32768,
},
max_model_len=32768,
max_num_seqs=128,
gpu_memory_utilization=0.85,
enforce_eager=False,
disable_log_stats=False,
)
def test_dflash_acceptance_rates(dflash_config):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Runs acceptance rate validation on GSM8k, MT-Bench, and HumanEval
comparing against baseline results from the paper (Table 1).
See https://github.com/z-lab/dflash/blob/main/benchmark_sglang.py for methodology.
"""
spec_llm = LLM(**dflash_config)
max_prompts_per_dataset = 200 # mt-bench has 80, humaneval has 164, truncates gsm8k
# All scores from Table 1 in https://arxiv.org/pdf/2602.06036
expected_acceptance_lengths = {
"mt-bench": 4.24,
"humaneval": 6.50,
"gsm8k": 6.54 * 0.95, # runs with a subset of prompts so extra wide tol here
}
tokenizer = spec_llm.get_tokenizer()
for dataset_name, expected_len in expected_acceptance_lengths.items():
dataset = load_and_process_dataset(dataset_name)
prev_metrics = None
acceptance_lengths = []
for i in tqdm(
range(min(max_prompts_per_dataset, len(dataset))),
desc=f"Processing {dataset_name}",
):
user_content = dataset[i]["turns"][0]
prompt_text = tokenizer.apply_chat_template(
[{"role": "user", "content": user_content}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
# Temp=0, MaxTokens=2048 from the paper
spec_llm.generate(
[prompt_text],
SamplingParams(temperature=0, max_tokens=2048),
use_tqdm=False,
)
current_metrics = spec_llm.get_metrics()
acceptance_len = compute_acceptance_len(current_metrics, prev_metrics)
prev_metrics = current_metrics
acceptance_lengths.append(acceptance_len)
mean_acceptance_length = sum(acceptance_lengths) / len(acceptance_lengths)
expected_len = expected_len * 0.9
print(
f"DFlash acceptance_len for {dataset_name}: {mean_acceptance_length:.2f}"
f" (expected at least {expected_len:.2f})"
)
assert mean_acceptance_length >= expected_len, (
f"DFlash acceptance_len for {dataset_name} is below expected threshold:"
f"{mean_acceptance_length:.2f} < {expected_len:.2f}"
)
del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
def test_dflash_correctness(dflash_config):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Ensures output correctness on GSM8k, with cudagraphs and batching on.
"""
spec_llm = LLM(**dflash_config)
# Evaluate GSM8k accuracy (Qwen3-8B ref: ~87-92% on GSM8k)
evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8)
current_metrics = spec_llm.get_metrics()
acceptance_len = compute_acceptance_len(current_metrics)
# AR is thoroughly validated in test_dflash_acceptance_rates, in a manner consistent
# with the DFlash paper. However, that test measures AL per-request and thus runs
# with a batch size of 1. To ensure that AL does not collapse with large batch sizes
# we enforce a baseline on the AL over the full lm-eval-style GSM8k test.
expected_len = 3.5 # Measured is 3.9 to 4.0
print(f"DFlash GSM8k correctness test got AL {acceptance_len}")
assert acceptance_len >= expected_len, (
"DFlash correctness check failed with"
f" {acceptance_len=}, expected at least {expected_len}"
)
del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()