[ROCm][CI][V1] Fix nixl_connector test failure and achieve CUDA parity in test_async_scheduling (#32000)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-01-09 06:48:32 -06:00
committed by GitHub
parent b474782ad7
commit e02706d2d2
3 changed files with 27 additions and 38 deletions

View File

@@ -163,14 +163,7 @@ def run_tests(
uni/multiproc executor with spec decoding."""
# Determine attention config based on platform
if current_platform.is_rocm():
if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency
attention_config = {"backend": "TRITON_ATTN"}
else:
attention_config = {"backend": "ROCM_ATTN"}
else:
attention_config = {"backend": "FLEX_ATTENTION"}
attention_config = {"backend": "FLEX_ATTENTION"}
with monkeypatch.context() as m:
# lock matmul precision to full FP32 (IEEE)
@@ -226,15 +219,7 @@ def run_tests(
name_1=f"config=[{test_config}], params={params}",
)
# On ROCm with TRITON_ATTN (spec decoding test), skip strict
# logprobs comparison when logprobs are requested
skip_logprobs_check = (
current_platform.is_rocm()
and params.get("logprobs")
and is_testing_with_spec_decoding
)
if not skip_logprobs_check:
assert _all_logprobs_match(base_logprobs, test_logprobs)
assert _all_logprobs_match(base_logprobs, test_logprobs)
if (
base_acceptance_rate is not None
@@ -374,12 +359,7 @@ def _all_logprobs_match(req_a, req_b) -> bool:
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
if current_platform.is_rocm():
# ROCm has higher numerical variance
# due to use of float16.
rel_tol, abs_tol = 5e-2, 1e-5
else:
rel_tol, abs_tol = 1e-3, 1e-6
rel_tol, abs_tol = 1e-3, 1e-6
return (
len(lps_a) == len(lps_b)
and lps_a.keys() == lps_b.keys()