[Test] test_async_scheduling.py improvements (#36340)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-03-10 11:17:26 -07:00
committed by GitHub
parent bdd8981dab
commit 2a68464c5b

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from itertools import repeat
from typing import Any
@@ -19,6 +20,8 @@ from ...models.utils import check_outputs_equal
MODEL = "Qwen/Qwen3-0.6B"
MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
# Need to enforce eager for MRV2 while we sort out cudagraph issues.
ENFORCE_EAGER = os.getenv("ENFORCE_EAGER", "0") == "1"
first_prompt = (
"The following numbers of the sequence "
@@ -47,10 +50,10 @@ def test_without_spec_decoding(
test_sampling_params: list[dict[str, Any]] = [
dict(),
# dict(min_tokens=20),
dict(presence_penalty=-1.0),
dict(frequency_penalty=-1.0),
dict(bad_words=["the", " the"]),
dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0),
dict(logprobs=2, frequency_penalty=-1.0),
dict(structured_outputs=struct_outputs),
dict(
structured_outputs=struct_outputs,
@@ -58,12 +61,12 @@ def test_without_spec_decoding(
),
dict(
structured_outputs=struct_outputs,
presence_penalty=-1.0,
frequency_penalty=-1.0,
),
dict(
structured_outputs=struct_outputs,
logprobs=2,
presence_penalty=-1.0,
frequency_penalty=-1.0,
),
]
@@ -116,15 +119,15 @@ def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.Monke
test_sampling_params = [
dict(),
dict(presence_penalty=-1.0),
dict(frequency_penalty=-1.0),
dict(bad_words=["the", " the"]),
dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0),
dict(logprobs=2, frequency_penalty=-1.0),
dict(structured_outputs=struct_outputs),
dict(
structured_outputs=struct_outputs,
logprobs=2,
presence_penalty=-1.0,
frequency_penalty=-1.0,
),
]
@@ -144,14 +147,7 @@ def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.Monke
(True, "uni", True, spec_config_short, True),
]
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency
run_tests(
monkeypatch,
MTP_MODEL,
test_configs,
test_sampling_params,
is_testing_with_spec_decoding=True,
)
run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch):
@@ -196,12 +192,11 @@ def run_tests(
model: str,
test_configs: list[tuple],
test_sampling_params: list[dict[str, Any]],
is_testing_with_spec_decoding: bool = False,
):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding."""
# Determine attention config based on platform
# Flex attention supports float32.
attention_config = {"backend": "FLEX_ATTENTION"}
with monkeypatch.context() as m:
@@ -226,7 +221,6 @@ def run_tests(
async_scheduling,
spec_config,
test_prefill_chunking=test_prefill_chunking,
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
attention_config=attention_config,
)
outputs.append(test_results)
@@ -250,6 +244,7 @@ def run_tests(
test_acceptance_rates or repeat(None),
test_sampling_params,
):
reason = None
try:
check_outputs_equal(
outputs_0_lst=base_outs,
@@ -257,42 +252,57 @@ def run_tests(
name_0=f"baseline=[{baseline_config}], params={params}",
name_1=f"config=[{test_config}], params={params}",
)
except AssertionError as e:
reason = "outputs ", e
assert _all_logprobs_match(base_logprobs, test_logprobs)
if reason is None:
try:
assert _all_logprobs_match(base_logprobs, test_logprobs)
except AssertionError as e:
reason = "logprobs", e
if (
base_acceptance_rate is not None
and test_acceptance_rate is not None
):
if "spec_mml=None" in test_config:
# Preemption causes more variance in acceptance rates
if (
current_platform.is_rocm()
and "preemption=True" in test_config
):
tolerance = 0.10
if reason is None:
try:
if (
base_acceptance_rate is not None
and test_acceptance_rate is not None
):
if "spec_mml=None" in test_config:
# Preemption causes more variance in acceptance rates
if (
current_platform.is_rocm()
and "preemption=True" in test_config
):
tolerance = 0.10
else:
tolerance = 0.05
assert (
test_acceptance_rate > base_acceptance_rate
or test_acceptance_rate
== pytest.approx(base_acceptance_rate, rel=tolerance)
)
else:
tolerance = 0.05
assert (
test_acceptance_rate > base_acceptance_rate
or test_acceptance_rate
== pytest.approx(base_acceptance_rate, rel=tolerance)
)
else:
# Currently the reported acceptance rate is expected to be
# lower when we sometimes skip drafting altogether.
assert test_acceptance_rate > 0.1
# Currently the reported acceptance rate is expected to be
# lower when we sometimes skip drafting altogether.
assert test_acceptance_rate > 0.1
except AssertionError as e:
reason = "accept ", e
if reason is None:
print(
f"PASSED: config=[{test_config}], params={params}"
f"\033[32mPASSED\033[0m: "
f"config=[{test_config}], params={params}"
f" accept_rate={test_acceptance_rate}"
)
except AssertionError as e:
else:
reason_str, _ = reason
print(
f"FAILED: config=[{test_config}], params={params}"
f"\033[31mFAILED\033[0m({reason_str}): "
f"config=[{test_config}], params={params}"
f" accept_rate={test_acceptance_rate}"
)
if failure is None:
failure = e
_, failure = reason
if failure is not None:
raise failure
@@ -307,7 +317,6 @@ def run_test(
async_scheduling: bool,
spec_config: dict[str, Any] | None,
test_prefill_chunking: bool,
is_testing_with_spec_decoding: bool = False,
attention_config: dict[str, Any] | None = None,
):
spec_decoding = spec_config is not None
@@ -335,7 +344,7 @@ def run_test(
enable_chunked_prefill=test_prefill_chunking,
# Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None,
# enforce_eager=True,
enforce_eager=ENFORCE_EAGER,
async_scheduling=async_scheduling,
distributed_executor_backend=executor,
dtype="float32",