diff --git a/tests/v1/distributed/test_eagle_dp.py b/tests/v1/distributed/test_eagle_dp.py index 1b7c2d8ea..e20893b63 100644 --- a/tests/v1/distributed/test_eagle_dp.py +++ b/tests/v1/distributed/test_eagle_dp.py @@ -9,18 +9,40 @@ import pytest from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.v1.engine.async_llm import AsyncLLM DP_SIZE = int(os.getenv("DP_SIZE", 2)) +if current_platform.is_rocm(): + ATTN_BACKENDS = ["ROCM_ATTN", "TRITON_ATTN", "FLEX_ATTENTION"] +else: + ATTN_BACKENDS = ["FLASH_ATTN"] + @pytest.mark.asyncio -async def test_run_eagle_dp(monkeypatch: pytest.MonkeyPatch): - # This test checks that running a model with and without eagle - # leads to identical tokens. This is only true in batch invariant mode - # (because the target model verifies all draft tokens in one big forward pass) - monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1") +@pytest.mark.parametrize("attn_backend", ATTN_BACKENDS) +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="Test may fail on ROCm until batch invariance is enabled." + "See: https://github.com/vllm-project/vllm/issues/27433", + strict=False, +) +async def test_run_eagle_dp(monkeypatch: pytest.MonkeyPatch, attn_backend: str): + if not current_platform.is_rocm(): + # This test checks that running a model with and without eagle + # leads to identical tokens. + # + # NOTE: This is only true in batch invariant mode + # (because the target model verifies all draft tokens in one big + # forward pass) + # + # TODO[ROCm]: Test is passing on ROCm CI but may break in future. + # Enable batch invariance for ROCm when possible. See: + # https://github.com/vllm-project/vllm/issues/27433 + + monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1") target_model = "meta-llama/Llama-3.1-8B-Instruct" draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" @@ -34,7 +56,7 @@ async def test_run_eagle_dp(monkeypatch: pytest.MonkeyPatch): data_parallel_backend="mp", # ray takes more time trust_remote_code=True, max_model_len=16384, - attention_config={"backend": "FLASH_ATTN"}, + attention_config={"backend": attn_backend}, ) eagle_engine_args = replace(