[Benchmark][2/2] Use spline interpolation to tune SLA variables (#32095)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -129,10 +129,10 @@ vllm bench sweep serve_sla \
|
||||
|
||||
The algorithm for adjusting the SLA variable is as follows:
|
||||
|
||||
1. Run the benchmark with infinite QPS, and use the corresponding metrics to determine the initial value of the variable.
|
||||
- For example, the initial request rate is set to the concurrency under infinite QPS.
|
||||
2. If the SLA is still satisfied, keep doubling the value until the SLA is no longer satisfied. This gives a relatively narrow window that contains the point where the SLA is barely satisfied.
|
||||
3. Apply binary search over the window to find the maximum value that still satisfies the SLA.
|
||||
1. Run the benchmark once with maximum possible QPS, and once with minimum possible QPS. For each run, calculate the distance of the SLA metrics from their targets, resulting in data points of QPS vs SLA distance.
|
||||
2. Perform spline interpolation between the data points to estimate the QPS that results in zero SLA distance.
|
||||
3. Run the benchmark with the estimated QPS and add the resulting data point to the history.
|
||||
4. Repeat Steps 2 and 3 until the maximum QPS that passes SLA and the minimum QPS that fails SLA in the history are close enough to each other.
|
||||
|
||||
!!! important
|
||||
SLA tuning is applied over each combination of `--serve-params`, `--bench-params`, and `--sla-params`.
|
||||
|
||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from vllm.benchmarks.sweep.param_sweep import ParameterSweepItem
|
||||
from vllm.benchmarks.sweep.serve_sla import _estimate_sla_bounds, _find_sla_value
|
||||
from vllm.benchmarks.sweep.serve_sla import solve_sla
|
||||
from vllm.benchmarks.sweep.server import ServerProcess
|
||||
from vllm.benchmarks.sweep.sla_sweep import (
|
||||
SLACriterionBase,
|
||||
@@ -39,18 +39,70 @@ def _set_return_value(
|
||||
return patch("vllm.benchmarks.sweep.serve_sla.run_sla", side_effect=mock_run_sla)
|
||||
|
||||
|
||||
def _var2metric_identity(bench_comb):
|
||||
return [{"request_throughput": float(bench_comb["request_rate"])}]
|
||||
def _var2metric_linear():
|
||||
def wrapped(bench_comb):
|
||||
x = float(bench_comb["request_rate"])
|
||||
y = x
|
||||
|
||||
return [{"request_throughput": y}]
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _run_estimate_sla_bounds(
|
||||
def _var2metric_concave(elbow_point: float):
|
||||
def wrapped(bench_comb):
|
||||
x = float(bench_comb["request_rate"])
|
||||
if x < elbow_point:
|
||||
y = 0.5 * (x - elbow_point) + elbow_point
|
||||
else:
|
||||
y = 1.5 * (x - elbow_point) + elbow_point
|
||||
|
||||
return [{"request_throughput": y}]
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _var2metric_convex(elbow_point: float):
|
||||
def wrapped(bench_comb):
|
||||
x = float(bench_comb["request_rate"])
|
||||
if x < elbow_point:
|
||||
y = 1.5 * (x - elbow_point) + elbow_point
|
||||
else:
|
||||
y = 0.5 * (x - elbow_point) + elbow_point
|
||||
|
||||
return [{"request_throughput": y}]
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _var2metric_quadratic(y_intercept: float):
|
||||
def wrapped(bench_comb):
|
||||
x = float(bench_comb["request_rate"])
|
||||
y = y_intercept + 0.1 * x**2
|
||||
|
||||
return [{"request_throughput": y}]
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _var2metric_sqrt(y_intercept: float):
|
||||
def wrapped(bench_comb):
|
||||
x = float(bench_comb["request_rate"])
|
||||
y = y_intercept + 10 * x**0.5
|
||||
|
||||
return [{"request_throughput": y}]
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _run_solve_sla(
|
||||
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
|
||||
criterion: SLACriterionBase,
|
||||
init_value: int,
|
||||
max_value: int,
|
||||
min_value: int = 1,
|
||||
max_value: int = 100,
|
||||
):
|
||||
with _set_return_value(var2metric):
|
||||
return _estimate_sla_bounds(
|
||||
result = solve_sla(
|
||||
server=None,
|
||||
bench_cmd=[],
|
||||
serve_comb=ParameterSweepItem(),
|
||||
@@ -60,143 +112,129 @@ def _run_estimate_sla_bounds(
|
||||
num_runs=1,
|
||||
dry_run=False,
|
||||
sla_variable="request_rate",
|
||||
init_value=init_value,
|
||||
max_value=max_value,
|
||||
sla_min_value=min_value,
|
||||
sla_max_value=max_value,
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def test_estimate_sla_bounds_le():
|
||||
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds(
|
||||
_var2metric_identity,
|
||||
def test_solve_linear_sla_le():
|
||||
sla_data, history = _run_solve_sla(
|
||||
_var2metric_linear(),
|
||||
SLALessThanOrEqualTo(target=32),
|
||||
init_value=1,
|
||||
max_value=100,
|
||||
)
|
||||
|
||||
assert max_passing == 32
|
||||
assert min_failing == 64
|
||||
assert history.get_max_passing() == 32
|
||||
|
||||
assert {val: margin <= 0 for val, margin in history.items()} == {
|
||||
100: False,
|
||||
1: True,
|
||||
2: True,
|
||||
4: True,
|
||||
8: True,
|
||||
16: True,
|
||||
32: True,
|
||||
64: False,
|
||||
33: False,
|
||||
}
|
||||
|
||||
|
||||
def test_estimate_sla_bounds_lt():
|
||||
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds(
|
||||
_var2metric_identity,
|
||||
def test_solve_linear_sla_lt():
|
||||
sla_data, history = _run_solve_sla(
|
||||
_var2metric_linear(),
|
||||
SLALessThan(target=32),
|
||||
init_value=1,
|
||||
max_value=100,
|
||||
)
|
||||
|
||||
assert max_passing == 16
|
||||
assert min_failing == 32
|
||||
assert history.get_max_passing() == 31
|
||||
|
||||
assert {val: margin <= 0 for val, margin in history.items()} == {
|
||||
100: False,
|
||||
1: True,
|
||||
2: True,
|
||||
4: True,
|
||||
8: True,
|
||||
16: True,
|
||||
31: True,
|
||||
32: False,
|
||||
}
|
||||
|
||||
|
||||
def test_estimate_sla_bounds_oob():
|
||||
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds(
|
||||
_var2metric_identity,
|
||||
def test_solve_linear_sla_oob():
|
||||
sla_data, history = _run_solve_sla(
|
||||
_var2metric_linear(),
|
||||
SLALessThanOrEqualTo(target=32),
|
||||
init_value=64,
|
||||
max_value=128,
|
||||
)
|
||||
|
||||
assert max_passing == 0
|
||||
assert min_failing == 64
|
||||
|
||||
assert {val: margin <= 0 for val, margin in history.items()} == {
|
||||
64: False,
|
||||
}
|
||||
|
||||
|
||||
def _run_test_find_sla_value_le(
|
||||
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
|
||||
criterion: SLACriterionBase,
|
||||
min_value: int,
|
||||
max_value: int,
|
||||
):
|
||||
with _set_return_value(var2metric):
|
||||
return _find_sla_value(
|
||||
server=None,
|
||||
bench_cmd=[],
|
||||
serve_comb=ParameterSweepItem(),
|
||||
bench_comb=ParameterSweepItem(),
|
||||
sla_comb=SLASweepItem({"request_throughput": criterion}),
|
||||
base_path=Path(""),
|
||||
num_runs=1,
|
||||
dry_run=False,
|
||||
sla_variable="request_rate",
|
||||
min_value=min_value,
|
||||
max_value=max_value,
|
||||
)
|
||||
|
||||
|
||||
def test_find_sla_value_le():
|
||||
sla_data, sla_value, history = _run_test_find_sla_value_le(
|
||||
_var2metric_identity,
|
||||
SLALessThanOrEqualTo(target=50.0),
|
||||
min_value=32,
|
||||
max_value=64,
|
||||
)
|
||||
|
||||
assert sla_value == 50
|
||||
assert {val: margin <= 0 for val, margin in history.items()} == {
|
||||
48: True,
|
||||
56: False,
|
||||
52: False,
|
||||
50: True,
|
||||
51: False,
|
||||
}
|
||||
|
||||
|
||||
def test_find_sla_value_lt():
|
||||
sla_data, sla_value, history = _run_test_find_sla_value_le(
|
||||
_var2metric_identity,
|
||||
SLALessThan(target=50.0),
|
||||
min_value=32,
|
||||
max_value=64,
|
||||
)
|
||||
|
||||
assert sla_value == 49
|
||||
assert {val: margin <= 0 for val, margin in history.items()} == {
|
||||
48: True,
|
||||
56: False,
|
||||
52: False,
|
||||
50: False,
|
||||
49: True,
|
||||
}
|
||||
|
||||
|
||||
def test_find_sla_value_oob():
|
||||
sla_data, sla_value, history = _run_test_find_sla_value_le(
|
||||
_var2metric_identity,
|
||||
SLALessThanOrEqualTo(target=50.0),
|
||||
min_value=64,
|
||||
max_value=128,
|
||||
)
|
||||
|
||||
assert sla_value == 64
|
||||
assert history.get_max_passing() == 64
|
||||
assert history.get_min_failing() == 64
|
||||
|
||||
assert {val: margin <= 0 for val, margin in history.items()} == {
|
||||
96: False,
|
||||
80: False,
|
||||
72: False,
|
||||
68: False,
|
||||
66: False,
|
||||
65: False,
|
||||
100: False,
|
||||
64: False,
|
||||
}
|
||||
|
||||
|
||||
def test_solve_concave_sla_le():
|
||||
sla_data, history = _run_solve_sla(
|
||||
_var2metric_concave(elbow_point=32),
|
||||
SLALessThanOrEqualTo(target=24),
|
||||
)
|
||||
|
||||
assert history.get_max_passing() == 16
|
||||
|
||||
assert {val: margin <= 0 for val, margin in history.items()} == {
|
||||
100: False,
|
||||
1: True,
|
||||
7: True,
|
||||
13: True,
|
||||
15: True,
|
||||
16: True,
|
||||
17: False,
|
||||
}
|
||||
|
||||
|
||||
def test_solve_convex_sla_le():
|
||||
sla_data, history = _run_solve_sla(
|
||||
_var2metric_convex(elbow_point=32),
|
||||
SLALessThanOrEqualTo(target=24),
|
||||
)
|
||||
|
||||
assert history.get_max_passing() == 26
|
||||
|
||||
assert {val: margin <= 0 for val, margin in history.items()} == {
|
||||
100: False,
|
||||
1: True,
|
||||
48: False,
|
||||
30: False,
|
||||
24: True,
|
||||
26: True,
|
||||
27: False,
|
||||
}
|
||||
|
||||
|
||||
def test_solve_quadratic_sla_le():
|
||||
sla_data, history = _run_solve_sla(
|
||||
_var2metric_quadratic(y_intercept=10),
|
||||
SLALessThanOrEqualTo(target=50),
|
||||
)
|
||||
|
||||
assert history.get_max_passing() == 20
|
||||
|
||||
assert {val: margin <= 0 for val, margin in history.items()} == {
|
||||
100: False,
|
||||
1: True,
|
||||
4: True,
|
||||
20: True,
|
||||
21: False,
|
||||
}
|
||||
|
||||
|
||||
def test_solve_sqrt_sla_le():
|
||||
sla_data, history = _run_solve_sla(
|
||||
_var2metric_sqrt(y_intercept=10),
|
||||
SLALessThanOrEqualTo(target=100),
|
||||
)
|
||||
|
||||
assert history.get_max_passing() == 81
|
||||
|
||||
assert {val: margin <= 0 for val, margin in history.items()} == {
|
||||
100: False,
|
||||
1: True,
|
||||
89: False,
|
||||
81: True,
|
||||
82: False,
|
||||
}
|
||||
|
||||
@@ -3,14 +3,11 @@
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, get_args
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
@@ -24,6 +21,15 @@ try:
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
try:
|
||||
from scipy.interpolate import PchipInterpolator
|
||||
except ImportError:
|
||||
PchipInterpolator = (
|
||||
PlaceholderModule("scipy")
|
||||
.placeholder_attr("interpolate")
|
||||
.placeholder_attr("PchipInterpolator")
|
||||
)
|
||||
|
||||
|
||||
def _get_sla_base_path(
|
||||
output_dir: Path,
|
||||
@@ -118,18 +124,36 @@ def run_sla(
|
||||
SLAVariable = Literal["request_rate", "max_concurrency"]
|
||||
|
||||
|
||||
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
|
||||
request_throughput = float(run_data["request_throughput"]) # type: ignore
|
||||
if sla_variable == "request_rate":
|
||||
return request_throughput
|
||||
if sla_variable == "max_concurrency":
|
||||
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
|
||||
return request_throughput * mean_latency_ms / 1000
|
||||
class SLAHistory(dict[int, float]):
|
||||
def __init__(self, min_value: int, max_value: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert_never(sla_variable)
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
|
||||
def get_xy(self) -> tuple[list[int], list[float]]:
|
||||
xs = list[int]()
|
||||
ys = list[float]()
|
||||
for x, y in sorted(self.items()):
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
|
||||
return xs, ys
|
||||
|
||||
def get_max_passing(self) -> float:
|
||||
return max(
|
||||
(val for val, margin in self.items() if margin <= 0),
|
||||
default=self.min_value,
|
||||
)
|
||||
|
||||
def get_min_failing(self) -> float:
|
||||
return min(
|
||||
(val for val, margin in self.items() if margin > 0),
|
||||
default=self.max_value,
|
||||
)
|
||||
|
||||
|
||||
def _estimate_sla_bounds(
|
||||
def solve_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
@@ -140,17 +164,33 @@ def _estimate_sla_bounds(
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
sla_variable: SLAVariable,
|
||||
init_value: int,
|
||||
max_value: int,
|
||||
sla_min_value: int = 1,
|
||||
sla_max_value: int = 8192, # The value that represents infinite QPS
|
||||
):
|
||||
sla_data = list[dict[str, object]]()
|
||||
history = SLAHistory(min_value=sla_min_value, max_value=sla_max_value)
|
||||
|
||||
val: int = init_value
|
||||
assert val > 0
|
||||
# NOTE: We don't use equality here to be more robust against noisy results
|
||||
while history.get_max_passing() + 1 < history.get_min_failing():
|
||||
if len(history) == 0:
|
||||
val = sla_max_value
|
||||
elif len(history) == 1:
|
||||
val = sla_min_value
|
||||
else:
|
||||
spl = PchipInterpolator(*history.get_xy(), extrapolate=False)
|
||||
spl_roots = spl.solve()
|
||||
if len(spl_roots) == 0:
|
||||
# Fallback to binary search
|
||||
val = int((history.get_max_passing() + history.get_min_failing()) / 2)
|
||||
else:
|
||||
val = int(spl_roots[0])
|
||||
|
||||
history = dict[int, float]()
|
||||
if val in history:
|
||||
# Cover both sides (floor and ceil) of the root to be sure
|
||||
# that it is indeed the target value
|
||||
val += 1
|
||||
|
||||
while True:
|
||||
val = max(sla_min_value, min(val, sla_max_value))
|
||||
print(f"Testing {sla_variable}: {val} req/s")
|
||||
|
||||
iter_data = run_sla(
|
||||
@@ -162,8 +202,9 @@ def _estimate_sla_bounds(
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if iter_data is None:
|
||||
return None
|
||||
|
||||
assert iter_data is not None
|
||||
sla_data.extend(iter_data)
|
||||
|
||||
iter_data_mean = {
|
||||
@@ -175,92 +216,14 @@ def _estimate_sla_bounds(
|
||||
criterion.print_and_compute_margin(iter_data_mean, k)
|
||||
for k, criterion in sla_comb.items()
|
||||
]
|
||||
margin = max(sla_margins)
|
||||
history[val] = margin
|
||||
history[val] = margin = max(sla_margins)
|
||||
|
||||
if margin <= 0:
|
||||
print("SLA criteria are met.")
|
||||
val *= 2
|
||||
print(f"SLA criteria are met. ({margin=:.2f})")
|
||||
else:
|
||||
print("SLA criteria are not met.")
|
||||
break
|
||||
print(f"SLA criteria are not met. ({margin=:.2f})")
|
||||
|
||||
if val >= max_value:
|
||||
break
|
||||
|
||||
max_passing = max(
|
||||
(val for val, margin in history.items() if margin <= 0),
|
||||
default=0,
|
||||
)
|
||||
min_failing = min(
|
||||
(val for val, margin in history.items() if margin > 0),
|
||||
default=max_value,
|
||||
)
|
||||
|
||||
return sla_data, (max_passing, min_failing), history
|
||||
|
||||
|
||||
def _find_sla_value(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_comb: SLASweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
sla_variable: SLAVariable,
|
||||
min_value: int,
|
||||
max_value: int,
|
||||
):
|
||||
sla_data = list[dict[str, object]]()
|
||||
|
||||
left: int = min_value
|
||||
right: int = max_value
|
||||
|
||||
history = dict[int, float]()
|
||||
|
||||
while True:
|
||||
val = (left + right) // 2
|
||||
print(f"Testing {sla_variable}: {val} req/s")
|
||||
|
||||
iter_data = run_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {sla_variable: val},
|
||||
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
assert iter_data is not None
|
||||
sla_data.extend(iter_data)
|
||||
|
||||
iter_data_mean = {
|
||||
k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
|
||||
for k in sla_comb
|
||||
}
|
||||
|
||||
sla_margins = [
|
||||
criterion.print_and_compute_margin(iter_data_mean, k)
|
||||
for k, criterion in sla_comb.items()
|
||||
]
|
||||
margin = max(sla_margins)
|
||||
history[val] = margin
|
||||
|
||||
if margin <= 0:
|
||||
print("SLA criteria are met.")
|
||||
left = val
|
||||
else:
|
||||
print("SLA criteria are not met.")
|
||||
right = val
|
||||
|
||||
if right - left <= 1 and left in history:
|
||||
break
|
||||
|
||||
return sla_data, left, history
|
||||
return sla_data, history
|
||||
|
||||
|
||||
def search_sla(
|
||||
@@ -271,7 +234,6 @@ def search_sla(
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_comb: SLASweepItem,
|
||||
sla_variable: SLAVariable,
|
||||
sla_inf_value: int = 65536, # The value that represents infinite QPS
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
@@ -279,57 +241,25 @@ def search_sla(
|
||||
print("[SLA START]")
|
||||
print(f"SLA criteria: {sla_comb.as_text()}")
|
||||
|
||||
sla_data_0 = run_sla(
|
||||
result = solve_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {sla_variable: sla_inf_value},
|
||||
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, sla_inf_value),
|
||||
bench_comb=bench_comb,
|
||||
sla_comb=sla_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
sla_variable=sla_variable,
|
||||
)
|
||||
if sla_data_0 is None:
|
||||
if result is None:
|
||||
assert dry_run
|
||||
print("Omitting SLA search.")
|
||||
print("[SLA END]")
|
||||
return None
|
||||
return
|
||||
|
||||
sla_init_value = math.ceil(
|
||||
sum(_estimate_sla_value(item, sla_variable) for item in sla_data_0)
|
||||
/ len(sla_data_0)
|
||||
)
|
||||
print(f"Initial {sla_variable} to search: {sla_init_value} req/s.")
|
||||
|
||||
sla_data_1, (sla_min, sla_max), _ = _estimate_sla_bounds(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_comb=sla_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
sla_variable=sla_variable,
|
||||
init_value=sla_init_value,
|
||||
max_value=sla_inf_value,
|
||||
)
|
||||
print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.")
|
||||
|
||||
sla_data_2, sla_value, _ = _find_sla_value(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_comb=sla_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
sla_variable=sla_variable,
|
||||
min_value=sla_min,
|
||||
max_value=sla_max,
|
||||
)
|
||||
|
||||
sla_data = sla_data_0 + sla_data_1 + sla_data_2
|
||||
sla_data, sla_history = result
|
||||
sla_value = sla_history.get_max_passing()
|
||||
print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.")
|
||||
|
||||
with _get_sla_iter_path(
|
||||
|
||||
Reference in New Issue
Block a user