[Test] test_async_scheduling.py improvements (#36340)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from itertools import repeat
|
||||
from typing import Any
|
||||
|
||||
@@ -19,6 +20,8 @@ from ...models.utils import check_outputs_equal
|
||||
MODEL = "Qwen/Qwen3-0.6B"
|
||||
MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
# Need to enforce eager for MRV2 while we sort out cudagraph issues.
|
||||
ENFORCE_EAGER = os.getenv("ENFORCE_EAGER", "0") == "1"
|
||||
|
||||
first_prompt = (
|
||||
"The following numbers of the sequence "
|
||||
@@ -47,10 +50,10 @@ def test_without_spec_decoding(
|
||||
test_sampling_params: list[dict[str, Any]] = [
|
||||
dict(),
|
||||
# dict(min_tokens=20),
|
||||
dict(presence_penalty=-1.0),
|
||||
dict(frequency_penalty=-1.0),
|
||||
dict(bad_words=["the", " the"]),
|
||||
dict(logprobs=2),
|
||||
dict(logprobs=2, presence_penalty=-1.0),
|
||||
dict(logprobs=2, frequency_penalty=-1.0),
|
||||
dict(structured_outputs=struct_outputs),
|
||||
dict(
|
||||
structured_outputs=struct_outputs,
|
||||
@@ -58,12 +61,12 @@ def test_without_spec_decoding(
|
||||
),
|
||||
dict(
|
||||
structured_outputs=struct_outputs,
|
||||
presence_penalty=-1.0,
|
||||
frequency_penalty=-1.0,
|
||||
),
|
||||
dict(
|
||||
structured_outputs=struct_outputs,
|
||||
logprobs=2,
|
||||
presence_penalty=-1.0,
|
||||
frequency_penalty=-1.0,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -116,15 +119,15 @@ def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.Monke
|
||||
|
||||
test_sampling_params = [
|
||||
dict(),
|
||||
dict(presence_penalty=-1.0),
|
||||
dict(frequency_penalty=-1.0),
|
||||
dict(bad_words=["the", " the"]),
|
||||
dict(logprobs=2),
|
||||
dict(logprobs=2, presence_penalty=-1.0),
|
||||
dict(logprobs=2, frequency_penalty=-1.0),
|
||||
dict(structured_outputs=struct_outputs),
|
||||
dict(
|
||||
structured_outputs=struct_outputs,
|
||||
logprobs=2,
|
||||
presence_penalty=-1.0,
|
||||
frequency_penalty=-1.0,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -144,14 +147,7 @@ def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.Monke
|
||||
(True, "uni", True, spec_config_short, True),
|
||||
]
|
||||
|
||||
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency
|
||||
run_tests(
|
||||
monkeypatch,
|
||||
MTP_MODEL,
|
||||
test_configs,
|
||||
test_sampling_params,
|
||||
is_testing_with_spec_decoding=True,
|
||||
)
|
||||
run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
|
||||
|
||||
|
||||
def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
||||
@@ -196,12 +192,11 @@ def run_tests(
|
||||
model: str,
|
||||
test_configs: list[tuple],
|
||||
test_sampling_params: list[dict[str, Any]],
|
||||
is_testing_with_spec_decoding: bool = False,
|
||||
):
|
||||
"""Test consistency of combos of async scheduling, preemption,
|
||||
uni/multiproc executor with spec decoding."""
|
||||
|
||||
# Determine attention config based on platform
|
||||
# Flex attention supports float32.
|
||||
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
@@ -226,7 +221,6 @@ def run_tests(
|
||||
async_scheduling,
|
||||
spec_config,
|
||||
test_prefill_chunking=test_prefill_chunking,
|
||||
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
outputs.append(test_results)
|
||||
@@ -250,6 +244,7 @@ def run_tests(
|
||||
test_acceptance_rates or repeat(None),
|
||||
test_sampling_params,
|
||||
):
|
||||
reason = None
|
||||
try:
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=base_outs,
|
||||
@@ -257,42 +252,57 @@ def run_tests(
|
||||
name_0=f"baseline=[{baseline_config}], params={params}",
|
||||
name_1=f"config=[{test_config}], params={params}",
|
||||
)
|
||||
except AssertionError as e:
|
||||
reason = "outputs ", e
|
||||
|
||||
assert _all_logprobs_match(base_logprobs, test_logprobs)
|
||||
if reason is None:
|
||||
try:
|
||||
assert _all_logprobs_match(base_logprobs, test_logprobs)
|
||||
except AssertionError as e:
|
||||
reason = "logprobs", e
|
||||
|
||||
if (
|
||||
base_acceptance_rate is not None
|
||||
and test_acceptance_rate is not None
|
||||
):
|
||||
if "spec_mml=None" in test_config:
|
||||
# Preemption causes more variance in acceptance rates
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and "preemption=True" in test_config
|
||||
):
|
||||
tolerance = 0.10
|
||||
if reason is None:
|
||||
try:
|
||||
if (
|
||||
base_acceptance_rate is not None
|
||||
and test_acceptance_rate is not None
|
||||
):
|
||||
if "spec_mml=None" in test_config:
|
||||
# Preemption causes more variance in acceptance rates
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and "preemption=True" in test_config
|
||||
):
|
||||
tolerance = 0.10
|
||||
else:
|
||||
tolerance = 0.05
|
||||
assert (
|
||||
test_acceptance_rate > base_acceptance_rate
|
||||
or test_acceptance_rate
|
||||
== pytest.approx(base_acceptance_rate, rel=tolerance)
|
||||
)
|
||||
else:
|
||||
tolerance = 0.05
|
||||
assert (
|
||||
test_acceptance_rate > base_acceptance_rate
|
||||
or test_acceptance_rate
|
||||
== pytest.approx(base_acceptance_rate, rel=tolerance)
|
||||
)
|
||||
else:
|
||||
# Currently the reported acceptance rate is expected to be
|
||||
# lower when we sometimes skip drafting altogether.
|
||||
assert test_acceptance_rate > 0.1
|
||||
# Currently the reported acceptance rate is expected to be
|
||||
# lower when we sometimes skip drafting altogether.
|
||||
assert test_acceptance_rate > 0.1
|
||||
except AssertionError as e:
|
||||
reason = "accept ", e
|
||||
|
||||
if reason is None:
|
||||
print(
|
||||
f"PASSED: config=[{test_config}], params={params}"
|
||||
f"\033[32mPASSED\033[0m: "
|
||||
f"config=[{test_config}], params={params}"
|
||||
f" accept_rate={test_acceptance_rate}"
|
||||
)
|
||||
except AssertionError as e:
|
||||
else:
|
||||
reason_str, _ = reason
|
||||
print(
|
||||
f"FAILED: config=[{test_config}], params={params}"
|
||||
f"\033[31mFAILED\033[0m({reason_str}): "
|
||||
f"config=[{test_config}], params={params}"
|
||||
f" accept_rate={test_acceptance_rate}"
|
||||
)
|
||||
if failure is None:
|
||||
failure = e
|
||||
_, failure = reason
|
||||
|
||||
if failure is not None:
|
||||
raise failure
|
||||
@@ -307,7 +317,6 @@ def run_test(
|
||||
async_scheduling: bool,
|
||||
spec_config: dict[str, Any] | None,
|
||||
test_prefill_chunking: bool,
|
||||
is_testing_with_spec_decoding: bool = False,
|
||||
attention_config: dict[str, Any] | None = None,
|
||||
):
|
||||
spec_decoding = spec_config is not None
|
||||
@@ -335,7 +344,7 @@ def run_test(
|
||||
enable_chunked_prefill=test_prefill_chunking,
|
||||
# Force prefill chunking
|
||||
max_num_batched_tokens=48 if test_prefill_chunking else None,
|
||||
# enforce_eager=True,
|
||||
enforce_eager=ENFORCE_EAGER,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=executor,
|
||||
dtype="float32",
|
||||
|
||||
Reference in New Issue
Block a user