[Tests] Add GSM8k check to SpecDec E2E tests (#34772)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-02-25 07:51:14 -05:00
committed by GitHub
parent 709eadbb0b
commit ee59a7c615
2 changed files with 239 additions and 145 deletions

View File

@@ -8,6 +8,7 @@ from typing import Any
import pytest
import torch
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
@@ -35,53 +36,57 @@ def _skip_if_insufficient_gpus_for_tp(tp_size: int):
Messages = list[dict[str, Any]]
def get_test_prompts(
mm_enabled: bool, quiet: bool = False, num_prompts: int = 100
) -> list[Messages]:
prompt_types = ["repeat", "sentence"]
def get_test_prompts(mm_enabled: bool, num_prompts: int = 100) -> list[Messages]:
prompt_types = ["repeat", "gsm8k"]
if mm_enabled:
prompt_types.append("mm")
prompts = []
prompts: list[Messages] = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
if not quiet:
print(f"Prompt types: {random_prompt_type_choices}")
num_repeat_prompts = num_prompts // len(prompt_types)
if mm_enabled:
num_gsm8k_prompts = num_prompts // len(prompt_types)
num_mm_prompts = num_prompts - num_repeat_prompts - num_gsm8k_prompts
else:
num_mm_prompts = 0
num_gsm8k_prompts = num_prompts - num_repeat_prompts
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
random.seed(0)
for _ in range(num_repeat_prompts):
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
prompt: str | list[dict[str, Any]] = ""
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
give no other output than the word at least ten times in a row,
in lowercase with spaces between each word and without quotes.
"""
elif kind == "sentence":
prompt = f"""
please give a ten-word sentence that
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
elif kind == "mm":
placeholders = [
prompts.append(
[
{
"type": "image_url",
"image_url": {
"url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
},
"role": "user",
"content": f"""
please repeat the word '{word}' 10 times.
give no other output than the word at least ten times in a row,
in lowercase with spaces between each word and without quotes.
""",
}
]
prompt = [
*placeholders,
{"type": "text", "text": "The meaning of the image is"},
]
else:
raise ValueError(f"Unknown prompt type: {kind}")
)
prompts.extend(
[{"role": "user", "content": prompt}]
for prompt in _build_gsm8k_prompts(
num_questions=num_gsm8k_prompts, num_shots=5
)[0]
)
for _ in range(num_mm_prompts):
placeholders = [
{
"type": "image_url",
"image_url": {
"url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
},
}
]
prompt = [
*placeholders,
{"type": "text", "text": "The meaning of the image is"},
]
prompts.append([{"role": "user", "content": prompt}])
return prompts
@@ -113,6 +118,25 @@ def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"
def evaluate_llm_for_gsm8k(llm: LLM, expected_accuracy_threshold: float = 0.70) -> None:
"""Evaluate the LLM on GSM8K and check that accuracy is above a sanity threshold.
The default threshold assumes the LLM uses the same target model as the "model_name"
fixture, with max model len == 4096. Precomputed reference value is 75% to 80%
on GSM8K with greedy decoding, so we check that it's above a sanity threshold of 70%
to verify that the model is correct.
"""
if expected_accuracy_threshold <= 0.0:
print("Skipping GSM8K evaluation")
return
results = evaluate_gsm8k_offline(llm)
accuracy = results["accuracy"]
print(f"GSM8K accuracy: {accuracy:.3f}")
assert accuracy >= expected_accuracy_threshold, (
f"Expected GSM8K accuracy >= {expected_accuracy_threshold}, got {accuracy:.3f}"
)
@pytest.fixture(autouse=True)
def reset_torch_dynamo():
"""Reset torch dynamo cache before each test"""
@@ -138,41 +162,14 @@ def reset_torch_dynamo():
)
def test_ngram_and_suffix_correctness(
speculative_config: dict,
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
):
"""
Compare the outputs of an original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
"""
test_prompts = get_test_prompts(mm_enabled=False)
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
model=model_name,
speculative_config=speculative_config,
max_model_len=1024,
max_model_len=4096,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches >= int(0.66 * len(ref_outputs))
evaluate_llm_for_gsm8k(spec_llm)
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@@ -238,10 +235,10 @@ def test_suffix_decoding_acceptance(
@pytest.mark.parametrize(
"model_path",
["model_path", "expected_accuracy_threshold"],
[
"RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
"RedHatAI/Qwen3-8B-speculator.eagle3",
("RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3", 0.7), # ref: 75%-80%
("RedHatAI/Qwen3-8B-speculator.eagle3", 0.8), # ref: 87%-92%
],
ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
)
@@ -249,6 +246,7 @@ def test_speculators_model_integration(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_path: str,
expected_accuracy_threshold: float,
):
"""
Test that speculators models work with the simplified integration.
@@ -262,7 +260,8 @@ def test_speculators_model_integration(
2. Verifier model is extracted from speculator config
3. Speculative decoding is automatically enabled
4. Text generation works correctly
5. Output matches reference (non-speculative) generation
5. GSM8k accuracy of the model passes a sanity check when speculative decoding on
6. Output matches reference (non-speculative) generation
"""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@@ -270,7 +269,10 @@ def test_speculators_model_integration(
test_prompts = get_test_prompts(mm_enabled=False)
# First run: Direct speculator model (simplified integration)
spec_llm = LLM(model=model_path, max_model_len=1024)
spec_llm = LLM(model=model_path, max_model_len=4096)
evaluate_llm_for_gsm8k(
spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
# Verify speculative config was auto-detected
@@ -297,7 +299,7 @@ def test_speculators_model_integration(
cleanup_dist_env_and_memory()
# Second run: Reference without speculative decoding
ref_llm = LLM(model=verifier_model, max_model_len=1024)
ref_llm = LLM(model=verifier_model, max_model_len=4096)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
@@ -318,19 +320,27 @@ def test_speculators_model_integration(
@pytest.mark.parametrize(
["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
[
"model_setup",
"mm_enabled",
"enable_chunked_prefill",
"model_impl",
"expected_accuracy_threshold",
],
[
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"auto",
0.8, # ref: 90%
),
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"transformers",
0.8, # ref: 90%
),
pytest.param(
(
@@ -342,6 +352,7 @@ def test_speculators_model_integration(
False,
False,
"auto",
0.8, # ref: 90%
marks=pytest.mark.skip(
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
),
@@ -356,6 +367,7 @@ def test_speculators_model_integration(
False,
False,
"auto",
0.7, # TODO, update this with a reference value when re-enabling this case
marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a a multiple of 32"
),
@@ -370,6 +382,7 @@ def test_speculators_model_integration(
False,
True,
"auto",
0.7, # ref: 75%-80%
marks=large_gpu_mark(min_gb=40),
), # works on 4x H100
(
@@ -382,6 +395,7 @@ def test_speculators_model_integration(
False,
False,
"auto",
0.7, # ref: 75%-80%
),
pytest.param(
(
@@ -393,7 +407,8 @@ def test_speculators_model_integration(
False,
False,
"auto",
marks=large_gpu_mark(min_gb=80),
0.8, # ref: 90%
# marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
pytest.param(
(
@@ -405,6 +420,7 @@ def test_speculators_model_integration(
True,
True,
"auto",
0.8, # ref: 90%
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
(
@@ -417,6 +433,7 @@ def test_speculators_model_integration(
False,
False,
"auto",
0.0, # dummy model, skip gsm8k check
),
],
ids=[
@@ -437,10 +454,18 @@ def test_eagle_correctness(
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
expected_accuracy_threshold: float,
enable_chunked_prefill: bool,
model_impl: str,
attn_backend: str,
):
"""
Compare the outputs of a original LLM and a speculative LLM
which should be the same when using eagle speculative decoding. Due to some variance
in the engine, it is possible for some outputs to differ, so we expect that at least
6/10 output tokens match exactly, and that the GSM8k accuracy is above
a precomputed reference threshold for each model.
"""
if attn_backend == "TREE_ATTN":
# TODO: Fix this flaky test
pytest.skip(
@@ -461,11 +486,6 @@ def test_eagle_correctness(
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
"""
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size)
"""
# Determine attention config
# Scout requires default backend selection because vision encoder has
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
@@ -505,6 +525,9 @@ def test_eagle_correctness(
tensor_parallel_size=tp_size,
attention_config=attention_config,
)
evaluate_llm_for_gsm8k(
ref_llm, expected_accuracy_threshold=expected_accuracy_threshold
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
@@ -526,6 +549,9 @@ def test_eagle_correctness(
model_impl=model_impl,
attention_config=attention_config,
)
evaluate_llm_for_gsm8k(
spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
@@ -546,10 +572,10 @@ def test_eagle_correctness(
@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
["model_setup", "mm_enabled", "expected_accuracy_threshold"],
[
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False, 0.5), # ref: 65%-70%
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False, 0.0), # dummy model
],
ids=["mimo", "deepseek"],
)
@@ -558,14 +584,17 @@ def test_mtp_correctness(
sampling_config: SamplingParams,
model_setup: tuple[str, str, int],
mm_enabled: bool,
expected_accuracy_threshold: float,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
"""
Compare the outputs of a original LLM and a speculative LLM
should be the same when using MTP speculative decoding.
model_setup: (method, model_name, tp_size)
which should be the same when using MTP speculative decoding. Due to some variance
in the engine, it is possible for some outputs to differ, so we expect that at least
6/10 output tokens match exactly, and that the GSM8k accuracy is above a precomputed
reference threshold for each model.
"""
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
with monkeypatch.context() as m:
m.setenv("VLLM_MLA_DISABLE", "1")
@@ -579,6 +608,9 @@ def test_mtp_correctness(
trust_remote_code=True,
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
evaluate_llm_for_gsm8k(
ref_llm, expected_accuracy_threshold=expected_accuracy_threshold
)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@@ -594,6 +626,9 @@ def test_mtp_correctness(
},
max_model_len=2048,
)
evaluate_llm_for_gsm8k(
spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
@@ -621,12 +656,13 @@ class ArgsTest:
num_speculative_tokens: int
expected_acceptance_rate: float
expected_acceptance_len: float
expected_gsm8k_accuracy: float = 0.0 # skip by default
# Defaults
enforce_eager: bool = True
parallel_drafting: bool = False
target_tensor_parallel_size: int = 1
draft_tensor_parallel_size: int = 1
max_model_len: int = 1024
max_model_len: int = 2048
gpu_memory_utilization: float = 0.5
dataset: str = "test_prompts"
num_prompts: int = 100
@@ -639,8 +675,9 @@ cases = [
draft_model="Qwen/Qwen3-0.6B",
sampling_config=greedy_sampling(),
num_speculative_tokens=3, # K
expected_acceptance_len=3 + 1, # K + 1
expected_acceptance_rate=1.0,
expected_acceptance_len=0.98 * (3 + 1), # epsilon discount of K + 1
expected_acceptance_rate=0.98, # slight epsilon
expected_gsm8k_accuracy=0.25, # ref: 35-40%
),
# Smaller draft model, stochastic sampling.
ArgsTest(
@@ -648,8 +685,9 @@ cases = [
draft_model="Qwen/Qwen3-0.6B",
sampling_config=stochastic_sampling(),
num_speculative_tokens=3,
expected_acceptance_len=2.8 + 1,
expected_acceptance_rate=0.9,
expected_acceptance_len=3.4, # ref: 3.7
expected_acceptance_rate=0.80, # ref: 0.90
expected_gsm8k_accuracy=0.5, # ref: 60%. Note gsm8k always runs greedy sampling
),
]
@@ -669,9 +707,8 @@ def test_draft_model_realistic_example():
num_speculative_tokens=3,
sampling_config=greedy_sampling(),
enforce_eager=False,
# values below are not derived, but just prevent a regression
expected_acceptance_len=2.8,
expected_acceptance_rate=0.55,
expected_acceptance_len=2.6, # ref: 2.86
expected_acceptance_rate=0.5, # ref: 0.62
)
assert_draft_model_correctness(args)
@@ -685,9 +722,8 @@ def test_draft_model_parallel_drafting():
sampling_config=greedy_sampling(),
parallel_drafting=True,
enforce_eager=False,
# values below are collected from a stable run, with ~5% tolerance
expected_acceptance_len=2.375,
expected_acceptance_rate=0.45,
expected_acceptance_len=2.3, # ref: 2.52
expected_acceptance_rate=0.4, # ref: 0.51
)
assert_draft_model_correctness(args)
@@ -723,6 +759,7 @@ def test_draft_model_tensor_parallelism():
draft_tensor_parallel_size=2,
**some_high_acceptance_metrics(),
enforce_eager=False,
expected_gsm8k_accuracy=0.5,
)
assert_draft_model_correctness(sd_case)
@@ -797,9 +834,14 @@ def assert_draft_model_correctness(args: ArgsTest):
# we don't check the outputs, only check the metrics
spec_llm.chat(test_prompts, args.sampling_config)
metrics = spec_llm.get_metrics()
acceptance_rate: float = compute_acceptance_rate(metrics)
acceptance_len: float = compute_acceptance_len(metrics)
# Need to evaluate after getting metrics to avoid polluting the AR
evaluate_llm_for_gsm8k(
spec_llm, expected_accuracy_threshold=args.expected_gsm8k_accuracy
)
del spec_llm # CLEANUP
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@@ -817,7 +859,7 @@ def assert_draft_model_correctness(args: ArgsTest):
def get_messages(dataset: str, n: int) -> list[Messages]:
if dataset == "test_prompts":
return get_test_prompts(mm_enabled=False, quiet=True, num_prompts=n)
return get_test_prompts(mm_enabled=False, num_prompts=n)
elif dataset == "likaixin/InstructCoder":
return get_instruct_coder_messages(n=n)
else:
@@ -828,8 +870,8 @@ def some_high_acceptance_metrics() -> dict:
return {
"sampling_config": greedy_sampling(),
"num_speculative_tokens": 3,
"expected_acceptance_len": 2.8 + 1,
"expected_acceptance_rate": 0.90,
"expected_acceptance_len": 3.4, # ref: 3.75
"expected_acceptance_rate": 0.8, # ref: 0.9
}