[ROCm] [V1] [SpecDec] Enable Speculative Decoding on ROCm V1 Engine (#21496)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -8,10 +8,12 @@ from typing import Any, Union
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
@@ -141,11 +143,14 @@ def test_ngram_correctness(
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
],
|
||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
attn_backend: str,
|
||||
):
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
@@ -156,6 +161,16 @@ def test_eagle_correctness(
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
method, model_name, spec_model_name, tp_size = model_setup
|
||||
|
||||
ref_llm = LLM(model=model_name,
|
||||
|
||||
Reference in New Issue
Block a user