diff --git a/docs/usage/security.md b/docs/usage/security.md index 1e85a4a2d..b126d2a1e 100644 --- a/docs/usage/security.md +++ b/docs/usage/security.md @@ -231,6 +231,18 @@ The most effective approach is to deploy vLLM behind a reverse proxy (such as ng - Blocks all other endpoints, including the unauthenticated inference and operational control endpoints - Implements additional authentication, rate limiting, and logging at the proxy layer +## Request Parameter Resource Limits + +Certain API request parameters can have a large impact on resource consumption and may be abused to exhaust server resources. The `n` parameter in the `/v1/completions` and `/v1/chat/completions` endpoints controls how many independent output sequences are generated per request. A very large value causes the engine to allocate memory, CPU, and GPU time proportional to `n`, which can lead to out-of-memory conditions on the host and block the server from processing other requests. + +To mitigate this, vLLM enforces a configurable upper bound on the `n` parameter via the `VLLM_MAX_N_SEQUENCES` environment variable (default: **16384**). Requests exceeding this limit are rejected before reaching the engine. + +### Recommendations + +- **Public-facing deployments:** Consider setting `VLLM_MAX_N_SEQUENCES` to a value appropriate for your workload (e.g., `64` or `128`) to limit the blast radius of a single request. +- **Reverse proxy layer:** In addition to vLLM's built-in limit, consider enforcing request body validation and rate limiting at your reverse proxy to further constrain abusive payloads. +- **Monitoring:** Monitor per-request resource consumption to detect anomalous patterns that may indicate abuse. + ## Tool Server and MCP Security vLLM supports connecting to external tool servers via the `--tool-server` argument. This enables models to call tools through the Responses API (`/v1/responses`). Tool server support works with all models — it is not limited to specific model architectures. diff --git a/tests/entrypoints/openai/chat_completion/test_chat.py b/tests/entrypoints/openai/chat_completion/test_chat.py index 25f4c7d7a..212839f78 100644 --- a/tests/entrypoints/openai/chat_completion/test_chat.py +++ b/tests/entrypoints/openai/chat_completion/test_chat.py @@ -1020,3 +1020,114 @@ def test_chat_completion_request_n_parameter_various_values(): assert sampling_params.n == n_value, ( f"Expected n={n_value}, got n={sampling_params.n}" ) + + +def test_chat_completion_request_n_parameter_exceeds_default_limit( + monkeypatch: pytest.MonkeyPatch, +): + """Test that n values exceeding the default limit are rejected.""" + import vllm.envs as envs + + monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False) + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + max_n = envs.VLLM_MAX_N_SEQUENCES + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Test"}], + n=max_n + 1, + max_tokens=10, + ) + + with pytest.raises(ValueError, match="n must be at most"): + request.to_sampling_params( + max_tokens=10, + default_sampling_params={}, + ) + + +def test_chat_completion_request_n_parameter_at_limit( + monkeypatch: pytest.MonkeyPatch, +): + """Test that n at exactly the limit is accepted.""" + import vllm.envs as envs + + monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False) + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + max_n = envs.VLLM_MAX_N_SEQUENCES + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Test"}], + n=max_n, + max_tokens=10, + ) + + sampling_params = request.to_sampling_params( + max_tokens=10, + default_sampling_params={}, + ) + assert sampling_params.n == max_n + + +def test_chat_completion_request_n_parameter_custom_limit( + monkeypatch: pytest.MonkeyPatch, +): + """Test that VLLM_MAX_N_SEQUENCES env var overrides the default limit.""" + import vllm.envs as envs + + monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128") + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Test"}], + n=128, + max_tokens=10, + ) + + sampling_params = request.to_sampling_params( + max_tokens=10, + default_sampling_params={}, + ) + assert sampling_params.n == 128 + + request_over = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Test"}], + n=129, + max_tokens=10, + ) + + with pytest.raises(ValueError, match="n must be at most 128"): + request_over.to_sampling_params( + max_tokens=10, + default_sampling_params={}, + ) + + +def test_chat_completion_request_n_parameter_massive_value( + monkeypatch: pytest.MonkeyPatch, +): + """Test that astronomically large n values are rejected (CVE fix).""" + import vllm.envs as envs + + monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False) + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Test"}], + n=100_000_000, + max_tokens=1, + ) + + with pytest.raises(ValueError, match="n must be at most"): + request.to_sampling_params( + max_tokens=1, + default_sampling_params={}, + ) diff --git a/tests/test_envs.py b/tests/test_envs.py index b6b7cf38d..3f3add2ab 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -454,3 +454,55 @@ class TestVllmConfigureLogging: with pytest.raises(ValueError, match="invalid literal for int"): _ = envs.VLLM_CONFIGURE_LOGGING + + +class TestVllmMaxNSequences: + def test_default_value(self): + """Test that VLLM_MAX_N_SEQUENCES defaults to 64.""" + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("VLLM_MAX_N_SEQUENCES", None) + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + assert envs.VLLM_MAX_N_SEQUENCES == 16384 + + def test_custom_value(self, monkeypatch: pytest.MonkeyPatch): + """Test that VLLM_MAX_N_SEQUENCES can be overridden.""" + monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128") + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + assert envs.VLLM_MAX_N_SEQUENCES == 128 + + def test_sampling_params_respects_limit( + self, + monkeypatch: pytest.MonkeyPatch, + ): + """Test that SamplingParams rejects n above the limit.""" + from vllm.sampling_params import SamplingParams + + monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False) + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + max_n = envs.VLLM_MAX_N_SEQUENCES + SamplingParams(n=max_n) + + with pytest.raises(ValueError, match="n must be at most"): + SamplingParams(n=max_n + 1) + + def test_sampling_params_respects_custom_limit( + self, + monkeypatch: pytest.MonkeyPatch, + ): + """Test that SamplingParams uses the overridden env var limit.""" + from vllm.sampling_params import SamplingParams + + monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128") + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + SamplingParams(n=128) + + with pytest.raises(ValueError, match="n must be at most 128"): + SamplingParams(n=129) diff --git a/vllm/envs.py b/vllm/envs.py index d29e367bc..2944bb111 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -86,6 +86,7 @@ if TYPE_CHECKING: VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds + VLLM_MAX_N_SEQUENCES: int = 16384 VLLM_PLUGINS: list[str] | None = None VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None VLLM_LORA_RESOLVER_HF_REPO_LIST: str | None = None @@ -870,6 +871,12 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_HTTP_TIMEOUT_KEEP_ALIVE": lambda: int( os.environ.get("VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "5") ), + # Maximum allowed value for the `n` sampling parameter (number of output + # sequences per request). Limits resource consumption to prevent + # denial-of-service via excessively large fan-out. Default: 16384. + "VLLM_MAX_N_SEQUENCES": lambda: int( + os.environ.get("VLLM_MAX_N_SEQUENCES", "16384") + ), # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded # if this is set to an empty string, no plugins will be loaded diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index dc3e1d49c..3a2a04fd7 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -12,6 +12,7 @@ from typing import Any import msgspec from pydantic.dataclasses import dataclass +import vllm.envs as envs from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger @@ -169,6 +170,9 @@ class SamplingParams( n: int = 1 """Number of outputs to return for the given prompt request. + The maximum allowed value is controlled by the ``VLLM_MAX_N_SEQUENCES`` + environment variable (default: 16384). + NOTE: `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs are generated and streamed cumulatively per request. To see all `n` @@ -425,6 +429,13 @@ class SamplingParams( raise ValueError(f"n must be an int, but is of type {type(self.n)}") if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") + max_n = envs.VLLM_MAX_N_SEQUENCES + if self.n > max_n: + raise ValueError( + f"n must be at most {max_n}, got {self.n}. " + "To increase this limit, set the VLLM_MAX_N_SEQUENCES " + "environment variable." + ) if not -2.0 <= self.presence_penalty <= 2.0: raise ValueError( f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."