[ROCm][CI] Extending attention backend coverage for Eagle spec decode tests (#35265)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-25 16:16:18 -06:00
committed by GitHub
parent c97234c08b
commit 9571e99945
4 changed files with 314 additions and 150 deletions

View File

@@ -30,7 +30,7 @@ steps:
- pytest -v -s v1/engine --ignore v1/engine/test_preprocess_error_handling.py
mirror:
amd:
device: mi325_8
device: mi325_1
depends_on:
- image-build-amd
commands:

View File

@@ -1327,6 +1327,57 @@ def multi_gpu_test(*, num_gpus: int):
return wrapper
def gpu_tier_mark(*, min_gpus: int = 1, max_gpus: int | None = None):
"""
Mark a test to only run when the GPU count falls within [min_gpus, max_gpus].
Examples:
@gpu_tier_mark(min_gpus=2) # only on multi-GPU
@gpu_tier_mark(max_gpus=1) # only on single-GPU
@gpu_tier_mark(min_gpus=2, max_gpus=4) # 2-4 GPUs only
"""
gpu_count = cuda_device_count_stateless()
marks = []
if min_gpus > 1:
marks.append(pytest.mark.distributed(num_gpus=min_gpus))
reasons = []
if gpu_count < min_gpus:
reasons.append(f"Need at least {min_gpus} GPUs (have {gpu_count})")
if max_gpus is not None and gpu_count > max_gpus:
reasons.append(f"Need at most {max_gpus} GPUs (have {gpu_count})")
if reasons:
marks.append(pytest.mark.skipif(True, reason="; ".join(reasons)))
return marks
def single_gpu_only(f=None):
"""Skip this test when running in a multi-GPU environment."""
marks = gpu_tier_mark(max_gpus=1)
def wrapper(func):
for mark in reversed(marks):
func = mark(func)
return func
return wrapper(f) if f is not None else wrapper
def multi_gpu_only(*, num_gpus: int = 2):
"""Skip this test when running on fewer than num_gpus GPUs."""
marks = gpu_tier_mark(min_gpus=num_gpus)
def wrapper(f):
for mark in reversed(marks):
f = mark(f)
return f
return wrapper
async def completions_with_server_args(
prompts: list[str],
model_name: str,

View File

@@ -6,6 +6,7 @@ from typing import Any
import pytest
import torch._dynamo.config as dynamo_config
from tests.utils import large_gpu_mark, single_gpu_only
from vllm import SamplingParams
from vllm.logprobs import Logprob
from vllm.platforms import current_platform
@@ -36,6 +37,7 @@ default_params = dict(
)
@single_gpu_only
def test_without_spec_decoding(
sample_json_schema,
monkeypatch: pytest.MonkeyPatch,
@@ -95,6 +97,8 @@ def test_without_spec_decoding(
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
@single_gpu_only
@large_gpu_mark(min_gb=16)
def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking,

View File

@@ -9,7 +9,13 @@ import pytest
import torch
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
from tests.utils import (
get_attn_backend_list_based_on_platform,
large_gpu_mark,
multi_gpu_marks,
multi_gpu_only,
single_gpu_only,
)
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
@@ -160,6 +166,8 @@ def reset_torch_dynamo():
},
],
)
@single_gpu_only
@large_gpu_mark(min_gb=20)
def test_ngram_and_suffix_correctness(
speculative_config: dict,
model_name: str,
@@ -175,6 +183,8 @@ def test_ngram_and_suffix_correctness(
cleanup_dist_env_and_memory()
@single_gpu_only
@large_gpu_mark(min_gb=20)
def test_suffix_decoding_acceptance(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
@@ -242,6 +252,8 @@ def test_suffix_decoding_acceptance(
],
ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
)
@single_gpu_only
@large_gpu_mark(min_gb=24)
def test_speculators_model_integration(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
@@ -319,137 +331,7 @@ def test_speculators_model_integration(
)
@pytest.mark.parametrize(
[
"model_setup",
"mm_enabled",
"enable_chunked_prefill",
"model_impl",
"expected_accuracy_threshold",
],
[
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"auto",
0.8, # ref: 90%
),
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"transformers",
0.8, # ref: 90%
),
pytest.param(
(
"eagle3",
"Qwen/Qwen3-VL-8B-Instruct",
"taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
1,
),
False,
False,
"auto",
0.8, # ref: 90%
marks=pytest.mark.skip(
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
),
),
pytest.param(
(
"eagle3",
"Qwen/Qwen2.5-VL-7B-Instruct",
"Rayzl/qwen2.5-vl-7b-eagle3-sgl",
1,
),
False,
False,
"auto",
0.7, # TODO, update this with a reference value when re-enabling this case
marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a a multiple of 32"
),
),
pytest.param(
(
"eagle",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
1,
),
False,
True,
"auto",
0.7, # ref: 75%-80%
marks=large_gpu_mark(min_gb=40),
), # works on 4x H100
(
(
"eagle3",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
1,
),
False,
False,
"auto",
0.7, # ref: 75%-80%
),
pytest.param(
(
"eagle",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
4,
),
False,
False,
"auto",
0.8, # ref: 90%
# marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
pytest.param(
(
"eagle",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
4,
),
True,
True,
"auto",
0.8, # ref: 90%
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
(
(
"eagle",
"eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random",
1,
),
False,
False,
"auto",
0.0, # dummy model, skip gsm8k check
),
],
ids=[
"qwen3_eagle3",
"qwen3_eagle3-transformers",
"qwen3_vl_eagle3",
"qwen2_5_vl_eagle3",
"llama3_eagle",
"llama3_eagle3",
"llama4_eagle",
"llama4_eagle_mm",
"deepseek_eagle",
],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
def test_eagle_correctness(
def _run_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
@@ -460,14 +342,10 @@ def test_eagle_correctness(
attn_backend: str,
):
"""
Compare the outputs of a original LLM and a speculative LLM
which should be the same when using eagle speculative decoding. Due to some variance
in the engine, it is possible for some outputs to differ, so we expect that at least
6/10 output tokens match exactly, and that the GSM8k accuracy is above
a precomputed reference threshold for each model.
Compare the outputs of an original LLM and a speculative LLM
which should be the same when using eagle speculative decoding.
"""
if attn_backend == "TREE_ATTN":
# TODO: Fix this flaky test
pytest.skip(
"TREE_ATTN is flaky in the test disable for now until it can be "
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
@@ -484,17 +362,17 @@ def test_eagle_correctness(
f"transformers>={required}, but got {installed}"
)
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
# Determine attention config
# Scout requires default backend selection because vision encoder has
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
# to Flex Attn
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
attention_config = None # Let it fall back to default
print(
"FLASH_ATTN for spec_decode not supported on "
"ROCm currently. Changing to FLEX_ATTENTION backend."
)
attention_config = {"backend": "FLEX_ATTENTION"}
else:
attention_config = None
else:
attention_config = {"backend": attn_backend}
@@ -509,7 +387,9 @@ def test_eagle_correctness(
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
if "deepseek" in model_setup[1].lower():
pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
m.setenv("VLLM_ROCM_USE_AITER", "1")
m.delenv("VLLM_MLA_DISABLE", raising=False)
attention_config = {"backend": "TRITON_MLA"}
else:
m.setenv("VLLM_ROCM_USE_AITER", "1")
@@ -563,14 +443,235 @@ def test_eagle_correctness(
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 60% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.6 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@single_gpu_only
@pytest.mark.parametrize(
[
"model_setup",
"mm_enabled",
"enable_chunked_prefill",
"model_impl",
"expected_accuracy_threshold",
],
[
(
(
"eagle",
"eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random",
1,
),
False,
False,
"auto",
0.0,
),
],
ids=["deepseek_eagle"],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
def test_eagle_correctness_light(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
expected_accuracy_threshold: float,
enable_chunked_prefill: bool,
model_impl: str,
attn_backend: str,
):
_run_eagle_correctness(
monkeypatch,
sampling_config,
model_setup,
mm_enabled,
expected_accuracy_threshold,
enable_chunked_prefill,
model_impl,
attn_backend,
)
@single_gpu_only
@large_gpu_mark(min_gb=24)
@pytest.mark.parametrize(
[
"model_setup",
"mm_enabled",
"enable_chunked_prefill",
"model_impl",
"expected_accuracy_threshold",
],
[
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"auto",
0.8,
),
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"transformers",
0.8,
),
pytest.param(
(
"eagle3",
"Qwen/Qwen3-VL-8B-Instruct",
"taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
1,
),
False,
False,
"auto",
0.8,
marks=pytest.mark.skip(
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
),
),
pytest.param(
(
"eagle3",
"Qwen/Qwen2.5-VL-7B-Instruct",
"Rayzl/qwen2.5-vl-7b-eagle3-sgl",
1,
),
False,
False,
"auto",
0.7,
marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a multiple of 32"
),
),
(
(
"eagle3",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
1,
),
False,
False,
"auto",
0.7,
),
],
ids=[
"qwen3_eagle3",
"qwen3_eagle3-transformers",
"qwen3_vl_eagle3",
"qwen2_5_vl_eagle3",
"llama3_eagle3",
],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
def test_eagle_correctness_medium(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
expected_accuracy_threshold: float,
enable_chunked_prefill: bool,
model_impl: str,
attn_backend: str,
):
_run_eagle_correctness(
monkeypatch,
sampling_config,
model_setup,
mm_enabled,
expected_accuracy_threshold,
enable_chunked_prefill,
model_impl,
attn_backend,
)
@pytest.mark.parametrize(
[
"model_setup",
"mm_enabled",
"enable_chunked_prefill",
"model_impl",
"expected_accuracy_threshold",
],
[
pytest.param(
(
"eagle",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
1,
),
False,
True,
"auto",
0.7,
marks=large_gpu_mark(min_gb=40),
id="llama3_eagle",
),
pytest.param(
(
"eagle",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
4,
),
False,
False,
"auto",
0.8,
marks=multi_gpu_marks(num_gpus=4),
id="llama4_eagle",
),
pytest.param(
(
"eagle",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
4,
),
True,
True,
"auto",
0.8,
marks=[*multi_gpu_marks(num_gpus=4), large_gpu_mark(min_gb=80)],
id="llama4_eagle_mm",
),
],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
def test_eagle_correctness_heavy(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
expected_accuracy_threshold: float,
enable_chunked_prefill: bool,
model_impl: str,
attn_backend: str,
):
_run_eagle_correctness(
monkeypatch,
sampling_config,
model_setup,
mm_enabled,
expected_accuracy_threshold,
enable_chunked_prefill,
model_impl,
attn_backend,
)
@pytest.mark.parametrize(
["model_setup", "mm_enabled", "expected_accuracy_threshold"],
[
@@ -579,6 +680,8 @@ def test_eagle_correctness(
],
ids=["mimo", "deepseek"],
)
@single_gpu_only
@large_gpu_mark(min_gb=20)
def test_mtp_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
@@ -694,11 +797,13 @@ cases = [
@pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False])
@single_gpu_only
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
args.enforce_eager = enforce_eager
assert_draft_model_correctness(args)
@single_gpu_only
def test_draft_model_realistic_example():
args = ArgsTest(
target_model="Qwen/Qwen3-1.7B",
@@ -713,6 +818,7 @@ def test_draft_model_realistic_example():
assert_draft_model_correctness(args)
@single_gpu_only
def test_draft_model_parallel_drafting():
args = ArgsTest(
target_model="Qwen/Qwen3-1.7B",
@@ -738,6 +844,7 @@ def test_draft_model_parallel_drafting():
ids=["target_quantized", "draft_quantized"],
)
@pytest.mark.parametrize("enforce_eager", [True, False])
@single_gpu_only
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
tgt_model, draft_model = models
sd_case = ArgsTest(
@@ -749,6 +856,7 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
assert_draft_model_correctness(sd_case)
@multi_gpu_only(num_gpus=2)
def test_draft_model_tensor_parallelism():
"""Ensure spec decode works when running with TP > 1."""
_skip_if_insufficient_gpus_for_tp(2)
@@ -764,6 +872,7 @@ def test_draft_model_tensor_parallelism():
assert_draft_model_correctness(sd_case)
@multi_gpu_only(num_gpus=2)
def test_draft_model_engine_args_tensor_parallelism():
"""Ensure the vllm_config for the draft model is created correctly,
and independently of the target model (quantization, TP, etc.)"""