fix(security): Add VLLM_MAX_N_SEQUENCES environment variable and enforce limit (#37952)
Signed-off-by: jperezde <jperezde@redhat.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
committed by
GitHub
parent
497e234d38
commit
b111f8a61f
@@ -231,6 +231,18 @@ The most effective approach is to deploy vLLM behind a reverse proxy (such as ng
|
||||
- Blocks all other endpoints, including the unauthenticated inference and operational control endpoints
|
||||
- Implements additional authentication, rate limiting, and logging at the proxy layer
|
||||
|
||||
## Request Parameter Resource Limits
|
||||
|
||||
Certain API request parameters can have a large impact on resource consumption and may be abused to exhaust server resources. The `n` parameter in the `/v1/completions` and `/v1/chat/completions` endpoints controls how many independent output sequences are generated per request. A very large value causes the engine to allocate memory, CPU, and GPU time proportional to `n`, which can lead to out-of-memory conditions on the host and block the server from processing other requests.
|
||||
|
||||
To mitigate this, vLLM enforces a configurable upper bound on the `n` parameter via the `VLLM_MAX_N_SEQUENCES` environment variable (default: **16384**). Requests exceeding this limit are rejected before reaching the engine.
|
||||
|
||||
### Recommendations
|
||||
|
||||
- **Public-facing deployments:** Consider setting `VLLM_MAX_N_SEQUENCES` to a value appropriate for your workload (e.g., `64` or `128`) to limit the blast radius of a single request.
|
||||
- **Reverse proxy layer:** In addition to vLLM's built-in limit, consider enforcing request body validation and rate limiting at your reverse proxy to further constrain abusive payloads.
|
||||
- **Monitoring:** Monitor per-request resource consumption to detect anomalous patterns that may indicate abuse.
|
||||
|
||||
## Tool Server and MCP Security
|
||||
|
||||
vLLM supports connecting to external tool servers via the `--tool-server` argument. This enables models to call tools through the Responses API (`/v1/responses`). Tool server support works with all models — it is not limited to specific model architectures.
|
||||
|
||||
@@ -1020,3 +1020,114 @@ def test_chat_completion_request_n_parameter_various_values():
|
||||
assert sampling_params.n == n_value, (
|
||||
f"Expected n={n_value}, got n={sampling_params.n}"
|
||||
)
|
||||
|
||||
|
||||
def test_chat_completion_request_n_parameter_exceeds_default_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that n values exceeding the default limit are rejected."""
|
||||
import vllm.envs as envs
|
||||
|
||||
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
max_n = envs.VLLM_MAX_N_SEQUENCES
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
n=max_n + 1,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="n must be at most"):
|
||||
request.to_sampling_params(
|
||||
max_tokens=10,
|
||||
default_sampling_params={},
|
||||
)
|
||||
|
||||
|
||||
def test_chat_completion_request_n_parameter_at_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that n at exactly the limit is accepted."""
|
||||
import vllm.envs as envs
|
||||
|
||||
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
max_n = envs.VLLM_MAX_N_SEQUENCES
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
n=max_n,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens=10,
|
||||
default_sampling_params={},
|
||||
)
|
||||
assert sampling_params.n == max_n
|
||||
|
||||
|
||||
def test_chat_completion_request_n_parameter_custom_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that VLLM_MAX_N_SEQUENCES env var overrides the default limit."""
|
||||
import vllm.envs as envs
|
||||
|
||||
monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128")
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
n=128,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens=10,
|
||||
default_sampling_params={},
|
||||
)
|
||||
assert sampling_params.n == 128
|
||||
|
||||
request_over = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
n=129,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="n must be at most 128"):
|
||||
request_over.to_sampling_params(
|
||||
max_tokens=10,
|
||||
default_sampling_params={},
|
||||
)
|
||||
|
||||
|
||||
def test_chat_completion_request_n_parameter_massive_value(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that astronomically large n values are rejected (CVE fix)."""
|
||||
import vllm.envs as envs
|
||||
|
||||
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
n=100_000_000,
|
||||
max_tokens=1,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="n must be at most"):
|
||||
request.to_sampling_params(
|
||||
max_tokens=1,
|
||||
default_sampling_params={},
|
||||
)
|
||||
|
||||
@@ -454,3 +454,55 @@ class TestVllmConfigureLogging:
|
||||
|
||||
with pytest.raises(ValueError, match="invalid literal for int"):
|
||||
_ = envs.VLLM_CONFIGURE_LOGGING
|
||||
|
||||
|
||||
class TestVllmMaxNSequences:
|
||||
def test_default_value(self):
|
||||
"""Test that VLLM_MAX_N_SEQUENCES defaults to 64."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("VLLM_MAX_N_SEQUENCES", None)
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
assert envs.VLLM_MAX_N_SEQUENCES == 16384
|
||||
|
||||
def test_custom_value(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that VLLM_MAX_N_SEQUENCES can be overridden."""
|
||||
monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128")
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
assert envs.VLLM_MAX_N_SEQUENCES == 128
|
||||
|
||||
def test_sampling_params_respects_limit(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that SamplingParams rejects n above the limit."""
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
max_n = envs.VLLM_MAX_N_SEQUENCES
|
||||
SamplingParams(n=max_n)
|
||||
|
||||
with pytest.raises(ValueError, match="n must be at most"):
|
||||
SamplingParams(n=max_n + 1)
|
||||
|
||||
def test_sampling_params_respects_custom_limit(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that SamplingParams uses the overridden env var limit."""
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128")
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
SamplingParams(n=128)
|
||||
|
||||
with pytest.raises(ValueError, match="n must be at most 128"):
|
||||
SamplingParams(n=129)
|
||||
|
||||
@@ -86,6 +86,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
||||
VLLM_RPC_TIMEOUT: int = 10000 # ms
|
||||
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
|
||||
VLLM_MAX_N_SEQUENCES: int = 16384
|
||||
VLLM_PLUGINS: list[str] | None = None
|
||||
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
|
||||
VLLM_LORA_RESOLVER_HF_REPO_LIST: str | None = None
|
||||
@@ -870,6 +871,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_HTTP_TIMEOUT_KEEP_ALIVE": lambda: int(
|
||||
os.environ.get("VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "5")
|
||||
),
|
||||
# Maximum allowed value for the `n` sampling parameter (number of output
|
||||
# sequences per request). Limits resource consumption to prevent
|
||||
# denial-of-service via excessively large fan-out. Default: 16384.
|
||||
"VLLM_MAX_N_SEQUENCES": lambda: int(
|
||||
os.environ.get("VLLM_MAX_N_SEQUENCES", "16384")
|
||||
),
|
||||
# a list of plugin names to load, separated by commas.
|
||||
# if this is not set, it means all plugins will be loaded
|
||||
# if this is set to an empty string, no plugins will be loaded
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import Any
|
||||
import msgspec
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
@@ -169,6 +170,9 @@ class SamplingParams(
|
||||
n: int = 1
|
||||
"""Number of outputs to return for the given prompt request.
|
||||
|
||||
The maximum allowed value is controlled by the ``VLLM_MAX_N_SEQUENCES``
|
||||
environment variable (default: 16384).
|
||||
|
||||
NOTE:
|
||||
`AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
|
||||
are generated and streamed cumulatively per request. To see all `n`
|
||||
@@ -425,6 +429,13 @@ class SamplingParams(
|
||||
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||
max_n = envs.VLLM_MAX_N_SEQUENCES
|
||||
if self.n > max_n:
|
||||
raise ValueError(
|
||||
f"n must be at most {max_n}, got {self.n}. "
|
||||
"To increase this limit, set the VLLM_MAX_N_SEQUENCES "
|
||||
"environment variable."
|
||||
)
|
||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
|
||||
|
||||
Reference in New Issue
Block a user