[ROCm][CI][V1] Fix nixl_connector test failure and achieve CUDA parity in test_async_scheduling (#32000)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -163,14 +163,7 @@ def run_tests(
|
||||
uni/multiproc executor with spec decoding."""
|
||||
|
||||
# Determine attention config based on platform
|
||||
if current_platform.is_rocm():
|
||||
if is_testing_with_spec_decoding:
|
||||
# Use TRITON_ATTN for spec decoding test for consistency
|
||||
attention_config = {"backend": "TRITON_ATTN"}
|
||||
else:
|
||||
attention_config = {"backend": "ROCM_ATTN"}
|
||||
else:
|
||||
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
# lock matmul precision to full FP32 (IEEE)
|
||||
@@ -226,15 +219,7 @@ def run_tests(
|
||||
name_1=f"config=[{test_config}], params={params}",
|
||||
)
|
||||
|
||||
# On ROCm with TRITON_ATTN (spec decoding test), skip strict
|
||||
# logprobs comparison when logprobs are requested
|
||||
skip_logprobs_check = (
|
||||
current_platform.is_rocm()
|
||||
and params.get("logprobs")
|
||||
and is_testing_with_spec_decoding
|
||||
)
|
||||
if not skip_logprobs_check:
|
||||
assert _all_logprobs_match(base_logprobs, test_logprobs)
|
||||
assert _all_logprobs_match(base_logprobs, test_logprobs)
|
||||
|
||||
if (
|
||||
base_acceptance_rate is not None
|
||||
@@ -374,12 +359,7 @@ def _all_logprobs_match(req_a, req_b) -> bool:
|
||||
|
||||
|
||||
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
|
||||
if current_platform.is_rocm():
|
||||
# ROCm has higher numerical variance
|
||||
# due to use of float16.
|
||||
rel_tol, abs_tol = 5e-2, 1e-5
|
||||
else:
|
||||
rel_tol, abs_tol = 1e-3, 1e-6
|
||||
rel_tol, abs_tol = 1e-3, 1e-6
|
||||
return (
|
||||
len(lps_a) == len(lps_b)
|
||||
and lps_a.keys() == lps_b.keys()
|
||||
|
||||
Reference in New Issue
Block a user