diff --git a/tests/evals/gsm8k/gsm8k_eval.py b/tests/evals/gsm8k/gsm8k_eval.py index 0421f8bb1..647c149ef 100644 --- a/tests/evals/gsm8k/gsm8k_eval.py +++ b/tests/evals/gsm8k/gsm8k_eval.py @@ -110,6 +110,65 @@ async def call_vllm_api( return "", 0 +def _build_gsm8k_prompts( + num_questions: int = 1319, + num_shots: int = 5, +) -> tuple[list[str], list[int]]: + """Build few-shot GSM8K completion prompts and ground-truth labels.""" + if num_questions == 0: + return [], [] + train_data, test_data = load_gsm8k_data() + num_questions = min(num_questions, len(test_data)) + + few_shot_examples = "" + for i in range(num_shots): + few_shot_examples += ( + f"Question: {train_data[i]['question']}\n" + f"Answer: {train_data[i]['answer']}\n\n" + ) + + prompts = [] + labels = [] + for i in range(num_questions): + prompts.append( + few_shot_examples + f"Question: {test_data[i]['question']}\nAnswer:" + ) + labels.append(get_answer_value(test_data[i]["answer"])) + + assert all(label != INVALID for label in labels), "Some labels are invalid" + return prompts, labels + + +def _score_gsm8k( + states: list[str], + output_tokens: list[int], + labels: list[int], + num_shots: int, + max_tokens: int, + latency: float, +) -> dict[str, float | int]: + """Score GSM8K responses and return a results dict.""" + num_questions = len(labels) + preds = [get_answer_value(state) for state in states] + accuracy = np.mean(np.array(preds) == np.array(labels)) + invalid_rate = np.mean(np.array(preds) == INVALID) + total_output_tokens = sum(output_tokens) + tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0 + + return { + "accuracy": accuracy, + "invalid_rate": invalid_rate, + "latency": latency, + "questions_per_second": num_questions / latency if latency > 0 else 0.0, + "total_output_tokens": total_output_tokens, + "tokens_per_second": tokens_per_second, + "num_questions": num_questions, + "num_shots": num_shots, + "max_tokens": max_tokens, + "timestamp": time.time(), + } + + def evaluate_gsm8k( num_questions: int = 1319, num_shots: int = 5, @@ -125,40 +184,17 @@ def evaluate_gsm8k( Returns dict with accuracy, invalid_rate, latency, etc. """ base_url = f"{host}:{port}" + prompts, labels = _build_gsm8k_prompts(num_questions, num_shots) + num_questions = len(prompts) - # Load GSM8K train and test data - train_data, test_data = load_gsm8k_data() - - # Limit to available test questions - num_questions = min(num_questions, len(test_data)) - - # Build few-shot examples from train split (like lm-eval does) - few_shot_examples = "" - for i in range(num_shots): - few_shot_examples += ( - f"Question: {train_data[i]['question']}\n" - f"Answer: {train_data[i]['answer']}\n\n" - ) - - # Prepare test questions and labels from test split - questions = [] - labels = [] - for i in range(num_questions): - questions.append(f"Question: {test_data[i]['question']}\nAnswer:") - labels.append(get_answer_value(test_data[i]["answer"])) - - assert all(label != INVALID for label in labels), "Some labels are invalid" - - # Run evaluation async def run_async_evaluation(): states: list[str] = [""] * num_questions output_tokens: list[int] = [0] * num_questions async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]: - prompt = few_shot_examples + questions[i] answer, tokens = await call_vllm_api( session=session, - prompt=prompt, + prompt=prompts[i], temperature=temperature, max_tokens=max_tokens, stop=["Question", "Assistant:", "<|separator|>"], @@ -183,27 +219,43 @@ def evaluate_gsm8k( states, output_tokens = asyncio.run(run_async_evaluation()) latency = time.perf_counter() - tic - # Compute metrics - preds = [get_answer_value(state) for state in states] - accuracy = np.mean(np.array(preds) == np.array(labels)) - invalid_rate = np.mean(np.array(preds) == INVALID) - total_output_tokens = sum(output_tokens) - tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0 + return _score_gsm8k(states, output_tokens, labels, num_shots, max_tokens, latency) - result = { - "accuracy": accuracy, - "invalid_rate": invalid_rate, - "latency": latency, - "questions_per_second": num_questions / latency, - "total_output_tokens": total_output_tokens, - "tokens_per_second": tokens_per_second, - "num_questions": num_questions, - "num_shots": num_shots, - "max_tokens": max_tokens, - "timestamp": time.time(), - } - return result +def evaluate_gsm8k_offline( + llm, + num_questions: int = 1319, + num_shots: int = 5, + max_tokens: int = 256, + temperature: float = 0.0, +) -> dict[str, float | int]: + """Evaluate GSM8K accuracy using an offline vllm.LLM object. + + Same prompts and scoring as evaluate_gsm8k(), but runs generation + directly via llm.generate() instead of calling a server over HTTP. + """ + from vllm import SamplingParams + + prompts, labels = _build_gsm8k_prompts(num_questions, num_shots) + + sampling_params = SamplingParams( + temperature=temperature, + max_tokens=max_tokens, + stop=["Question", "Assistant:", "<|separator|>"], + ) + + print( + f"Running offline GSM8K evaluation: {len(prompts)} questions, {num_shots}-shot" + ) + + tic = time.perf_counter() + outputs = llm.generate(prompts, sampling_params) + latency = time.perf_counter() - tic + + states = [o.outputs[0].text for o in outputs] + output_tokens = [len(o.outputs[0].token_ids) for o in outputs] + + return _score_gsm8k(states, output_tokens, labels, num_shots, max_tokens, latency) def main() -> None: diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index a141e9da0..9289d1ce1 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -8,6 +8,7 @@ from typing import Any import pytest import torch +from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark from vllm import LLM, SamplingParams from vllm.assets.base import VLLM_S3_BUCKET_URL @@ -35,53 +36,57 @@ def _skip_if_insufficient_gpus_for_tp(tp_size: int): Messages = list[dict[str, Any]] -def get_test_prompts( - mm_enabled: bool, quiet: bool = False, num_prompts: int = 100 -) -> list[Messages]: - prompt_types = ["repeat", "sentence"] +def get_test_prompts(mm_enabled: bool, num_prompts: int = 100) -> list[Messages]: + prompt_types = ["repeat", "gsm8k"] if mm_enabled: prompt_types.append("mm") - prompts = [] + prompts: list[Messages] = [] - random.seed(0) - random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) - - if not quiet: - print(f"Prompt types: {random_prompt_type_choices}") + num_repeat_prompts = num_prompts // len(prompt_types) + if mm_enabled: + num_gsm8k_prompts = num_prompts // len(prompt_types) + num_mm_prompts = num_prompts - num_repeat_prompts - num_gsm8k_prompts + else: + num_mm_prompts = 0 + num_gsm8k_prompts = num_prompts - num_repeat_prompts # Generate a mixed batch of prompts, some of which can be easily # predicted by n-gram matching and some which likely cannot. - for kind in random_prompt_type_choices: + random.seed(0) + for _ in range(num_repeat_prompts): word_choices = ["test", "temp", "hello", "where"] word = random.choice(word_choices) - prompt: str | list[dict[str, Any]] = "" - if kind == "repeat": - prompt = f""" - please repeat the word '{word}' 10 times. - give no other output than the word at least ten times in a row, - in lowercase with spaces between each word and without quotes. - """ - elif kind == "sentence": - prompt = f""" - please give a ten-word sentence that - uses the word {word} at least once. - give no other output than that simple sentence without quotes. - """ - elif kind == "mm": - placeholders = [ + prompts.append( + [ { - "type": "image_url", - "image_url": { - "url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg" - }, + "role": "user", + "content": f""" + please repeat the word '{word}' 10 times. + give no other output than the word at least ten times in a row, + in lowercase with spaces between each word and without quotes. + """, } ] - prompt = [ - *placeholders, - {"type": "text", "text": "The meaning of the image is"}, - ] - else: - raise ValueError(f"Unknown prompt type: {kind}") + ) + prompts.extend( + [{"role": "user", "content": prompt}] + for prompt in _build_gsm8k_prompts( + num_questions=num_gsm8k_prompts, num_shots=5 + )[0] + ) + for _ in range(num_mm_prompts): + placeholders = [ + { + "type": "image_url", + "image_url": { + "url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg" + }, + } + ] + prompt = [ + *placeholders, + {"type": "text", "text": "The meaning of the image is"}, + ] prompts.append([{"role": "user", "content": prompt}]) return prompts @@ -113,6 +118,25 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" +def evaluate_llm_for_gsm8k(llm: LLM, expected_accuracy_threshold: float = 0.70) -> None: + """Evaluate the LLM on GSM8K and check that accuracy is above a sanity threshold. + + The default threshold assumes the LLM uses the same target model as the "model_name" + fixture, with max model len == 4096. Precomputed reference value is 75% to 80% + on GSM8K with greedy decoding, so we check that it's above a sanity threshold of 70% + to verify that the model is correct. + """ + if expected_accuracy_threshold <= 0.0: + print("Skipping GSM8K evaluation") + return + results = evaluate_gsm8k_offline(llm) + accuracy = results["accuracy"] + print(f"GSM8K accuracy: {accuracy:.3f}") + assert accuracy >= expected_accuracy_threshold, ( + f"Expected GSM8K accuracy >= {expected_accuracy_threshold}, got {accuracy:.3f}" + ) + + @pytest.fixture(autouse=True) def reset_torch_dynamo(): """Reset torch dynamo cache before each test""" @@ -138,41 +162,14 @@ def reset_torch_dynamo(): ) def test_ngram_and_suffix_correctness( speculative_config: dict, - monkeypatch: pytest.MonkeyPatch, - sampling_config: SamplingParams, model_name: str, ): - """ - Compare the outputs of an original LLM and a speculative LLM - should be the same when using ngram speculative decoding. - """ - test_prompts = get_test_prompts(mm_enabled=False) - - ref_llm = LLM(model=model_name, max_model_len=1024) - ref_outputs = ref_llm.chat(test_prompts, sampling_config) - del ref_llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() - spec_llm = LLM( model=model_name, speculative_config=speculative_config, - max_model_len=1024, + max_model_len=4096, ) - spec_outputs = spec_llm.chat(test_prompts, sampling_config) - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") - - # Heuristic: expect at least 66% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches >= int(0.66 * len(ref_outputs)) + evaluate_llm_for_gsm8k(spec_llm) del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() @@ -238,10 +235,10 @@ def test_suffix_decoding_acceptance( @pytest.mark.parametrize( - "model_path", + ["model_path", "expected_accuracy_threshold"], [ - "RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3", - "RedHatAI/Qwen3-8B-speculator.eagle3", + ("RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3", 0.7), # ref: 75%-80% + ("RedHatAI/Qwen3-8B-speculator.eagle3", 0.8), # ref: 87%-92% ], ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"], ) @@ -249,6 +246,7 @@ def test_speculators_model_integration( monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, model_path: str, + expected_accuracy_threshold: float, ): """ Test that speculators models work with the simplified integration. @@ -262,7 +260,8 @@ def test_speculators_model_integration( 2. Verifier model is extracted from speculator config 3. Speculative decoding is automatically enabled 4. Text generation works correctly - 5. Output matches reference (non-speculative) generation + 5. GSM8k accuracy of the model passes a sanity check when speculative decoding on + 6. Output matches reference (non-speculative) generation """ monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") @@ -270,7 +269,10 @@ def test_speculators_model_integration( test_prompts = get_test_prompts(mm_enabled=False) # First run: Direct speculator model (simplified integration) - spec_llm = LLM(model=model_path, max_model_len=1024) + spec_llm = LLM(model=model_path, max_model_len=4096) + evaluate_llm_for_gsm8k( + spec_llm, expected_accuracy_threshold=expected_accuracy_threshold + ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) # Verify speculative config was auto-detected @@ -297,7 +299,7 @@ def test_speculators_model_integration( cleanup_dist_env_and_memory() # Second run: Reference without speculative decoding - ref_llm = LLM(model=verifier_model, max_model_len=1024) + ref_llm = LLM(model=verifier_model, max_model_len=4096) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm torch.cuda.empty_cache() @@ -318,19 +320,27 @@ def test_speculators_model_integration( @pytest.mark.parametrize( - ["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"], + [ + "model_setup", + "mm_enabled", + "enable_chunked_prefill", + "model_impl", + "expected_accuracy_threshold", + ], [ ( ("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False, "auto", + 0.8, # ref: 90% ), ( ("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False, "transformers", + 0.8, # ref: 90% ), pytest.param( ( @@ -342,6 +352,7 @@ def test_speculators_model_integration( False, False, "auto", + 0.8, # ref: 90% marks=pytest.mark.skip( reason="architecture of its eagle3 is LlamaForCausalLMEagle3" ), @@ -356,6 +367,7 @@ def test_speculators_model_integration( False, False, "auto", + 0.7, # TODO, update this with a reference value when re-enabling this case marks=pytest.mark.skip( reason="Skipping due to its head_dim not being a a multiple of 32" ), @@ -370,6 +382,7 @@ def test_speculators_model_integration( False, True, "auto", + 0.7, # ref: 75%-80% marks=large_gpu_mark(min_gb=40), ), # works on 4x H100 ( @@ -382,6 +395,7 @@ def test_speculators_model_integration( False, False, "auto", + 0.7, # ref: 75%-80% ), pytest.param( ( @@ -393,7 +407,8 @@ def test_speculators_model_integration( False, False, "auto", - marks=large_gpu_mark(min_gb=80), + 0.8, # ref: 90% + # marks=large_gpu_mark(min_gb=80), ), # works on 4x H100 pytest.param( ( @@ -405,6 +420,7 @@ def test_speculators_model_integration( True, True, "auto", + 0.8, # ref: 90% marks=large_gpu_mark(min_gb=80), ), # works on 4x H100 ( @@ -417,6 +433,7 @@ def test_speculators_model_integration( False, False, "auto", + 0.0, # dummy model, skip gsm8k check ), ], ids=[ @@ -437,10 +454,18 @@ def test_eagle_correctness( sampling_config: SamplingParams, model_setup: tuple[str, str, str, int], mm_enabled: bool, + expected_accuracy_threshold: float, enable_chunked_prefill: bool, model_impl: str, attn_backend: str, ): + """ + Compare the outputs of a original LLM and a speculative LLM + which should be the same when using eagle speculative decoding. Due to some variance + in the engine, it is possible for some outputs to differ, so we expect that at least + 6/10 output tokens match exactly, and that the GSM8k accuracy is above + a precomputed reference threshold for each model. + """ if attn_backend == "TREE_ATTN": # TODO: Fix this flaky test pytest.skip( @@ -461,11 +486,6 @@ def test_eagle_correctness( # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) - """ - Compare the outputs of a original LLM and a speculative LLM - should be the same when using eagle speculative decoding. - model_setup: (method, model_name, eagle_model_name, tp_size) - """ # Determine attention config # Scout requires default backend selection because vision encoder has # head_dim 88 being incompatible with FLASH_ATTN and needs to fall back @@ -505,6 +525,9 @@ def test_eagle_correctness( tensor_parallel_size=tp_size, attention_config=attention_config, ) + evaluate_llm_for_gsm8k( + ref_llm, expected_accuracy_threshold=expected_accuracy_threshold + ) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm torch.cuda.empty_cache() @@ -526,6 +549,9 @@ def test_eagle_correctness( model_impl=model_impl, attention_config=attention_config, ) + evaluate_llm_for_gsm8k( + spec_llm, expected_accuracy_threshold=expected_accuracy_threshold + ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 misses = 0 @@ -546,10 +572,10 @@ def test_eagle_correctness( @pytest.mark.parametrize( - ["model_setup", "mm_enabled"], + ["model_setup", "mm_enabled", "expected_accuracy_threshold"], [ - (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False), - (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False), + (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False, 0.5), # ref: 65%-70% + (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False, 0.0), # dummy model ], ids=["mimo", "deepseek"], ) @@ -558,14 +584,17 @@ def test_mtp_correctness( sampling_config: SamplingParams, model_setup: tuple[str, str, int], mm_enabled: bool, + expected_accuracy_threshold: float, ): - # Generate test prompts inside the function instead of using fixture - test_prompts = get_test_prompts(mm_enabled) """ Compare the outputs of a original LLM and a speculative LLM - should be the same when using MTP speculative decoding. - model_setup: (method, model_name, tp_size) + which should be the same when using MTP speculative decoding. Due to some variance + in the engine, it is possible for some outputs to differ, so we expect that at least + 6/10 output tokens match exactly, and that the GSM8k accuracy is above a precomputed + reference threshold for each model. """ + # Generate test prompts inside the function instead of using fixture + test_prompts = get_test_prompts(mm_enabled) with monkeypatch.context() as m: m.setenv("VLLM_MLA_DISABLE", "1") @@ -579,6 +608,9 @@ def test_mtp_correctness( trust_remote_code=True, ) ref_outputs = ref_llm.chat(test_prompts, sampling_config) + evaluate_llm_for_gsm8k( + ref_llm, expected_accuracy_threshold=expected_accuracy_threshold + ) del ref_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() @@ -594,6 +626,9 @@ def test_mtp_correctness( }, max_model_len=2048, ) + evaluate_llm_for_gsm8k( + spec_llm, expected_accuracy_threshold=expected_accuracy_threshold + ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 misses = 0 @@ -621,12 +656,13 @@ class ArgsTest: num_speculative_tokens: int expected_acceptance_rate: float expected_acceptance_len: float + expected_gsm8k_accuracy: float = 0.0 # skip by default # Defaults enforce_eager: bool = True parallel_drafting: bool = False target_tensor_parallel_size: int = 1 draft_tensor_parallel_size: int = 1 - max_model_len: int = 1024 + max_model_len: int = 2048 gpu_memory_utilization: float = 0.5 dataset: str = "test_prompts" num_prompts: int = 100 @@ -639,8 +675,9 @@ cases = [ draft_model="Qwen/Qwen3-0.6B", sampling_config=greedy_sampling(), num_speculative_tokens=3, # K - expected_acceptance_len=3 + 1, # K + 1 - expected_acceptance_rate=1.0, + expected_acceptance_len=0.98 * (3 + 1), # epsilon discount of K + 1 + expected_acceptance_rate=0.98, # slight epsilon + expected_gsm8k_accuracy=0.25, # ref: 35-40% ), # Smaller draft model, stochastic sampling. ArgsTest( @@ -648,8 +685,9 @@ cases = [ draft_model="Qwen/Qwen3-0.6B", sampling_config=stochastic_sampling(), num_speculative_tokens=3, - expected_acceptance_len=2.8 + 1, - expected_acceptance_rate=0.9, + expected_acceptance_len=3.4, # ref: 3.7 + expected_acceptance_rate=0.80, # ref: 0.90 + expected_gsm8k_accuracy=0.5, # ref: 60%. Note gsm8k always runs greedy sampling ), ] @@ -669,9 +707,8 @@ def test_draft_model_realistic_example(): num_speculative_tokens=3, sampling_config=greedy_sampling(), enforce_eager=False, - # values below are not derived, but just prevent a regression - expected_acceptance_len=2.8, - expected_acceptance_rate=0.55, + expected_acceptance_len=2.6, # ref: 2.86 + expected_acceptance_rate=0.5, # ref: 0.62 ) assert_draft_model_correctness(args) @@ -685,9 +722,8 @@ def test_draft_model_parallel_drafting(): sampling_config=greedy_sampling(), parallel_drafting=True, enforce_eager=False, - # values below are collected from a stable run, with ~5% tolerance - expected_acceptance_len=2.375, - expected_acceptance_rate=0.45, + expected_acceptance_len=2.3, # ref: 2.52 + expected_acceptance_rate=0.4, # ref: 0.51 ) assert_draft_model_correctness(args) @@ -723,6 +759,7 @@ def test_draft_model_tensor_parallelism(): draft_tensor_parallel_size=2, **some_high_acceptance_metrics(), enforce_eager=False, + expected_gsm8k_accuracy=0.5, ) assert_draft_model_correctness(sd_case) @@ -797,9 +834,14 @@ def assert_draft_model_correctness(args: ArgsTest): # we don't check the outputs, only check the metrics spec_llm.chat(test_prompts, args.sampling_config) metrics = spec_llm.get_metrics() - acceptance_rate: float = compute_acceptance_rate(metrics) acceptance_len: float = compute_acceptance_len(metrics) + + # Need to evaluate after getting metrics to avoid polluting the AR + evaluate_llm_for_gsm8k( + spec_llm, expected_accuracy_threshold=args.expected_gsm8k_accuracy + ) + del spec_llm # CLEANUP torch.cuda.empty_cache() cleanup_dist_env_and_memory() @@ -817,7 +859,7 @@ def assert_draft_model_correctness(args: ArgsTest): def get_messages(dataset: str, n: int) -> list[Messages]: if dataset == "test_prompts": - return get_test_prompts(mm_enabled=False, quiet=True, num_prompts=n) + return get_test_prompts(mm_enabled=False, num_prompts=n) elif dataset == "likaixin/InstructCoder": return get_instruct_coder_messages(n=n) else: @@ -828,8 +870,8 @@ def some_high_acceptance_metrics() -> dict: return { "sampling_config": greedy_sampling(), "num_speculative_tokens": 3, - "expected_acceptance_len": 2.8 + 1, - "expected_acceptance_rate": 0.90, + "expected_acceptance_len": 3.4, # ref: 3.75 + "expected_acceptance_rate": 0.8, # ref: 0.9 }