[ROCm][CI][V1] Fix nixl_connector test failure and achieve CUDA parity in test_async_scheduling (#32000)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -163,14 +163,7 @@ def run_tests(
|
||||
uni/multiproc executor with spec decoding."""
|
||||
|
||||
# Determine attention config based on platform
|
||||
if current_platform.is_rocm():
|
||||
if is_testing_with_spec_decoding:
|
||||
# Use TRITON_ATTN for spec decoding test for consistency
|
||||
attention_config = {"backend": "TRITON_ATTN"}
|
||||
else:
|
||||
attention_config = {"backend": "ROCM_ATTN"}
|
||||
else:
|
||||
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
# lock matmul precision to full FP32 (IEEE)
|
||||
@@ -226,15 +219,7 @@ def run_tests(
|
||||
name_1=f"config=[{test_config}], params={params}",
|
||||
)
|
||||
|
||||
# On ROCm with TRITON_ATTN (spec decoding test), skip strict
|
||||
# logprobs comparison when logprobs are requested
|
||||
skip_logprobs_check = (
|
||||
current_platform.is_rocm()
|
||||
and params.get("logprobs")
|
||||
and is_testing_with_spec_decoding
|
||||
)
|
||||
if not skip_logprobs_check:
|
||||
assert _all_logprobs_match(base_logprobs, test_logprobs)
|
||||
assert _all_logprobs_match(base_logprobs, test_logprobs)
|
||||
|
||||
if (
|
||||
base_acceptance_rate is not None
|
||||
@@ -374,12 +359,7 @@ def _all_logprobs_match(req_a, req_b) -> bool:
|
||||
|
||||
|
||||
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
|
||||
if current_platform.is_rocm():
|
||||
# ROCm has higher numerical variance
|
||||
# due to use of float16.
|
||||
rel_tol, abs_tol = 5e-2, 1e-5
|
||||
else:
|
||||
rel_tol, abs_tol = 1e-3, 1e-6
|
||||
rel_tol, abs_tol = 1e-3, 1e-6
|
||||
return (
|
||||
len(lps_a) == len(lps_b)
|
||||
and lps_a.keys() == lps_b.keys()
|
||||
|
||||
@@ -185,18 +185,21 @@ class FakeNixlWrapper:
|
||||
def _make_fake_nixl_pkg():
|
||||
"""Context manager that creates a temporary package making
|
||||
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.
|
||||
Also creates rixl package for ROCm compatibility.
|
||||
|
||||
Automatically cleans up the temporary directory when done.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
pkg_root = os.path.join(td, "nixl", "_api")
|
||||
os.makedirs(pkg_root, exist_ok=True)
|
||||
# Create both nixl and rixl packages for cross-platform compatibility
|
||||
for pkg_name in ["nixl", "rixl"]:
|
||||
pkg_root = os.path.join(td, pkg_name, "_api")
|
||||
os.makedirs(pkg_root, exist_ok=True)
|
||||
|
||||
# Get the source code of FakeNixlWrapper class and dedent it
|
||||
fake_nixl_source = inspect.getsource(FakeNixlWrapper)
|
||||
fake_nixl_source = textwrap.dedent(fake_nixl_source)
|
||||
# Get the source code of FakeNixlWrapper class and dedent it
|
||||
fake_nixl_source = inspect.getsource(FakeNixlWrapper)
|
||||
fake_nixl_source = textwrap.dedent(fake_nixl_source)
|
||||
|
||||
stub = f"""\
|
||||
stub = f"""\
|
||||
# Copy of FakeNixlWrapper implementation for Ray workers
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
@@ -206,16 +209,17 @@ from collections import defaultdict
|
||||
# Export as nixl_agent
|
||||
nixl_agent = FakeNixlWrapper
|
||||
"""
|
||||
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
|
||||
f.write(stub)
|
||||
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
|
||||
f.write(stub)
|
||||
|
||||
# Mock nixlXferTelemetry class
|
||||
pkg_root2 = os.path.join(td, pkg_name, "_bindings")
|
||||
os.makedirs(pkg_root2, exist_ok=True)
|
||||
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
|
||||
f.write("class nixlXferTelemetry: pass")
|
||||
# touch parent package
|
||||
open(os.path.join(td, pkg_name, "__init__.py"), "w").close()
|
||||
|
||||
# Mock nixlXferTelemetry class
|
||||
pkg_root2 = os.path.join(td, "nixl", "_bindings")
|
||||
os.makedirs(pkg_root2, exist_ok=True)
|
||||
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
|
||||
f.write("class nixlXferTelemetry: pass")
|
||||
# touch parent package
|
||||
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
|
||||
yield td
|
||||
|
||||
|
||||
|
||||
@@ -187,6 +187,11 @@ class EagleProposer:
|
||||
|
||||
rocm_types.append(MLACommonMetadata)
|
||||
|
||||
# FlexAttention backend support
|
||||
from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata
|
||||
|
||||
rocm_types.append(FlexAttentionMetadata)
|
||||
|
||||
self.allowed_attn_types = tuple(rocm_types)
|
||||
|
||||
# Parse the speculative token tree.
|
||||
|
||||
Reference in New Issue
Block a user