diff --git a/tests/benchmarks/sweep/test_serve_sla.py b/tests/benchmarks/sweep/test_serve_sla.py index 3b85801b0..19f4740bc 100644 --- a/tests/benchmarks/sweep/test_serve_sla.py +++ b/tests/benchmarks/sweep/test_serve_sla.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json from collections.abc import Callable from pathlib import Path from unittest.mock import patch from vllm.benchmarks.sweep.param_sweep import ParameterSweepItem -from vllm.benchmarks.sweep.serve_sla import solve_sla +from vllm.benchmarks.sweep.serve_sla import _get_sla_run_path, solve_sla from vllm.benchmarks.sweep.server import ServerProcess from vllm.benchmarks.sweep.sla_sweep import ( SLACriterionBase, @@ -34,7 +35,14 @@ def _set_return_value( num_runs: int, dry_run: bool, ): - return var2metric(bench_comb) + iter_data = var2metric(bench_comb) + + summary_path = _get_sla_run_path(iter_path, run_number=None) + summary_path.parent.mkdir(parents=True, exist_ok=True) + with summary_path.open("w") as f: + json.dump(iter_data, f, indent=4) + + return iter_data return patch("vllm.benchmarks.sweep.serve_sla.run_sla", side_effect=mock_run_sla) @@ -98,6 +106,7 @@ def _var2metric_sqrt(y_intercept: float): def _run_solve_sla( var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]], criterion: SLACriterionBase, + base_path: Path, min_value: int = 1, max_value: int = 100, ): @@ -108,7 +117,7 @@ def _run_solve_sla( serve_comb=ParameterSweepItem(), bench_comb=ParameterSweepItem(), sla_comb=SLASweepItem({"request_throughput": criterion}), - base_path=Path(""), + base_path=base_path, num_runs=1, dry_run=False, sla_variable="request_rate", @@ -120,10 +129,11 @@ def _run_solve_sla( return result -def test_solve_linear_sla_le(): +def test_solve_linear_sla_le(tmp_path): sla_data, history = _run_solve_sla( _var2metric_linear(), SLALessThanOrEqualTo(target=32), + tmp_path, ) assert history.get_max_passing() == 32 @@ -136,10 +146,11 @@ def test_solve_linear_sla_le(): } -def test_solve_linear_sla_lt(): +def test_solve_linear_sla_lt(tmp_path): sla_data, history = _run_solve_sla( _var2metric_linear(), SLALessThan(target=32), + tmp_path, ) assert history.get_max_passing() == 31 @@ -152,10 +163,11 @@ def test_solve_linear_sla_lt(): } -def test_solve_linear_sla_oob(): +def test_solve_linear_sla_oob(tmp_path): sla_data, history = _run_solve_sla( _var2metric_linear(), SLALessThanOrEqualTo(target=32), + tmp_path, min_value=64, ) @@ -168,10 +180,11 @@ def test_solve_linear_sla_oob(): } -def test_solve_concave_sla_le(): +def test_solve_concave_sla_le(tmp_path): sla_data, history = _run_solve_sla( _var2metric_concave(elbow_point=32), SLALessThanOrEqualTo(target=24), + tmp_path, ) assert history.get_max_passing() == 16 @@ -187,10 +200,11 @@ def test_solve_concave_sla_le(): } -def test_solve_convex_sla_le(): +def test_solve_convex_sla_le(tmp_path): sla_data, history = _run_solve_sla( _var2metric_convex(elbow_point=32), SLALessThanOrEqualTo(target=24), + tmp_path, ) assert history.get_max_passing() == 26 @@ -206,10 +220,11 @@ def test_solve_convex_sla_le(): } -def test_solve_quadratic_sla_le(): +def test_solve_quadratic_sla_le(tmp_path): sla_data, history = _run_solve_sla( _var2metric_quadratic(y_intercept=10), SLALessThanOrEqualTo(target=50), + tmp_path, ) assert history.get_max_passing() == 20 @@ -223,10 +238,11 @@ def test_solve_quadratic_sla_le(): } -def test_solve_sqrt_sla_le(): +def test_solve_sqrt_sla_le(tmp_path): sla_data, history = _run_solve_sla( _var2metric_sqrt(y_intercept=10), SLALessThanOrEqualTo(target=100), + tmp_path, ) assert history.get_max_passing() == 81 @@ -238,3 +254,45 @@ def test_solve_sqrt_sla_le(): 81: True, 82: False, } + + +def test_solve_reuse_history(tmp_path): + sla_data, history = _run_solve_sla( + _var2metric_linear(), + SLALessThanOrEqualTo(target=10), + tmp_path, + min_value=1, + max_value=20, + ) + + assert history.get_max_passing() == 10 + + assert {val: margin <= 0 for val, margin in history.items()} == { + 20: False, + 1: True, + 10: True, + 11: False, + } + + sla_data, history = _run_solve_sla( + _var2metric_linear(), + SLALessThanOrEqualTo(target=30), + tmp_path, + min_value=21, + max_value=40, + ) + + assert history.get_max_passing() == 30 + + assert {val: margin <= 0 for val, margin in history.items()} == { + # Items from the past run + # (the margins are different because the target changed) + 20: True, + 1: True, + 10: True, + 11: True, + # Items from this run + 40: False, + 30: True, + 31: False, + } diff --git a/vllm/benchmarks/sweep/serve_sla.py b/vllm/benchmarks/sweep/serve_sla.py index 1a2091d8b..26f0d6bf6 100644 --- a/vllm/benchmarks/sweep/serve_sla.py +++ b/vllm/benchmarks/sweep/serve_sla.py @@ -65,6 +65,14 @@ def _get_sla_run_path(iter_path: Path, run_number: int | None): return iter_path / f"run={run_number}.json" +def _iter_sla_val_paths(base_path: Path, sla_variable: str): + for iter_path in base_path.glob(f"{sla_variable}=*"): + sla_value = int(iter_path.name.removeprefix(f"{sla_variable}=")) + summary_path = iter_path / "summary.json" + if summary_path.exists(): + yield sla_value, summary_path + + def _sla_needs_server( serve_comb: ParameterSweepItem, bench_combs: ParameterSweep, @@ -153,6 +161,25 @@ class SLAHistory(dict[int, float]): ) +def _compute_margin( + sla_comb: SLASweepItem, + iter_data: list[dict[str, object]], +): + assert iter_data, "Summary should not be empty" + + iter_data_mean = { + k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore + for k in sla_comb + } + + sla_margins = [ + criterion.print_and_compute_margin(iter_data_mean, k) + for k, criterion in sla_comb.items() + ] + + return max(sla_margins) + + def solve_sla( server: ServerProcess | None, bench_cmd: list[str], @@ -170,11 +197,18 @@ def solve_sla( sla_data = list[dict[str, object]]() history = SLAHistory(min_value=sla_min_value, max_value=sla_max_value) + # Use results from previous runs + for past_sla_value, path in _iter_sla_val_paths(base_path, sla_variable): + with path.open("rb") as f: + past_iter_data = json.load(f) + + history[past_sla_value] = _compute_margin(sla_comb, past_iter_data) + # NOTE: We don't use equality here to be more robust against noisy results while history.get_max_passing() + 1 < history.get_min_failing(): - if len(history) == 0: + if max(history, default=sla_min_value) < sla_max_value: val = sla_max_value - elif len(history) == 1: + elif min(history, default=sla_max_value) > sla_min_value: val = sla_min_value else: spl = PchipInterpolator(*history.get_xy(), extrapolate=False) @@ -205,24 +239,15 @@ def solve_sla( if iter_data is None: return None - sla_data.extend(iter_data) - - iter_data_mean = { - k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore - for k in sla_comb - } - - sla_margins = [ - criterion.print_and_compute_margin(iter_data_mean, k) - for k, criterion in sla_comb.items() - ] - history[val] = margin = max(sla_margins) - + margin = _compute_margin(sla_comb, iter_data) if margin <= 0: print(f"SLA criteria are met. ({margin=:.2f})") else: print(f"SLA criteria are not met. ({margin=:.2f})") + sla_data.extend(iter_data) + history[val] = margin + return sla_data, history