[ROCm][CI] Add ROCm attention backend support for EAGLE DP tests (#32363)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -9,18 +9,40 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
DP_SIZE = int(os.getenv("DP_SIZE", 2))
|
||||
|
||||
if current_platform.is_rocm():
|
||||
ATTN_BACKENDS = ["ROCM_ATTN", "TRITON_ATTN", "FLEX_ATTENTION"]
|
||||
else:
|
||||
ATTN_BACKENDS = ["FLASH_ATTN"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_eagle_dp(monkeypatch: pytest.MonkeyPatch):
|
||||
# This test checks that running a model with and without eagle
|
||||
# leads to identical tokens. This is only true in batch invariant mode
|
||||
# (because the target model verifies all draft tokens in one big forward pass)
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
@pytest.mark.parametrize("attn_backend", ATTN_BACKENDS)
|
||||
@pytest.mark.xfail(
|
||||
current_platform.is_rocm(),
|
||||
reason="Test may fail on ROCm until batch invariance is enabled."
|
||||
"See: https://github.com/vllm-project/vllm/issues/27433",
|
||||
strict=False,
|
||||
)
|
||||
async def test_run_eagle_dp(monkeypatch: pytest.MonkeyPatch, attn_backend: str):
|
||||
if not current_platform.is_rocm():
|
||||
# This test checks that running a model with and without eagle
|
||||
# leads to identical tokens.
|
||||
#
|
||||
# NOTE: This is only true in batch invariant mode
|
||||
# (because the target model verifies all draft tokens in one big
|
||||
# forward pass)
|
||||
#
|
||||
# TODO[ROCm]: Test is passing on ROCm CI but may break in future.
|
||||
# Enable batch invariance for ROCm when possible. See:
|
||||
# https://github.com/vllm-project/vllm/issues/27433
|
||||
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
|
||||
target_model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
@@ -34,7 +56,7 @@ async def test_run_eagle_dp(monkeypatch: pytest.MonkeyPatch):
|
||||
data_parallel_backend="mp", # ray takes more time
|
||||
trust_remote_code=True,
|
||||
max_model_len=16384,
|
||||
attention_config={"backend": "FLASH_ATTN"},
|
||||
attention_config={"backend": attn_backend},
|
||||
)
|
||||
|
||||
eagle_engine_args = replace(
|
||||
|
||||
Reference in New Issue
Block a user