diff --git a/tests/benchmarks/sweep/__init__.py b/tests/benchmarks/sweep/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/benchmarks/test_param_sweep.py b/tests/benchmarks/sweep/test_param_sweep.py similarity index 100% rename from tests/benchmarks/test_param_sweep.py rename to tests/benchmarks/sweep/test_param_sweep.py diff --git a/tests/benchmarks/sweep/test_serve_sla.py b/tests/benchmarks/sweep/test_serve_sla.py new file mode 100644 index 000000000..5e93cfc36 --- /dev/null +++ b/tests/benchmarks/sweep/test_serve_sla.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable +from pathlib import Path +from unittest.mock import patch + +from vllm.benchmarks.sweep.param_sweep import ParameterSweepItem +from vllm.benchmarks.sweep.serve_sla import _estimate_sla_bounds, _find_sla_value +from vllm.benchmarks.sweep.server import ServerProcess +from vllm.benchmarks.sweep.sla_sweep import ( + SLACriterionBase, + SLALessThan, + SLALessThanOrEqualTo, + SLASweepItem, +) + + +def _set_return_value( + var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]], +): + """ + Create a patch for run_sla with a specific function + indicating the relationship between the benchmark combination + (which includes the SLA variable) and the SLA criterion. + """ + + def mock_run_sla( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + iter_path: Path, + num_runs: int, + dry_run: bool, + ): + return var2metric(bench_comb) + + return patch("vllm.benchmarks.sweep.serve_sla.run_sla", side_effect=mock_run_sla) + + +def _var2metric_identity(bench_comb): + return [{"request_throughput": float(bench_comb["request_rate"])}] + + +def _run_estimate_sla_bounds( + var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]], + criterion: SLACriterionBase, + init_value: int, + max_value: int, +): + with _set_return_value(var2metric): + return _estimate_sla_bounds( + server=None, + bench_cmd=[], + serve_comb=ParameterSweepItem(), + bench_comb=ParameterSweepItem(), + sla_comb=SLASweepItem({"request_throughput": criterion}), + base_path=Path(""), + num_runs=1, + dry_run=False, + sla_variable="request_rate", + init_value=init_value, + max_value=max_value, + ) + + +def test_estimate_sla_bounds_le(): + sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds( + _var2metric_identity, + SLALessThanOrEqualTo(target=32), + init_value=1, + max_value=100, + ) + + assert max_passing == 32 + assert min_failing == 64 + + assert {val: margin <= 0 for val, margin in history.items()} == { + 1: True, + 2: True, + 4: True, + 8: True, + 16: True, + 32: True, + 64: False, + } + + +def test_estimate_sla_bounds_lt(): + sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds( + _var2metric_identity, + SLALessThan(target=32), + init_value=1, + max_value=100, + ) + + assert max_passing == 16 + assert min_failing == 32 + + assert {val: margin <= 0 for val, margin in history.items()} == { + 1: True, + 2: True, + 4: True, + 8: True, + 16: True, + 32: False, + } + + +def test_estimate_sla_bounds_oob(): + sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds( + _var2metric_identity, + SLALessThanOrEqualTo(target=32), + init_value=64, + max_value=128, + ) + + assert max_passing == 0 + assert min_failing == 64 + + assert {val: margin <= 0 for val, margin in history.items()} == { + 64: False, + } + + +def _run_test_find_sla_value_le( + var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]], + criterion: SLACriterionBase, + min_value: int, + max_value: int, +): + with _set_return_value(var2metric): + return _find_sla_value( + server=None, + bench_cmd=[], + serve_comb=ParameterSweepItem(), + bench_comb=ParameterSweepItem(), + sla_comb=SLASweepItem({"request_throughput": criterion}), + base_path=Path(""), + num_runs=1, + dry_run=False, + sla_variable="request_rate", + min_value=min_value, + max_value=max_value, + ) + + +def test_find_sla_value_le(): + sla_data, sla_value, history = _run_test_find_sla_value_le( + _var2metric_identity, + SLALessThanOrEqualTo(target=50.0), + min_value=32, + max_value=64, + ) + + assert sla_value == 50 + assert {val: margin <= 0 for val, margin in history.items()} == { + 48: True, + 56: False, + 52: False, + 50: True, + 51: False, + } + + +def test_find_sla_value_lt(): + sla_data, sla_value, history = _run_test_find_sla_value_le( + _var2metric_identity, + SLALessThan(target=50.0), + min_value=32, + max_value=64, + ) + + assert sla_value == 49 + assert {val: margin <= 0 for val, margin in history.items()} == { + 48: True, + 56: False, + 52: False, + 50: False, + 49: True, + } + + +def test_find_sla_value_oob(): + sla_data, sla_value, history = _run_test_find_sla_value_le( + _var2metric_identity, + SLALessThanOrEqualTo(target=50.0), + min_value=64, + max_value=128, + ) + + assert sla_value == 64 + assert {val: margin <= 0 for val, margin in history.items()} == { + 96: False, + 80: False, + 72: False, + 68: False, + 66: False, + 65: False, + 64: False, + } diff --git a/vllm/benchmarks/sweep/param_sweep.py b/vllm/benchmarks/sweep/param_sweep.py index a438a3288..f20134cfc 100644 --- a/vllm/benchmarks/sweep/param_sweep.py +++ b/vllm/benchmarks/sweep/param_sweep.py @@ -74,7 +74,8 @@ class ParameterSweepItem(dict[str, object]): representation of all parameters. """ if "_benchmark_name" in self: - return self["_benchmark_name"] + return str(self["_benchmark_name"]) + return self.as_text(sep="-") # In JSON, we prefer "_" diff --git a/vllm/benchmarks/sweep/serve_sla.py b/vllm/benchmarks/sweep/serve_sla.py index 0403d1ddf..297cd4e5b 100644 --- a/vllm/benchmarks/sweep/serve_sla.py +++ b/vllm/benchmarks/sweep/serve_sla.py @@ -145,12 +145,11 @@ def _estimate_sla_bounds( ): sla_data = list[dict[str, object]]() - max_passing: int = 0 - min_failing: int = 0 - val: int = init_value assert val > 0 + history = dict[int, float]() + while True: print(f"Testing {sla_variable}: {val} req/s") @@ -172,24 +171,33 @@ def _estimate_sla_bounds( for k in sla_comb } - sla_results = [ - criterion.print_and_validate(iter_data_mean, k) + sla_margins = [ + criterion.print_and_compute_margin(iter_data_mean, k) for k, criterion in sla_comb.items() ] + margin = max(sla_margins) + history[val] = margin - if all(sla_results): + if margin <= 0: print("SLA criteria are met.") - max_passing = val val *= 2 else: print("SLA criteria are not met.") - min_failing = val break if val >= max_value: break - return sla_data, (max_passing, min_failing) + max_passing = max( + (val for val, margin in history.items() if margin <= 0), + default=0, + ) + min_failing = min( + (val for val, margin in history.items() if margin > 0), + default=max_value, + ) + + return sla_data, (max_passing, min_failing), history def _find_sla_value( @@ -211,6 +219,8 @@ def _find_sla_value( left: int = min_value right: int = max_value + history = dict[int, float]() + while True: val = (left + right) // 2 print(f"Testing {sla_variable}: {val} req/s") @@ -233,22 +243,24 @@ def _find_sla_value( for k in sla_comb } - sla_results = [ - criterion.print_and_validate(iter_data_mean, k) + sla_margins = [ + criterion.print_and_compute_margin(iter_data_mean, k) for k, criterion in sla_comb.items() ] + margin = max(sla_margins) + history[val] = margin - if all(sla_results): + if margin <= 0: print("SLA criteria are met.") left = val else: print("SLA criteria are not met.") right = val - if right - left <= 1: + if right - left <= 1 and left in history: break - return sla_data, left + return sla_data, left, history def search_sla( @@ -288,7 +300,7 @@ def search_sla( ) print(f"Initial {sla_variable} to search: {sla_init_value} req/s.") - sla_data_1, (sla_min, sla_max) = _estimate_sla_bounds( + sla_data_1, (sla_min, sla_max), _ = _estimate_sla_bounds( server, bench_cmd, serve_comb=serve_comb, @@ -303,7 +315,7 @@ def search_sla( ) print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.") - sla_data_2, sla_value = _find_sla_value( + sla_data_2, sla_value, _ = _find_sla_value( server, bench_cmd, serve_comb=serve_comb, diff --git a/vllm/benchmarks/sweep/sla_sweep.py b/vllm/benchmarks/sweep/sla_sweep.py index 327e3c7c5..0a780860d 100644 --- a/vllm/benchmarks/sweep/sla_sweep.py +++ b/vllm/benchmarks/sweep/sla_sweep.py @@ -7,39 +7,45 @@ from dataclasses import dataclass from typing_extensions import override +SLA_EPS = 1e-8 +"""Offset used to differentiate margins for equality checks.""" + @dataclass class SLACriterionBase(ABC): target: float @abstractmethod - def validate(self, actual: float) -> bool: - """Return `True` if this criterion is met; otherwise `False`.""" + def compute_margin(self, actual: float) -> float: + """ + Return a negative value or `0` if this criterion is met; + otherwise a positive value indicating the distance to the target. + """ raise NotImplementedError @abstractmethod def format_cond(self, lhs: str) -> str: raise NotImplementedError - def print_and_validate( + def print_and_compute_margin( self, metrics: dict[str, float], metrics_key: str, - ) -> bool: + ) -> float: metric = metrics[metrics_key] - result = self.validate(metric) + margin = self.compute_margin(metric) cond = self.format_cond(f"{metrics_key} = {metric:.2f}") - print(f"Validating SLA: {cond} | " + ("PASSED" if result else "FAILED")) + print(f"Validating SLA: {cond} | " + ("PASSED" if margin <= 0 else "FAILED")) - return result + return margin @dataclass class SLALessThan(SLACriterionBase): @override - def validate(self, actual: float) -> bool: - return actual < self.target + def compute_margin(self, actual: float) -> float: + return actual + SLA_EPS - self.target @override def format_cond(self, lhs: str) -> str: @@ -49,8 +55,8 @@ class SLALessThan(SLACriterionBase): @dataclass class SLALessThanOrEqualTo(SLACriterionBase): @override - def validate(self, actual: float) -> bool: - return actual <= self.target + def compute_margin(self, actual: float) -> float: + return actual - self.target @override def format_cond(self, lhs: str) -> str: @@ -60,8 +66,8 @@ class SLALessThanOrEqualTo(SLACriterionBase): @dataclass class SLAGreaterThan(SLACriterionBase): @override - def validate(self, actual: float) -> bool: - return actual > self.target + def compute_margin(self, actual: float) -> float: + return self.target + SLA_EPS - actual @override def format_cond(self, lhs: str) -> str: @@ -71,8 +77,8 @@ class SLAGreaterThan(SLACriterionBase): @dataclass class SLAGreaterThanOrEqualTo(SLACriterionBase): @override - def validate(self, actual: float) -> bool: - return actual >= self.target + def compute_margin(self, actual: float) -> float: + return self.target - actual @override def format_cond(self, lhs: str) -> str: