tests/v1/e2e/spec_decode: assert async scheduling is used (#39206)
Signed-off-by: Rishi Puri <riship@nvidia.com> Signed-off-by: Rishi Puri <puririshi98@berkeley.edu> Signed-off-by: sfeng33 <4florafeng@gmail.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: Flora Feng <4florafeng@gmail.com>
This commit is contained in:
@@ -30,6 +30,13 @@ from vllm.v1.metrics.reader import Metric
|
||||
MTP_SIMILARITY_RATE = 0.8
|
||||
|
||||
|
||||
class AsyncSchedulingNotEnabledError(AssertionError):
|
||||
"""Raised when async_scheduling is expected to be True for draft_model
|
||||
spec decode but is False. Tracked in:
|
||||
https://github.com/vllm-project/vllm/issues/38929
|
||||
"""
|
||||
|
||||
|
||||
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
|
||||
"""Skip test if available GPUs < tp_size on ROCm."""
|
||||
available_gpus = torch.accelerator.device_count()
|
||||
@@ -206,6 +213,8 @@ def test_ngram_gpu_default_with_async_scheduling(
|
||||
max_model_len=4096,
|
||||
async_scheduling=async_scheduling,
|
||||
)
|
||||
# Assert the resolved async_scheduling config matches what was requested.
|
||||
assert spec_llm.llm_engine.vllm_config.scheduler_config.async_scheduling == async_scheduling
|
||||
evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8)
|
||||
del spec_llm
|
||||
cleanup_dist_env_and_memory()
|
||||
@@ -457,6 +466,8 @@ def _run_eagle_correctness(
|
||||
model_impl=model_impl,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
# EAGLE/EAGLE3 supports async scheduling; assert it is active by default.
|
||||
assert spec_llm.llm_engine.vllm_config.scheduler_config.async_scheduling
|
||||
evaluate_llm_for_gsm8k(
|
||||
spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
|
||||
)
|
||||
@@ -760,6 +771,8 @@ def test_mtp_correctness(
|
||||
max_model_len=2048,
|
||||
attention_backend=attn_backend,
|
||||
)
|
||||
# MTP supports async scheduling; assert it is active by default.
|
||||
assert spec_llm.llm_engine.vllm_config.scheduler_config.async_scheduling
|
||||
evaluate_llm_for_gsm8k(
|
||||
spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
|
||||
)
|
||||
@@ -829,12 +842,22 @@ cases = [
|
||||
@pytest.mark.parametrize("args", cases)
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@single_gpu_only
|
||||
# TODO: Fix async_scheduling & engine initialization issues - see https://github.com/vllm-project/vllm/issues/38929
|
||||
@pytest.mark.xfail(
|
||||
raises=AsyncSchedulingNotEnabledError,
|
||||
reason="draft_model does not yet enable async_scheduling: issue #38929",
|
||||
)
|
||||
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
|
||||
args.enforce_eager = enforce_eager
|
||||
assert_draft_model_correctness(args)
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
# TODO: Fix async_scheduling and engine initialization issues - see https://github.com/vllm-project/vllm/issues/38929
|
||||
@pytest.mark.xfail(
|
||||
raises=AsyncSchedulingNotEnabledError,
|
||||
reason="draft_model does not yet enable async_scheduling: issue #38929",
|
||||
)
|
||||
def test_draft_model_realistic_example():
|
||||
args = ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
@@ -850,6 +873,11 @@ def test_draft_model_realistic_example():
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
# TODO: Fix async_scheduling and engine initialization issues - see https://github.com/vllm-project/vllm/issues/38929
|
||||
@pytest.mark.xfail(
|
||||
raises=AsyncSchedulingNotEnabledError,
|
||||
reason="draft_model does not yet enable async_scheduling: issue #38929",
|
||||
)
|
||||
def test_draft_model_parallel_drafting():
|
||||
args = ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
@@ -876,6 +904,11 @@ def test_draft_model_parallel_drafting():
|
||||
)
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@single_gpu_only
|
||||
# TODO: Fix async_scheduling and engine initialization issues - see https://github.com/vllm-project/vllm/issues/38929
|
||||
@pytest.mark.xfail(
|
||||
raises=AsyncSchedulingNotEnabledError,
|
||||
reason="draft_model does not yet enable async_scheduling: issue #38929",
|
||||
)
|
||||
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
|
||||
tgt_model, draft_model = models
|
||||
sd_case = ArgsTest(
|
||||
@@ -888,6 +921,11 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
|
||||
|
||||
|
||||
@multi_gpu_only(num_gpus=2)
|
||||
# TODO: Fix async_scheduling and engine initialization issues - see https://github.com/vllm-project/vllm/issues/38929
|
||||
@pytest.mark.xfail(
|
||||
raises=AsyncSchedulingNotEnabledError,
|
||||
reason="draft_model does not yet enable async_scheduling: issue #38929",
|
||||
)
|
||||
def test_draft_model_tensor_parallelism():
|
||||
"""Ensure spec decode works when running with TP > 1."""
|
||||
_skip_if_insufficient_gpus_for_tp(2)
|
||||
@@ -1062,6 +1100,7 @@ def assert_draft_model_correctness(args: ArgsTest):
|
||||
enforce_eager=args.enforce_eager,
|
||||
disable_log_stats=False, # enables get_metrics()
|
||||
)
|
||||
|
||||
# we don't check the outputs, only check the metrics
|
||||
spec_llm.chat(test_prompts, args.sampling_config)
|
||||
metrics = spec_llm.get_metrics()
|
||||
@@ -1073,10 +1112,6 @@ def assert_draft_model_correctness(args: ArgsTest):
|
||||
spec_llm, expected_accuracy_threshold=args.expected_gsm8k_accuracy
|
||||
)
|
||||
|
||||
del spec_llm # CLEANUP
|
||||
torch.accelerator.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
print(
|
||||
f"spec-decode: target={args.target_model}, draft={args.draft_model}, "
|
||||
f"temperature={args.sampling_config.temperature:.2f}, "
|
||||
@@ -1086,6 +1121,20 @@ def assert_draft_model_correctness(args: ArgsTest):
|
||||
|
||||
assert acceptance_rate >= args.expected_acceptance_rate
|
||||
assert acceptance_len >= args.expected_acceptance_len
|
||||
# draft_model supports async scheduling; assert it is active by default.
|
||||
# Raise AsyncSchedulingNotEnabledError (a subclass of AssertionError) so that
|
||||
# @pytest.mark.xfail(raises=AsyncSchedulingNotEnabledError) catches only this
|
||||
# specific failure — leaving all other assertion failures (e.g. correctness or
|
||||
# acceptance-rate checks above) visible as real test failures.
|
||||
has_async = spec_llm.llm_engine.vllm_config.scheduler_config.async_scheduling
|
||||
del spec_llm # CLEANUP
|
||||
torch.accelerator.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
if not has_async:
|
||||
raise AsyncSchedulingNotEnabledError(
|
||||
"Expected async_scheduling=True for draft_model spec decode, got False."
|
||||
" See https://github.com/vllm-project/vllm/issues/38929"
|
||||
)
|
||||
|
||||
|
||||
def get_messages(dataset: str, n: int) -> list[Messages]:
|
||||
|
||||
Reference in New Issue
Block a user