[ROCm] [V1] [SpecDec] Enable Speculative Decoding on ROCm V1 Engine (#21496)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-08-07 19:13:17 -07:00
committed by GitHub
parent acf8aeb79e
commit 1ee5ead5f8
6 changed files with 128 additions and 41 deletions

View File

@@ -8,10 +8,12 @@ from typing import Any, Union
import pytest
import torch
from tests.utils import get_attn_backend_list_based_on_platform
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform
def get_test_prompts(mm_enabled: bool):
@@ -141,11 +143,14 @@ def test_ngram_correctness(
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
attn_backend: str,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
@@ -156,6 +161,16 @@ def test_eagle_correctness(
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN_VLLM_V1"
and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform")
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")
method, model_name, spec_model_name, tp_size = model_setup
ref_llm = LLM(model=model_name,