fix(security): Add VLLM_MAX_N_SEQUENCES environment variable and enforce limit (#37952)

Signed-off-by: jperezde <jperezde@redhat.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Juan Pérez de Algaba
2026-03-27 14:02:10 +01:00
committed by GitHub
parent 497e234d38
commit b111f8a61f
5 changed files with 193 additions and 0 deletions

View File

@@ -1020,3 +1020,114 @@ def test_chat_completion_request_n_parameter_various_values():
assert sampling_params.n == n_value, (
f"Expected n={n_value}, got n={sampling_params.n}"
)
def test_chat_completion_request_n_parameter_exceeds_default_limit(
monkeypatch: pytest.MonkeyPatch,
):
"""Test that n values exceeding the default limit are rejected."""
import vllm.envs as envs
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
max_n = envs.VLLM_MAX_N_SEQUENCES
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
n=max_n + 1,
max_tokens=10,
)
with pytest.raises(ValueError, match="n must be at most"):
request.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)
def test_chat_completion_request_n_parameter_at_limit(
monkeypatch: pytest.MonkeyPatch,
):
"""Test that n at exactly the limit is accepted."""
import vllm.envs as envs
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
max_n = envs.VLLM_MAX_N_SEQUENCES
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
n=max_n,
max_tokens=10,
)
sampling_params = request.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)
assert sampling_params.n == max_n
def test_chat_completion_request_n_parameter_custom_limit(
monkeypatch: pytest.MonkeyPatch,
):
"""Test that VLLM_MAX_N_SEQUENCES env var overrides the default limit."""
import vllm.envs as envs
monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128")
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
n=128,
max_tokens=10,
)
sampling_params = request.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)
assert sampling_params.n == 128
request_over = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
n=129,
max_tokens=10,
)
with pytest.raises(ValueError, match="n must be at most 128"):
request_over.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)
def test_chat_completion_request_n_parameter_massive_value(
monkeypatch: pytest.MonkeyPatch,
):
"""Test that astronomically large n values are rejected (CVE fix)."""
import vllm.envs as envs
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
n=100_000_000,
max_tokens=1,
)
with pytest.raises(ValueError, match="n must be at most"):
request.to_sampling_params(
max_tokens=1,
default_sampling_params={},
)