diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index c205920d0..b85f8880c 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -163,14 +163,7 @@ def run_tests( uni/multiproc executor with spec decoding.""" # Determine attention config based on platform - if current_platform.is_rocm(): - if is_testing_with_spec_decoding: - # Use TRITON_ATTN for spec decoding test for consistency - attention_config = {"backend": "TRITON_ATTN"} - else: - attention_config = {"backend": "ROCM_ATTN"} - else: - attention_config = {"backend": "FLEX_ATTENTION"} + attention_config = {"backend": "FLEX_ATTENTION"} with monkeypatch.context() as m: # lock matmul precision to full FP32 (IEEE) @@ -226,15 +219,7 @@ def run_tests( name_1=f"config=[{test_config}], params={params}", ) - # On ROCm with TRITON_ATTN (spec decoding test), skip strict - # logprobs comparison when logprobs are requested - skip_logprobs_check = ( - current_platform.is_rocm() - and params.get("logprobs") - and is_testing_with_spec_decoding - ) - if not skip_logprobs_check: - assert _all_logprobs_match(base_logprobs, test_logprobs) + assert _all_logprobs_match(base_logprobs, test_logprobs) if ( base_acceptance_rate is not None @@ -374,12 +359,7 @@ def _all_logprobs_match(req_a, req_b) -> bool: def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool: - if current_platform.is_rocm(): - # ROCm has higher numerical variance - # due to use of float16. - rel_tol, abs_tol = 5e-2, 1e-5 - else: - rel_tol, abs_tol = 1e-3, 1e-6 + rel_tol, abs_tol = 1e-3, 1e-6 return ( len(lps_a) == len(lps_b) and lps_a.keys() == lps_b.keys() diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 494ebedc5..bc2afd2c3 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -185,18 +185,21 @@ class FakeNixlWrapper: def _make_fake_nixl_pkg(): """Context manager that creates a temporary package making `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper. + Also creates rixl package for ROCm compatibility. Automatically cleans up the temporary directory when done. """ with tempfile.TemporaryDirectory() as td: - pkg_root = os.path.join(td, "nixl", "_api") - os.makedirs(pkg_root, exist_ok=True) + # Create both nixl and rixl packages for cross-platform compatibility + for pkg_name in ["nixl", "rixl"]: + pkg_root = os.path.join(td, pkg_name, "_api") + os.makedirs(pkg_root, exist_ok=True) - # Get the source code of FakeNixlWrapper class and dedent it - fake_nixl_source = inspect.getsource(FakeNixlWrapper) - fake_nixl_source = textwrap.dedent(fake_nixl_source) + # Get the source code of FakeNixlWrapper class and dedent it + fake_nixl_source = inspect.getsource(FakeNixlWrapper) + fake_nixl_source = textwrap.dedent(fake_nixl_source) - stub = f"""\ + stub = f"""\ # Copy of FakeNixlWrapper implementation for Ray workers import uuid from collections import defaultdict @@ -206,16 +209,17 @@ from collections import defaultdict # Export as nixl_agent nixl_agent = FakeNixlWrapper """ - with open(os.path.join(pkg_root, "__init__.py"), "w") as f: - f.write(stub) + with open(os.path.join(pkg_root, "__init__.py"), "w") as f: + f.write(stub) + + # Mock nixlXferTelemetry class + pkg_root2 = os.path.join(td, pkg_name, "_bindings") + os.makedirs(pkg_root2, exist_ok=True) + with open(os.path.join(pkg_root2, "__init__.py"), "w") as f: + f.write("class nixlXferTelemetry: pass") + # touch parent package + open(os.path.join(td, pkg_name, "__init__.py"), "w").close() - # Mock nixlXferTelemetry class - pkg_root2 = os.path.join(td, "nixl", "_bindings") - os.makedirs(pkg_root2, exist_ok=True) - with open(os.path.join(pkg_root2, "__init__.py"), "w") as f: - f.write("class nixlXferTelemetry: pass") - # touch parent package - open(os.path.join(td, "nixl", "__init__.py"), "w").close() yield td diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4c5e32340..f6d198f63 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -187,6 +187,11 @@ class EagleProposer: rocm_types.append(MLACommonMetadata) + # FlexAttention backend support + from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata + + rocm_types.append(FlexAttentionMetadata) + self.allowed_attn_types = tuple(rocm_types) # Parse the speculative token tree.