[ROCm][CI] Extending attention backend coverage for Eagle spec decode tests (#35265)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -9,7 +9,13 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
|
||||
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
|
||||
from tests.utils import (
|
||||
get_attn_backend_list_based_on_platform,
|
||||
large_gpu_mark,
|
||||
multi_gpu_marks,
|
||||
multi_gpu_only,
|
||||
single_gpu_only,
|
||||
)
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
@@ -160,6 +166,8 @@ def reset_torch_dynamo():
|
||||
},
|
||||
],
|
||||
)
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=20)
|
||||
def test_ngram_and_suffix_correctness(
|
||||
speculative_config: dict,
|
||||
model_name: str,
|
||||
@@ -175,6 +183,8 @@ def test_ngram_and_suffix_correctness(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=20)
|
||||
def test_suffix_decoding_acceptance(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
@@ -242,6 +252,8 @@ def test_suffix_decoding_acceptance(
|
||||
],
|
||||
ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
|
||||
)
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=24)
|
||||
def test_speculators_model_integration(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
@@ -319,137 +331,7 @@ def test_speculators_model_integration(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
[
|
||||
"model_setup",
|
||||
"mm_enabled",
|
||||
"enable_chunked_prefill",
|
||||
"model_impl",
|
||||
"expected_accuracy_threshold",
|
||||
],
|
||||
[
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8, # ref: 90%
|
||||
),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"transformers",
|
||||
0.8, # ref: 90%
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen3-VL-8B-Instruct",
|
||||
"taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8, # ref: 90%
|
||||
marks=pytest.mark.skip(
|
||||
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"Rayzl/qwen2.5-vl-7b-eagle3-sgl",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.7, # TODO, update this with a reference value when re-enabling this case
|
||||
marks=pytest.mark.skip(
|
||||
reason="Skipping due to its head_dim not being a a multiple of 32"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
True,
|
||||
"auto",
|
||||
0.7, # ref: 75%-80%
|
||||
marks=large_gpu_mark(min_gb=40),
|
||||
), # works on 4x H100
|
||||
(
|
||||
(
|
||||
"eagle3",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.7, # ref: 75%-80%
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8, # ref: 90%
|
||||
# marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
True,
|
||||
True,
|
||||
"auto",
|
||||
0.8, # ref: 90%
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
(
|
||||
(
|
||||
"eagle",
|
||||
"eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.0, # dummy model, skip gsm8k check
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3",
|
||||
"qwen3_eagle3-transformers",
|
||||
"qwen3_vl_eagle3",
|
||||
"qwen2_5_vl_eagle3",
|
||||
"llama3_eagle",
|
||||
"llama3_eagle3",
|
||||
"llama4_eagle",
|
||||
"llama4_eagle_mm",
|
||||
"deepseek_eagle",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness(
|
||||
def _run_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
@@ -460,14 +342,10 @@ def test_eagle_correctness(
|
||||
attn_backend: str,
|
||||
):
|
||||
"""
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
which should be the same when using eagle speculative decoding. Due to some variance
|
||||
in the engine, it is possible for some outputs to differ, so we expect that at least
|
||||
6/10 output tokens match exactly, and that the GSM8k accuracy is above
|
||||
a precomputed reference threshold for each model.
|
||||
Compare the outputs of an original LLM and a speculative LLM
|
||||
which should be the same when using eagle speculative decoding.
|
||||
"""
|
||||
if attn_backend == "TREE_ATTN":
|
||||
# TODO: Fix this flaky test
|
||||
pytest.skip(
|
||||
"TREE_ATTN is flaky in the test disable for now until it can be "
|
||||
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
|
||||
@@ -484,17 +362,17 @@ def test_eagle_correctness(
|
||||
f"transformers>={required}, but got {installed}"
|
||||
)
|
||||
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
# Determine attention config
|
||||
# Scout requires default backend selection because vision encoder has
|
||||
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
|
||||
# to Flex Attn
|
||||
|
||||
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||
if current_platform.is_rocm():
|
||||
# TODO: Enable Flex Attn for spec_decode on ROCm
|
||||
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
|
||||
attention_config = None # Let it fall back to default
|
||||
print(
|
||||
"FLASH_ATTN for spec_decode not supported on "
|
||||
"ROCm currently. Changing to FLEX_ATTENTION backend."
|
||||
)
|
||||
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||
else:
|
||||
attention_config = None
|
||||
else:
|
||||
attention_config = {"backend": attn_backend}
|
||||
|
||||
@@ -509,7 +387,9 @@ def test_eagle_correctness(
|
||||
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
if "deepseek" in model_setup[1].lower():
|
||||
pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
m.delenv("VLLM_MLA_DISABLE", raising=False)
|
||||
attention_config = {"backend": "TRITON_MLA"}
|
||||
else:
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
@@ -563,14 +443,235 @@ def test_eagle_correctness(
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 60% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.6 * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
@pytest.mark.parametrize(
|
||||
[
|
||||
"model_setup",
|
||||
"mm_enabled",
|
||||
"enable_chunked_prefill",
|
||||
"model_impl",
|
||||
"expected_accuracy_threshold",
|
||||
],
|
||||
[
|
||||
(
|
||||
(
|
||||
"eagle",
|
||||
"eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.0,
|
||||
),
|
||||
],
|
||||
ids=["deepseek_eagle"],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness_light(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
expected_accuracy_threshold: float,
|
||||
enable_chunked_prefill: bool,
|
||||
model_impl: str,
|
||||
attn_backend: str,
|
||||
):
|
||||
_run_eagle_correctness(
|
||||
monkeypatch,
|
||||
sampling_config,
|
||||
model_setup,
|
||||
mm_enabled,
|
||||
expected_accuracy_threshold,
|
||||
enable_chunked_prefill,
|
||||
model_impl,
|
||||
attn_backend,
|
||||
)
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=24)
|
||||
@pytest.mark.parametrize(
|
||||
[
|
||||
"model_setup",
|
||||
"mm_enabled",
|
||||
"enable_chunked_prefill",
|
||||
"model_impl",
|
||||
"expected_accuracy_threshold",
|
||||
],
|
||||
[
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8,
|
||||
),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"transformers",
|
||||
0.8,
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen3-VL-8B-Instruct",
|
||||
"taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8,
|
||||
marks=pytest.mark.skip(
|
||||
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"Rayzl/qwen2.5-vl-7b-eagle3-sgl",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.7,
|
||||
marks=pytest.mark.skip(
|
||||
reason="Skipping due to its head_dim not being a multiple of 32"
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
"eagle3",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.7,
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3",
|
||||
"qwen3_eagle3-transformers",
|
||||
"qwen3_vl_eagle3",
|
||||
"qwen2_5_vl_eagle3",
|
||||
"llama3_eagle3",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness_medium(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
expected_accuracy_threshold: float,
|
||||
enable_chunked_prefill: bool,
|
||||
model_impl: str,
|
||||
attn_backend: str,
|
||||
):
|
||||
_run_eagle_correctness(
|
||||
monkeypatch,
|
||||
sampling_config,
|
||||
model_setup,
|
||||
mm_enabled,
|
||||
expected_accuracy_threshold,
|
||||
enable_chunked_prefill,
|
||||
model_impl,
|
||||
attn_backend,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
[
|
||||
"model_setup",
|
||||
"mm_enabled",
|
||||
"enable_chunked_prefill",
|
||||
"model_impl",
|
||||
"expected_accuracy_threshold",
|
||||
],
|
||||
[
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
True,
|
||||
"auto",
|
||||
0.7,
|
||||
marks=large_gpu_mark(min_gb=40),
|
||||
id="llama3_eagle",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8,
|
||||
marks=multi_gpu_marks(num_gpus=4),
|
||||
id="llama4_eagle",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
True,
|
||||
True,
|
||||
"auto",
|
||||
0.8,
|
||||
marks=[*multi_gpu_marks(num_gpus=4), large_gpu_mark(min_gb=80)],
|
||||
id="llama4_eagle_mm",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness_heavy(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
expected_accuracy_threshold: float,
|
||||
enable_chunked_prefill: bool,
|
||||
model_impl: str,
|
||||
attn_backend: str,
|
||||
):
|
||||
_run_eagle_correctness(
|
||||
monkeypatch,
|
||||
sampling_config,
|
||||
model_setup,
|
||||
mm_enabled,
|
||||
expected_accuracy_threshold,
|
||||
enable_chunked_prefill,
|
||||
model_impl,
|
||||
attn_backend,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled", "expected_accuracy_threshold"],
|
||||
[
|
||||
@@ -579,6 +680,8 @@ def test_eagle_correctness(
|
||||
],
|
||||
ids=["mimo", "deepseek"],
|
||||
)
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=20)
|
||||
def test_mtp_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
@@ -694,11 +797,13 @@ cases = [
|
||||
|
||||
@pytest.mark.parametrize("args", cases)
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@single_gpu_only
|
||||
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
|
||||
args.enforce_eager = enforce_eager
|
||||
assert_draft_model_correctness(args)
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
def test_draft_model_realistic_example():
|
||||
args = ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
@@ -713,6 +818,7 @@ def test_draft_model_realistic_example():
|
||||
assert_draft_model_correctness(args)
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
def test_draft_model_parallel_drafting():
|
||||
args = ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
@@ -738,6 +844,7 @@ def test_draft_model_parallel_drafting():
|
||||
ids=["target_quantized", "draft_quantized"],
|
||||
)
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@single_gpu_only
|
||||
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
|
||||
tgt_model, draft_model = models
|
||||
sd_case = ArgsTest(
|
||||
@@ -749,6 +856,7 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
|
||||
assert_draft_model_correctness(sd_case)
|
||||
|
||||
|
||||
@multi_gpu_only(num_gpus=2)
|
||||
def test_draft_model_tensor_parallelism():
|
||||
"""Ensure spec decode works when running with TP > 1."""
|
||||
_skip_if_insufficient_gpus_for_tp(2)
|
||||
@@ -764,6 +872,7 @@ def test_draft_model_tensor_parallelism():
|
||||
assert_draft_model_correctness(sd_case)
|
||||
|
||||
|
||||
@multi_gpu_only(num_gpus=2)
|
||||
def test_draft_model_engine_args_tensor_parallelism():
|
||||
"""Ensure the vllm_config for the draft model is created correctly,
|
||||
and independently of the target model (quantization, TP, etc.)"""
|
||||
|
||||
Reference in New Issue
Block a user