[Feat][Spec Decode] DFlash (#36847)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
ab1a6a43fa
commit
494636b29d
@@ -7,6 +7,7 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
|
||||
from tests.utils import (
|
||||
@@ -1105,19 +1106,178 @@ def some_high_acceptance_metrics() -> dict:
|
||||
}
|
||||
|
||||
|
||||
def compute_acceptance_rate(metrics: list[Metric]) -> float:
|
||||
def compute_acceptance_rate(
|
||||
metrics: list[Metric], prev_metrics: list[Metric] | None = None
|
||||
) -> float:
|
||||
name2metric = {metric.name: metric for metric in metrics}
|
||||
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
|
||||
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value
|
||||
if n_draft_toks == 0:
|
||||
return float("nan")
|
||||
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
|
||||
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
|
||||
if prev_metrics is not None:
|
||||
prev_name2metric = {metric.name: metric for metric in prev_metrics}
|
||||
n_draft_toks -= prev_name2metric["vllm:spec_decode_num_draft_tokens"].value
|
||||
n_accepted_toks -= prev_name2metric[
|
||||
"vllm:spec_decode_num_accepted_tokens"
|
||||
].value
|
||||
if n_draft_toks <= 0:
|
||||
return float("nan")
|
||||
return n_accepted_toks / n_draft_toks
|
||||
|
||||
|
||||
def compute_acceptance_len(metrics: list[Metric]) -> float:
|
||||
def compute_acceptance_len(
|
||||
metrics: list[Metric], prev_metrics: list[Metric] | None = None
|
||||
) -> float:
|
||||
name2metric = {metric.name: metric for metric in metrics}
|
||||
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore
|
||||
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
|
||||
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value
|
||||
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
|
||||
if n_drafts == 0:
|
||||
return 1
|
||||
if prev_metrics is not None:
|
||||
prev_name2metric = {metric.name: metric for metric in prev_metrics}
|
||||
n_drafts -= prev_name2metric["vllm:spec_decode_num_drafts"].value
|
||||
n_accepted_toks -= prev_name2metric[
|
||||
"vllm:spec_decode_num_accepted_tokens"
|
||||
].value
|
||||
if n_drafts <= 0:
|
||||
return 1
|
||||
return 1 + (n_accepted_toks / n_drafts)
|
||||
|
||||
|
||||
# Datasets in the format used in DFlash validations
|
||||
def load_and_process_dataset(data_name: str):
|
||||
from datasets import load_dataset
|
||||
|
||||
if data_name == "gsm8k":
|
||||
dataset = load_dataset("openai/gsm8k", "main", split="test")
|
||||
prompt_fmt = (
|
||||
"{question}\nPlease reason step by step,"
|
||||
" and put your final answer within \\boxed{{}}."
|
||||
)
|
||||
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
|
||||
elif data_name == "mt-bench":
|
||||
dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
|
||||
dataset = dataset.map(lambda x: {"turns": x["prompt"]})
|
||||
elif data_name == "humaneval":
|
||||
dataset = load_dataset("openai/openai_humaneval", split="test")
|
||||
prompt_fmt = (
|
||||
"Write a solution to the following problem and make sure"
|
||||
" that it passes the tests:\n```python\n{prompt}\n```"
|
||||
)
|
||||
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dflash_config():
|
||||
target_model = "Qwen/Qwen3-8B"
|
||||
draft_model = "z-lab/Qwen3-8B-DFlash-b16"
|
||||
|
||||
return dict(
|
||||
model=target_model,
|
||||
trust_remote_code=True,
|
||||
speculative_config={
|
||||
"method": "dflash",
|
||||
"model": draft_model,
|
||||
"num_speculative_tokens": 16,
|
||||
"max_model_len": 32768,
|
||||
},
|
||||
max_model_len=32768,
|
||||
max_num_seqs=128,
|
||||
gpu_memory_utilization=0.85,
|
||||
enforce_eager=False,
|
||||
disable_log_stats=False,
|
||||
)
|
||||
|
||||
|
||||
def test_dflash_acceptance_rates(dflash_config):
|
||||
"""
|
||||
E2E test for DFlash (block diffusion) speculative decoding.
|
||||
Runs acceptance rate validation on GSM8k, MT-Bench, and HumanEval
|
||||
comparing against baseline results from the paper (Table 1).
|
||||
See https://github.com/z-lab/dflash/blob/main/benchmark_sglang.py for methodology.
|
||||
"""
|
||||
spec_llm = LLM(**dflash_config)
|
||||
|
||||
max_prompts_per_dataset = 200 # mt-bench has 80, humaneval has 164, truncates gsm8k
|
||||
|
||||
# All scores from Table 1 in https://arxiv.org/pdf/2602.06036
|
||||
expected_acceptance_lengths = {
|
||||
"mt-bench": 4.24,
|
||||
"humaneval": 6.50,
|
||||
"gsm8k": 6.54 * 0.95, # runs with a subset of prompts so extra wide tol here
|
||||
}
|
||||
|
||||
tokenizer = spec_llm.get_tokenizer()
|
||||
for dataset_name, expected_len in expected_acceptance_lengths.items():
|
||||
dataset = load_and_process_dataset(dataset_name)
|
||||
prev_metrics = None
|
||||
acceptance_lengths = []
|
||||
for i in tqdm(
|
||||
range(min(max_prompts_per_dataset, len(dataset))),
|
||||
desc=f"Processing {dataset_name}",
|
||||
):
|
||||
user_content = dataset[i]["turns"][0]
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": user_content}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
# Temp=0, MaxTokens=2048 from the paper
|
||||
spec_llm.generate(
|
||||
[prompt_text],
|
||||
SamplingParams(temperature=0, max_tokens=2048),
|
||||
use_tqdm=False,
|
||||
)
|
||||
current_metrics = spec_llm.get_metrics()
|
||||
acceptance_len = compute_acceptance_len(current_metrics, prev_metrics)
|
||||
prev_metrics = current_metrics
|
||||
acceptance_lengths.append(acceptance_len)
|
||||
|
||||
mean_acceptance_length = sum(acceptance_lengths) / len(acceptance_lengths)
|
||||
expected_len = expected_len * 0.9
|
||||
print(
|
||||
f"DFlash acceptance_len for {dataset_name}: {mean_acceptance_length:.2f}"
|
||||
f" (expected at least {expected_len:.2f})"
|
||||
)
|
||||
|
||||
assert mean_acceptance_length >= expected_len, (
|
||||
f"DFlash acceptance_len for {dataset_name} is below expected threshold:"
|
||||
f"{mean_acceptance_length:.2f} < {expected_len:.2f}"
|
||||
)
|
||||
|
||||
del spec_llm
|
||||
torch.accelerator.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def test_dflash_correctness(dflash_config):
|
||||
"""
|
||||
E2E test for DFlash (block diffusion) speculative decoding.
|
||||
Ensures output correctness on GSM8k, with cudagraphs and batching on.
|
||||
"""
|
||||
spec_llm = LLM(**dflash_config)
|
||||
|
||||
# Evaluate GSM8k accuracy (Qwen3-8B ref: ~87-92% on GSM8k)
|
||||
evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8)
|
||||
|
||||
current_metrics = spec_llm.get_metrics()
|
||||
acceptance_len = compute_acceptance_len(current_metrics)
|
||||
|
||||
# AR is thoroughly validated in test_dflash_acceptance_rates, in a manner consistent
|
||||
# with the DFlash paper. However, that test measures AL per-request and thus runs
|
||||
# with a batch size of 1. To ensure that AL does not collapse with large batch sizes
|
||||
# we enforce a baseline on the AL over the full lm-eval-style GSM8k test.
|
||||
expected_len = 3.5 # Measured is 3.9 to 4.0
|
||||
print(f"DFlash GSM8k correctness test got AL {acceptance_len}")
|
||||
assert acceptance_len >= expected_len, (
|
||||
"DFlash correctness check failed with"
|
||||
f" {acceptance_len=}, expected at least {expected_len}"
|
||||
)
|
||||
|
||||
del spec_llm
|
||||
torch.accelerator.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
Reference in New Issue
Block a user