[Benchmark][2/2] Use spline interpolation to tune SLA variables (#32095)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-11 12:27:27 +08:00
committed by GitHub
parent 2a4dbe24ea
commit ef96fa3f1f
3 changed files with 232 additions and 264 deletions

View File

@@ -129,10 +129,10 @@ vllm bench sweep serve_sla \
The algorithm for adjusting the SLA variable is as follows:
1. Run the benchmark with infinite QPS, and use the corresponding metrics to determine the initial value of the variable.
- For example, the initial request rate is set to the concurrency under infinite QPS.
2. If the SLA is still satisfied, keep doubling the value until the SLA is no longer satisfied. This gives a relatively narrow window that contains the point where the SLA is barely satisfied.
3. Apply binary search over the window to find the maximum value that still satisfies the SLA.
1. Run the benchmark once with maximum possible QPS, and once with minimum possible QPS. For each run, calculate the distance of the SLA metrics from their targets, resulting in data points of QPS vs SLA distance.
2. Perform spline interpolation between the data points to estimate the QPS that results in zero SLA distance.
3. Run the benchmark with the estimated QPS and add the resulting data point to the history.
4. Repeat Steps 2 and 3 until the maximum QPS that passes SLA and the minimum QPS that fails SLA in the history are close enough to each other.
!!! important
SLA tuning is applied over each combination of `--serve-params`, `--bench-params`, and `--sla-params`.

View File

@@ -5,7 +5,7 @@ from pathlib import Path
from unittest.mock import patch
from vllm.benchmarks.sweep.param_sweep import ParameterSweepItem
from vllm.benchmarks.sweep.serve_sla import _estimate_sla_bounds, _find_sla_value
from vllm.benchmarks.sweep.serve_sla import solve_sla
from vllm.benchmarks.sweep.server import ServerProcess
from vllm.benchmarks.sweep.sla_sweep import (
SLACriterionBase,
@@ -39,18 +39,70 @@ def _set_return_value(
return patch("vllm.benchmarks.sweep.serve_sla.run_sla", side_effect=mock_run_sla)
def _var2metric_identity(bench_comb):
return [{"request_throughput": float(bench_comb["request_rate"])}]
def _var2metric_linear():
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
y = x
return [{"request_throughput": y}]
return wrapped
def _run_estimate_sla_bounds(
def _var2metric_concave(elbow_point: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
if x < elbow_point:
y = 0.5 * (x - elbow_point) + elbow_point
else:
y = 1.5 * (x - elbow_point) + elbow_point
return [{"request_throughput": y}]
return wrapped
def _var2metric_convex(elbow_point: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
if x < elbow_point:
y = 1.5 * (x - elbow_point) + elbow_point
else:
y = 0.5 * (x - elbow_point) + elbow_point
return [{"request_throughput": y}]
return wrapped
def _var2metric_quadratic(y_intercept: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
y = y_intercept + 0.1 * x**2
return [{"request_throughput": y}]
return wrapped
def _var2metric_sqrt(y_intercept: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
y = y_intercept + 10 * x**0.5
return [{"request_throughput": y}]
return wrapped
def _run_solve_sla(
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
criterion: SLACriterionBase,
init_value: int,
max_value: int,
min_value: int = 1,
max_value: int = 100,
):
with _set_return_value(var2metric):
return _estimate_sla_bounds(
result = solve_sla(
server=None,
bench_cmd=[],
serve_comb=ParameterSweepItem(),
@@ -60,143 +112,129 @@ def _run_estimate_sla_bounds(
num_runs=1,
dry_run=False,
sla_variable="request_rate",
init_value=init_value,
max_value=max_value,
sla_min_value=min_value,
sla_max_value=max_value,
)
assert result is not None
return result
def test_estimate_sla_bounds_le():
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds(
_var2metric_identity,
def test_solve_linear_sla_le():
sla_data, history = _run_solve_sla(
_var2metric_linear(),
SLALessThanOrEqualTo(target=32),
init_value=1,
max_value=100,
)
assert max_passing == 32
assert min_failing == 64
assert history.get_max_passing() == 32
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
2: True,
4: True,
8: True,
16: True,
32: True,
64: False,
33: False,
}
def test_estimate_sla_bounds_lt():
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds(
_var2metric_identity,
def test_solve_linear_sla_lt():
sla_data, history = _run_solve_sla(
_var2metric_linear(),
SLALessThan(target=32),
init_value=1,
max_value=100,
)
assert max_passing == 16
assert min_failing == 32
assert history.get_max_passing() == 31
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
2: True,
4: True,
8: True,
16: True,
31: True,
32: False,
}
def test_estimate_sla_bounds_oob():
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds(
_var2metric_identity,
def test_solve_linear_sla_oob():
sla_data, history = _run_solve_sla(
_var2metric_linear(),
SLALessThanOrEqualTo(target=32),
init_value=64,
max_value=128,
)
assert max_passing == 0
assert min_failing == 64
assert {val: margin <= 0 for val, margin in history.items()} == {
64: False,
}
def _run_test_find_sla_value_le(
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
criterion: SLACriterionBase,
min_value: int,
max_value: int,
):
with _set_return_value(var2metric):
return _find_sla_value(
server=None,
bench_cmd=[],
serve_comb=ParameterSweepItem(),
bench_comb=ParameterSweepItem(),
sla_comb=SLASweepItem({"request_throughput": criterion}),
base_path=Path(""),
num_runs=1,
dry_run=False,
sla_variable="request_rate",
min_value=min_value,
max_value=max_value,
)
def test_find_sla_value_le():
sla_data, sla_value, history = _run_test_find_sla_value_le(
_var2metric_identity,
SLALessThanOrEqualTo(target=50.0),
min_value=32,
max_value=64,
)
assert sla_value == 50
assert {val: margin <= 0 for val, margin in history.items()} == {
48: True,
56: False,
52: False,
50: True,
51: False,
}
def test_find_sla_value_lt():
sla_data, sla_value, history = _run_test_find_sla_value_le(
_var2metric_identity,
SLALessThan(target=50.0),
min_value=32,
max_value=64,
)
assert sla_value == 49
assert {val: margin <= 0 for val, margin in history.items()} == {
48: True,
56: False,
52: False,
50: False,
49: True,
}
def test_find_sla_value_oob():
sla_data, sla_value, history = _run_test_find_sla_value_le(
_var2metric_identity,
SLALessThanOrEqualTo(target=50.0),
min_value=64,
max_value=128,
)
assert sla_value == 64
assert history.get_max_passing() == 64
assert history.get_min_failing() == 64
assert {val: margin <= 0 for val, margin in history.items()} == {
96: False,
80: False,
72: False,
68: False,
66: False,
65: False,
100: False,
64: False,
}
def test_solve_concave_sla_le():
sla_data, history = _run_solve_sla(
_var2metric_concave(elbow_point=32),
SLALessThanOrEqualTo(target=24),
)
assert history.get_max_passing() == 16
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
7: True,
13: True,
15: True,
16: True,
17: False,
}
def test_solve_convex_sla_le():
sla_data, history = _run_solve_sla(
_var2metric_convex(elbow_point=32),
SLALessThanOrEqualTo(target=24),
)
assert history.get_max_passing() == 26
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
48: False,
30: False,
24: True,
26: True,
27: False,
}
def test_solve_quadratic_sla_le():
sla_data, history = _run_solve_sla(
_var2metric_quadratic(y_intercept=10),
SLALessThanOrEqualTo(target=50),
)
assert history.get_max_passing() == 20
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
4: True,
20: True,
21: False,
}
def test_solve_sqrt_sla_le():
sla_data, history = _run_solve_sla(
_var2metric_sqrt(y_intercept=10),
SLALessThanOrEqualTo(target=100),
)
assert history.get_max_passing() == 81
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
89: False,
81: True,
82: False,
}

View File

@@ -3,14 +3,11 @@
import argparse
import contextlib
import json
import math
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import ClassVar, Literal, get_args
from typing_extensions import assert_never
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
@@ -24,6 +21,15 @@ try:
except ImportError:
pd = PlaceholderModule("pandas")
try:
from scipy.interpolate import PchipInterpolator
except ImportError:
PchipInterpolator = (
PlaceholderModule("scipy")
.placeholder_attr("interpolate")
.placeholder_attr("PchipInterpolator")
)
def _get_sla_base_path(
output_dir: Path,
@@ -118,18 +124,36 @@ def run_sla(
SLAVariable = Literal["request_rate", "max_concurrency"]
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
request_throughput = float(run_data["request_throughput"]) # type: ignore
if sla_variable == "request_rate":
return request_throughput
if sla_variable == "max_concurrency":
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
return request_throughput * mean_latency_ms / 1000
class SLAHistory(dict[int, float]):
def __init__(self, min_value: int, max_value: int) -> None:
super().__init__()
assert_never(sla_variable)
self.min_value = min_value
self.max_value = max_value
def get_xy(self) -> tuple[list[int], list[float]]:
xs = list[int]()
ys = list[float]()
for x, y in sorted(self.items()):
xs.append(x)
ys.append(y)
return xs, ys
def get_max_passing(self) -> float:
return max(
(val for val, margin in self.items() if margin <= 0),
default=self.min_value,
)
def get_min_failing(self) -> float:
return min(
(val for val, margin in self.items() if margin > 0),
default=self.max_value,
)
def _estimate_sla_bounds(
def solve_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
@@ -140,17 +164,33 @@ def _estimate_sla_bounds(
num_runs: int,
dry_run: bool,
sla_variable: SLAVariable,
init_value: int,
max_value: int,
sla_min_value: int = 1,
sla_max_value: int = 8192, # The value that represents infinite QPS
):
sla_data = list[dict[str, object]]()
history = SLAHistory(min_value=sla_min_value, max_value=sla_max_value)
val: int = init_value
assert val > 0
# NOTE: We don't use equality here to be more robust against noisy results
while history.get_max_passing() + 1 < history.get_min_failing():
if len(history) == 0:
val = sla_max_value
elif len(history) == 1:
val = sla_min_value
else:
spl = PchipInterpolator(*history.get_xy(), extrapolate=False)
spl_roots = spl.solve()
if len(spl_roots) == 0:
# Fallback to binary search
val = int((history.get_max_passing() + history.get_min_failing()) / 2)
else:
val = int(spl_roots[0])
history = dict[int, float]()
if val in history:
# Cover both sides (floor and ceil) of the root to be sure
# that it is indeed the target value
val += 1
while True:
val = max(sla_min_value, min(val, sla_max_value))
print(f"Testing {sla_variable}: {val} req/s")
iter_data = run_sla(
@@ -162,8 +202,9 @@ def _estimate_sla_bounds(
num_runs=num_runs,
dry_run=dry_run,
)
if iter_data is None:
return None
assert iter_data is not None
sla_data.extend(iter_data)
iter_data_mean = {
@@ -175,92 +216,14 @@ def _estimate_sla_bounds(
criterion.print_and_compute_margin(iter_data_mean, k)
for k, criterion in sla_comb.items()
]
margin = max(sla_margins)
history[val] = margin
history[val] = margin = max(sla_margins)
if margin <= 0:
print("SLA criteria are met.")
val *= 2
print(f"SLA criteria are met. ({margin=:.2f})")
else:
print("SLA criteria are not met.")
break
print(f"SLA criteria are not met. ({margin=:.2f})")
if val >= max_value:
break
max_passing = max(
(val for val, margin in history.items() if margin <= 0),
default=0,
)
min_failing = min(
(val for val, margin in history.items() if margin > 0),
default=max_value,
)
return sla_data, (max_passing, min_failing), history
def _find_sla_value(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_comb: SLASweepItem,
base_path: Path,
num_runs: int,
dry_run: bool,
sla_variable: SLAVariable,
min_value: int,
max_value: int,
):
sla_data = list[dict[str, object]]()
left: int = min_value
right: int = max_value
history = dict[int, float]()
while True:
val = (left + right) // 2
print(f"Testing {sla_variable}: {val} req/s")
iter_data = run_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {sla_variable: val},
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
num_runs=num_runs,
dry_run=dry_run,
)
assert iter_data is not None
sla_data.extend(iter_data)
iter_data_mean = {
k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
for k in sla_comb
}
sla_margins = [
criterion.print_and_compute_margin(iter_data_mean, k)
for k, criterion in sla_comb.items()
]
margin = max(sla_margins)
history[val] = margin
if margin <= 0:
print("SLA criteria are met.")
left = val
else:
print("SLA criteria are not met.")
right = val
if right - left <= 1 and left in history:
break
return sla_data, left, history
return sla_data, history
def search_sla(
@@ -271,7 +234,6 @@ def search_sla(
bench_comb: ParameterSweepItem,
sla_comb: SLASweepItem,
sla_variable: SLAVariable,
sla_inf_value: int = 65536, # The value that represents infinite QPS
base_path: Path,
num_runs: int,
dry_run: bool,
@@ -279,57 +241,25 @@ def search_sla(
print("[SLA START]")
print(f"SLA criteria: {sla_comb.as_text()}")
sla_data_0 = run_sla(
result = solve_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {sla_variable: sla_inf_value},
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, sla_inf_value),
bench_comb=bench_comb,
sla_comb=sla_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
sla_variable=sla_variable,
)
if sla_data_0 is None:
if result is None:
assert dry_run
print("Omitting SLA search.")
print("[SLA END]")
return None
return
sla_init_value = math.ceil(
sum(_estimate_sla_value(item, sla_variable) for item in sla_data_0)
/ len(sla_data_0)
)
print(f"Initial {sla_variable} to search: {sla_init_value} req/s.")
sla_data_1, (sla_min, sla_max), _ = _estimate_sla_bounds(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_comb=sla_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
sla_variable=sla_variable,
init_value=sla_init_value,
max_value=sla_inf_value,
)
print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.")
sla_data_2, sla_value, _ = _find_sla_value(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_comb=sla_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
sla_variable=sla_variable,
min_value=sla_min,
max_value=sla_max,
)
sla_data = sla_data_0 + sla_data_1 + sla_data_2
sla_data, sla_history = result
sla_value = sla_history.get_max_passing()
print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.")
with _get_sla_iter_path(