From ef1f7030f016cc811236517e02fa51ee8876cc31 Mon Sep 17 00:00:00 2001 From: Micah Williamson Date: Tue, 25 Nov 2025 01:55:09 -0600 Subject: [PATCH] [ROCm][CI] Fix test_cudagraph_mode failure in AMD CI (#29367) Signed-off-by: Micah Williamson --- tests/v1/attention/utils.py | 7 +++ tests/v1/cudagraph/test_cudagraph_mode.py | 62 +++++++++++++++-------- vllm/platforms/rocm.py | 4 +- 3 files changed, 51 insertions(+), 22 deletions(-) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index dea89babd..df3d53332 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -340,4 +340,11 @@ full_cg_backend_configs = { "cudagraph_mode": "FULL_AND_PIECEWISE", }, ), + "RocmAttn": BackendConfig( + name="RocmAttn", + env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"}, + comp_config={ + "cudagraph_mode": "FULL", + }, + ), } diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index d6bde16eb..7f9c2a057 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -35,14 +35,22 @@ def temporary_environ(env_vars): # test attention backend and cudagraph_mode combo # (backend_name, cudagraph_mode, supported) -combo_cases_1 = [ - ("FA3", "FULL", True), - ("FA3", "FULL_AND_PIECEWISE", True), - ("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE - ("FA2", "FULL_AND_PIECEWISE", True), - ("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE - ("FlashInfer", "FULL_AND_PIECEWISE", True), -] +if current_platform.is_rocm(): + combo_cases_1 = [ + ("RocmAttn", "FULL", True), + ("RocmAttn", "FULL_AND_PIECEWISE", True), + ("TritonAttn", "FULL", True), + ("TritonAttn", "FULL_AND_PIECEWISE", True), + ] +else: + combo_cases_1 = [ + ("FA3", "FULL", True), + ("FA3", "FULL_AND_PIECEWISE", True), + ("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE + ("FA2", "FULL_AND_PIECEWISE", True), + ("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE + ("FlashInfer", "FULL_AND_PIECEWISE", True), + ] @pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1) @@ -92,18 +100,32 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte # test cudagraph_mode with different compilation mode. # (backend_name, cudagraph_mode, compilation_mode, supported) -combo_cases_2 = [ - ("FA2", "FULL", CompilationMode.NONE, True), - ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True), - ("FA2", "PIECEWISE", CompilationMode.NONE, False), - ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False), - ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True), - ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), - ("FA2", "NONE", CompilationMode.NONE, True), - ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True), -] +if current_platform.is_rocm(): + combo_cases_2 = [ + ("RocmAttn", "FULL", CompilationMode.NONE, True), + ("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True), + ("RocmAttn", "PIECEWISE", CompilationMode.NONE, False), + ("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), + ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False), + ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), + ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True), + ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), + ("RocmAttn", "NONE", CompilationMode.NONE, True), + ("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True), + ] +else: + combo_cases_2 = [ + ("FA2", "FULL", CompilationMode.NONE, True), + ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True), + ("FA2", "PIECEWISE", CompilationMode.NONE, False), + ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), + ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False), + ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), + ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True), + ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), + ("FA2", "NONE", CompilationMode.NONE, True), + ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True), + ] @pytest.mark.parametrize( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b0434b964..0483f6c06 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -321,8 +321,8 @@ class RocmPlatform(Platform): return AttentionBackendEnum.TRITON_ATTN.get_path() raise RuntimeError( - "V0 attention backends have been removed. Set VLLM_USE_V1=1 " - "to select a supported backend." + f"Attention backend {selected_backend.name} is not supported on " + "ROCm. Note that V0 attention backends have been removed." ) @classmethod